xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/generate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport dataclasses
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport platform
4*da0073e9SAndroid Build Coastguard Workerimport time
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
8*da0073e9SAndroid Build Coastguard Workerfrom mixtral_moe_quantize import (
9*da0073e9SAndroid Build Coastguard Worker    ConditionalFeedForwardInt8,
10*da0073e9SAndroid Build Coastguard Worker    WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler,
11*da0073e9SAndroid Build Coastguard Worker)
12*da0073e9SAndroid Build Coastguard Workerfrom model import Transformer as LLaMA
13*da0073e9SAndroid Build Coastguard Workerfrom quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerimport torch
16*da0073e9SAndroid Build Coastguard Workerimport torch._inductor.config
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.coordinate_descent_tuning = True
20*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.triton.unique_kernel_names = True
21*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future
22*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.assert_indirect_indexing = False
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
26*da0073e9SAndroid Build Coastguard Workerclass GPTModelConfig:
27*da0073e9SAndroid Build Coastguard Worker    name: str
28*da0073e9SAndroid Build Coastguard Worker    module: type
29*da0073e9SAndroid Build Coastguard Worker    mode: Optional[str]
30*da0073e9SAndroid Build Coastguard Worker    quantizer: type
31*da0073e9SAndroid Build Coastguard Worker    token_per_sec: float
32*da0073e9SAndroid Build Coastguard Worker    memory_bandwidth: float
33*da0073e9SAndroid Build Coastguard Worker    compilation_time: float
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerdef device_sync(device):
37*da0073e9SAndroid Build Coastguard Worker    if "cuda" in device:
38*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(device)
39*da0073e9SAndroid Build Coastguard Worker    elif "cpu" in device:
40*da0073e9SAndroid Build Coastguard Worker        pass
41*da0073e9SAndroid Build Coastguard Worker    else:
42*da0073e9SAndroid Build Coastguard Worker        print(f"device={device} is not yet suppported")
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerdef get_arch_name() -> str:
46*da0073e9SAndroid Build Coastguard Worker    if torch.cuda.is_available():
47*da0073e9SAndroid Build Coastguard Worker        return torch.cuda.get_device_name()
48*da0073e9SAndroid Build Coastguard Worker    else:
49*da0073e9SAndroid Build Coastguard Worker        # This returns x86_64 or arm64 (for aarch64)
50*da0073e9SAndroid Build Coastguard Worker        return platform.machine()
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef multinomial_sample_one_no_sync(
54*da0073e9SAndroid Build Coastguard Worker    probs_sort,
55*da0073e9SAndroid Build Coastguard Worker):  # Does multinomial sampling without a cuda synchronization
56*da0073e9SAndroid Build Coastguard Worker    q = torch.empty_like(probs_sort).exponential_(1)
57*da0073e9SAndroid Build Coastguard Worker    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerdef logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
61*da0073e9SAndroid Build Coastguard Worker    logits = logits / max(temperature, 1e-5)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    if top_k is not None:
64*da0073e9SAndroid Build Coastguard Worker        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
65*da0073e9SAndroid Build Coastguard Worker        pivot = v.select(-1, -1).unsqueeze(-1)
66*da0073e9SAndroid Build Coastguard Worker        logits = torch.where(logits < pivot, -float("Inf"), logits)
67*da0073e9SAndroid Build Coastguard Worker    probs = torch.nn.functional.softmax(logits, dim=-1)
68*da0073e9SAndroid Build Coastguard Worker    return probs
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
72*da0073e9SAndroid Build Coastguard Worker    probs = logits_to_probs(logits[0, -1], temperature, top_k)
73*da0073e9SAndroid Build Coastguard Worker    idx_next = multinomial_sample_one_no_sync(probs)
74*da0073e9SAndroid Build Coastguard Worker    return idx_next, probs
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker@torch.compile(fullgraph=True)
78*da0073e9SAndroid Build Coastguard Workerdef prefill(
79*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
80*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
81*da0073e9SAndroid Build Coastguard Worker    # input_pos: [B, S]
82*da0073e9SAndroid Build Coastguard Worker    logits = model(x, input_pos)
83*da0073e9SAndroid Build Coastguard Worker    return sample(logits, **sampling_kwargs)[0]
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker@torch.compile(fullgraph=True, mode="reduce-overhead")
87*da0073e9SAndroid Build Coastguard Workerdef decode_one_token(
88*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
89*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, torch.Tensor]:
90*da0073e9SAndroid Build Coastguard Worker    # input_pos: [B, 1]
91*da0073e9SAndroid Build Coastguard Worker    assert input_pos.shape[-1] == 1
92*da0073e9SAndroid Build Coastguard Worker    logits = model(x, input_pos)
93*da0073e9SAndroid Build Coastguard Worker    return sample(logits, **sampling_kwargs)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workerdef decode_n_tokens(
97*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module,
98*da0073e9SAndroid Build Coastguard Worker    cur_token: torch.Tensor,
99*da0073e9SAndroid Build Coastguard Worker    input_pos: torch.Tensor,
100*da0073e9SAndroid Build Coastguard Worker    num_new_tokens: int,
101*da0073e9SAndroid Build Coastguard Worker    **sampling_kwargs,
102*da0073e9SAndroid Build Coastguard Worker):
103*da0073e9SAndroid Build Coastguard Worker    new_tokens, new_probs = [], []
104*da0073e9SAndroid Build Coastguard Worker    for i in range(num_new_tokens):
105*da0073e9SAndroid Build Coastguard Worker        with torch.nn.attention.sdpa_kernel(
106*da0073e9SAndroid Build Coastguard Worker            torch.nn.attention.SDPBackend.MATH
107*da0073e9SAndroid Build Coastguard Worker        ):  # Actually better for Inductor to codegen attention here
108*da0073e9SAndroid Build Coastguard Worker            next_token, next_prob = decode_one_token(
109*da0073e9SAndroid Build Coastguard Worker                model, cur_token, input_pos, **sampling_kwargs
110*da0073e9SAndroid Build Coastguard Worker            )
111*da0073e9SAndroid Build Coastguard Worker            input_pos += 1
112*da0073e9SAndroid Build Coastguard Worker            new_tokens.append(next_token.clone())
113*da0073e9SAndroid Build Coastguard Worker            new_probs.append(next_prob.clone())
114*da0073e9SAndroid Build Coastguard Worker            cur_token = next_token.view(1, -1)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    return new_tokens, new_probs
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker@torch.no_grad()
120*da0073e9SAndroid Build Coastguard Workerdef generate(
121*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs
122*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
123*da0073e9SAndroid Build Coastguard Worker    device, dtype = prompt.device, prompt.dtype
124*da0073e9SAndroid Build Coastguard Worker    T = prompt.size(0)
125*da0073e9SAndroid Build Coastguard Worker    T_new = T + max_new_tokens
126*da0073e9SAndroid Build Coastguard Worker    max_seq_length = min(T_new, model.config.block_size)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    with torch.device(device):
129*da0073e9SAndroid Build Coastguard Worker        model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    # create an empty tensor of the expected final shape and fill in the current tokens
132*da0073e9SAndroid Build Coastguard Worker    empty = torch.empty(T_new, dtype=dtype, device=device)
133*da0073e9SAndroid Build Coastguard Worker    empty[:T] = prompt
134*da0073e9SAndroid Build Coastguard Worker    seq = empty
135*da0073e9SAndroid Build Coastguard Worker    input_pos = torch.arange(0, T, device=device)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
138*da0073e9SAndroid Build Coastguard Worker    seq[T] = next_token
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    input_pos = torch.tensor([T], device=device, dtype=torch.int)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    generated_tokens, _ = decode_n_tokens(
143*da0073e9SAndroid Build Coastguard Worker        model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs
144*da0073e9SAndroid Build Coastguard Worker    )
145*da0073e9SAndroid Build Coastguard Worker    seq[T + 1 :] = torch.cat(generated_tokens)
146*da0073e9SAndroid Build Coastguard Worker    return seq
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16):
150*da0073e9SAndroid Build Coastguard Worker    with torch.device("meta"):
151*da0073e9SAndroid Build Coastguard Worker        model = x.module.from_name(x.name)
152*da0073e9SAndroid Build Coastguard Worker    model = model.to(dtype=precision)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    if x.mode == "int8":
155*da0073e9SAndroid Build Coastguard Worker        print("Using int8 weight-only quantization!")
156*da0073e9SAndroid Build Coastguard Worker        model = x.quantizer(model).convert_for_runtime()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    state_dict = model.state_dict()
159*da0073e9SAndroid Build Coastguard Worker    for k, v in state_dict.items():
160*da0073e9SAndroid Build Coastguard Worker        state_dict[k] = torch.nn.Parameter(
161*da0073e9SAndroid Build Coastguard Worker            torch.randn(v.shape, device=device).to(dtype=v.dtype),
162*da0073e9SAndroid Build Coastguard Worker            requires_grad=v.requires_grad,
163*da0073e9SAndroid Build Coastguard Worker        )
164*da0073e9SAndroid Build Coastguard Worker    model.load_state_dict(state_dict, assign=True)
165*da0073e9SAndroid Build Coastguard Worker    return model.eval()
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker# Only count activated parameters and buffers.
169*da0073e9SAndroid Build Coastguard Workerdef _get_model_size(model):
170*da0073e9SAndroid Build Coastguard Worker    model_size = 0
171*da0073e9SAndroid Build Coastguard Worker    for name, child in model.named_children():
172*da0073e9SAndroid Build Coastguard Worker        if not isinstance(child, torch.nn.Embedding):
173*da0073e9SAndroid Build Coastguard Worker            model_size += sum(
174*da0073e9SAndroid Build Coastguard Worker                p.numel() * p.dtype.itemsize
175*da0073e9SAndroid Build Coastguard Worker                for p in itertools.chain(child.parameters(), child.buffers())
176*da0073e9SAndroid Build Coastguard Worker            )
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    # Remove the inactivated experts from the model size if this is mixture of experts
179*da0073e9SAndroid Build Coastguard Worker    # architecture, since only activated experts are loaded.
180*da0073e9SAndroid Build Coastguard Worker    if hasattr(model.config, "num_experts"):
181*da0073e9SAndroid Build Coastguard Worker        config = model.config
182*da0073e9SAndroid Build Coastguard Worker        for submodule in model.modules():
183*da0073e9SAndroid Build Coastguard Worker            if isinstance(
184*da0073e9SAndroid Build Coastguard Worker                submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8)
185*da0073e9SAndroid Build Coastguard Worker            ):
186*da0073e9SAndroid Build Coastguard Worker                model_size -= (
187*da0073e9SAndroid Build Coastguard Worker                    sum(
188*da0073e9SAndroid Build Coastguard Worker                        p.numel() * p.dtype.itemsize
189*da0073e9SAndroid Build Coastguard Worker                        for p in itertools.chain(
190*da0073e9SAndroid Build Coastguard Worker                            submodule.parameters(), child.buffers()
191*da0073e9SAndroid Build Coastguard Worker                        )
192*da0073e9SAndroid Build Coastguard Worker                    )
193*da0073e9SAndroid Build Coastguard Worker                    * (config.num_experts - config.num_activated_experts)
194*da0073e9SAndroid Build Coastguard Worker                    / config.num_experts
195*da0073e9SAndroid Build Coastguard Worker                )
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    return model_size
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Workerdef run_experiment(
201*da0073e9SAndroid Build Coastguard Worker    x: GPTModelConfig,
202*da0073e9SAndroid Build Coastguard Worker    num_samples: int = 5,
203*da0073e9SAndroid Build Coastguard Worker    max_new_tokens: int = 200,
204*da0073e9SAndroid Build Coastguard Worker    top_k: int = 200,
205*da0073e9SAndroid Build Coastguard Worker    temperature: float = 0.8,
206*da0073e9SAndroid Build Coastguard Worker    device: str = "cuda",
207*da0073e9SAndroid Build Coastguard Worker) -> None:
208*da0073e9SAndroid Build Coastguard Worker    print(f"Loading model {x.name}")
209*da0073e9SAndroid Build Coastguard Worker    t0 = time.time()
210*da0073e9SAndroid Build Coastguard Worker    model = _load_model(x, device=device)
211*da0073e9SAndroid Build Coastguard Worker    device_sync(device=device)  # MKG
212*da0073e9SAndroid Build Coastguard Worker    print(f"Time to load model: {time.time() - t0:.02f} seconds")
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    prompt = torch.tensor(
215*da0073e9SAndroid Build Coastguard Worker        [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32
216*da0073e9SAndroid Build Coastguard Worker    )
217*da0073e9SAndroid Build Coastguard Worker    prompt_length = prompt.size(0)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    torch.manual_seed(1234)
220*da0073e9SAndroid Build Coastguard Worker    model_size = _get_model_size(model)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []}
223*da0073e9SAndroid Build Coastguard Worker    start = -1
224*da0073e9SAndroid Build Coastguard Worker    compilation_time = None
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    for i in range(start, num_samples):
227*da0073e9SAndroid Build Coastguard Worker        device_sync(device=device)  # MKG
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        t0 = time.perf_counter()
230*da0073e9SAndroid Build Coastguard Worker        y = generate(
231*da0073e9SAndroid Build Coastguard Worker            model, prompt, max_new_tokens, temperature=temperature, top_k=top_k
232*da0073e9SAndroid Build Coastguard Worker        )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        if i == -1:
235*da0073e9SAndroid Build Coastguard Worker            compilation_time = time.perf_counter() - t0
236*da0073e9SAndroid Build Coastguard Worker            print(f"Compilation time: {compilation_time:.2f} seconds")
237*da0073e9SAndroid Build Coastguard Worker            continue
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        device_sync(device=device)  # MKG
240*da0073e9SAndroid Build Coastguard Worker        t = time.perf_counter() - t0
241*da0073e9SAndroid Build Coastguard Worker        tokens_generated = y.size(0) - prompt_length
242*da0073e9SAndroid Build Coastguard Worker        tokens_sec = tokens_generated / t
243*da0073e9SAndroid Build Coastguard Worker        aggregate_metrics["tokens_per_sec"].append(tokens_sec)
244*da0073e9SAndroid Build Coastguard Worker        aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item()
247*da0073e9SAndroid Build Coastguard Worker    memory_bandwidth = torch.mean(
248*da0073e9SAndroid Build Coastguard Worker        torch.tensor(aggregate_metrics["memory_bandwidth"])
249*da0073e9SAndroid Build Coastguard Worker    ).item()
250*da0073e9SAndroid Build Coastguard Worker    print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec")
251*da0073e9SAndroid Build Coastguard Worker    print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s")
252*da0073e9SAndroid Build Coastguard Worker    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
253*da0073e9SAndroid Build Coastguard Worker    return token_per_sec, memory_bandwidth, compilation_time
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
257*da0073e9SAndroid Build Coastguard Workerdef run_llama2_7b_bf16(device: str = "cuda"):
258*da0073e9SAndroid Build Coastguard Worker    from benchmark import Experiment
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    model = GPTModelConfig(
261*da0073e9SAndroid Build Coastguard Worker        "Llama-2-7b-chat-hf",
262*da0073e9SAndroid Build Coastguard Worker        LLaMA,
263*da0073e9SAndroid Build Coastguard Worker        "bfloat16",
264*da0073e9SAndroid Build Coastguard Worker        LLaMAWeightOnlyInt8QuantHandler,
265*da0073e9SAndroid Build Coastguard Worker        94,
266*da0073e9SAndroid Build Coastguard Worker        1253,
267*da0073e9SAndroid Build Coastguard Worker        162,
268*da0073e9SAndroid Build Coastguard Worker    )
269*da0073e9SAndroid Build Coastguard Worker    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
270*da0073e9SAndroid Build Coastguard Worker        model, device=device
271*da0073e9SAndroid Build Coastguard Worker    )
272*da0073e9SAndroid Build Coastguard Worker    return [
273*da0073e9SAndroid Build Coastguard Worker        Experiment(
274*da0073e9SAndroid Build Coastguard Worker            model.name,
275*da0073e9SAndroid Build Coastguard Worker            "token_per_sec",
276*da0073e9SAndroid Build Coastguard Worker            model.token_per_sec,
277*da0073e9SAndroid Build Coastguard Worker            f"{token_per_sec:.02f}",
278*da0073e9SAndroid Build Coastguard Worker            model.mode,
279*da0073e9SAndroid Build Coastguard Worker            device,
280*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
281*da0073e9SAndroid Build Coastguard Worker            True,
282*da0073e9SAndroid Build Coastguard Worker        ),
283*da0073e9SAndroid Build Coastguard Worker        Experiment(
284*da0073e9SAndroid Build Coastguard Worker            model.name,
285*da0073e9SAndroid Build Coastguard Worker            "memory_bandwidth(GB/s)",
286*da0073e9SAndroid Build Coastguard Worker            model.memory_bandwidth,
287*da0073e9SAndroid Build Coastguard Worker            f"{memory_bandwidth:.02f}",
288*da0073e9SAndroid Build Coastguard Worker            model.mode,
289*da0073e9SAndroid Build Coastguard Worker            device,
290*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
291*da0073e9SAndroid Build Coastguard Worker            True,
292*da0073e9SAndroid Build Coastguard Worker        ),
293*da0073e9SAndroid Build Coastguard Worker        Experiment(
294*da0073e9SAndroid Build Coastguard Worker            model.name,
295*da0073e9SAndroid Build Coastguard Worker            "compilation_time(s)",
296*da0073e9SAndroid Build Coastguard Worker            model.compilation_time,
297*da0073e9SAndroid Build Coastguard Worker            f"{compilation_time:.02f}",
298*da0073e9SAndroid Build Coastguard Worker            model.mode,
299*da0073e9SAndroid Build Coastguard Worker            device,
300*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
301*da0073e9SAndroid Build Coastguard Worker            True,
302*da0073e9SAndroid Build Coastguard Worker        ),
303*da0073e9SAndroid Build Coastguard Worker    ]
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
307*da0073e9SAndroid Build Coastguard Workerdef run_llama2_7b_int8(device: str = "cuda"):
308*da0073e9SAndroid Build Coastguard Worker    from benchmark import Experiment
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker    model = GPTModelConfig(
311*da0073e9SAndroid Build Coastguard Worker        "Llama-2-7b-chat-hf",
312*da0073e9SAndroid Build Coastguard Worker        LLaMA,
313*da0073e9SAndroid Build Coastguard Worker        "int8",
314*da0073e9SAndroid Build Coastguard Worker        LLaMAWeightOnlyInt8QuantHandler,
315*da0073e9SAndroid Build Coastguard Worker        144,
316*da0073e9SAndroid Build Coastguard Worker        957,
317*da0073e9SAndroid Build Coastguard Worker        172,
318*da0073e9SAndroid Build Coastguard Worker    )
319*da0073e9SAndroid Build Coastguard Worker    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
320*da0073e9SAndroid Build Coastguard Worker        model, device=device
321*da0073e9SAndroid Build Coastguard Worker    )
322*da0073e9SAndroid Build Coastguard Worker    return [
323*da0073e9SAndroid Build Coastguard Worker        Experiment(
324*da0073e9SAndroid Build Coastguard Worker            model.name,
325*da0073e9SAndroid Build Coastguard Worker            "token_per_sec",
326*da0073e9SAndroid Build Coastguard Worker            model.token_per_sec,
327*da0073e9SAndroid Build Coastguard Worker            f"{token_per_sec:.02f}",
328*da0073e9SAndroid Build Coastguard Worker            model.mode,
329*da0073e9SAndroid Build Coastguard Worker            device,
330*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
331*da0073e9SAndroid Build Coastguard Worker            True,
332*da0073e9SAndroid Build Coastguard Worker        ),
333*da0073e9SAndroid Build Coastguard Worker        Experiment(
334*da0073e9SAndroid Build Coastguard Worker            model.name,
335*da0073e9SAndroid Build Coastguard Worker            "memory_bandwidth(GB/s)",
336*da0073e9SAndroid Build Coastguard Worker            model.memory_bandwidth,
337*da0073e9SAndroid Build Coastguard Worker            f"{memory_bandwidth:.02f}",
338*da0073e9SAndroid Build Coastguard Worker            model.mode,
339*da0073e9SAndroid Build Coastguard Worker            device,
340*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
341*da0073e9SAndroid Build Coastguard Worker            True,
342*da0073e9SAndroid Build Coastguard Worker        ),
343*da0073e9SAndroid Build Coastguard Worker        Experiment(
344*da0073e9SAndroid Build Coastguard Worker            model.name,
345*da0073e9SAndroid Build Coastguard Worker            "compilation_time(s)",
346*da0073e9SAndroid Build Coastguard Worker            model.compilation_time,
347*da0073e9SAndroid Build Coastguard Worker            f"{compilation_time:.02f}",
348*da0073e9SAndroid Build Coastguard Worker            model.mode,
349*da0073e9SAndroid Build Coastguard Worker            device,
350*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
351*da0073e9SAndroid Build Coastguard Worker            True,
352*da0073e9SAndroid Build Coastguard Worker        ),
353*da0073e9SAndroid Build Coastguard Worker    ]
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
357*da0073e9SAndroid Build Coastguard Workerdef run_mixtral_8x7b_int8(device: str = "cuda"):
358*da0073e9SAndroid Build Coastguard Worker    from benchmark import Experiment
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
361*da0073e9SAndroid Build Coastguard Worker    model = GPTModelConfig(
362*da0073e9SAndroid Build Coastguard Worker        "Mixtral-8x7B-v0.1",
363*da0073e9SAndroid Build Coastguard Worker        MixtralMoE,
364*da0073e9SAndroid Build Coastguard Worker        "int8",
365*da0073e9SAndroid Build Coastguard Worker        MixtralMoEWeightOnlyInt8QuantHandler,
366*da0073e9SAndroid Build Coastguard Worker        175,
367*da0073e9SAndroid Build Coastguard Worker        1130,
368*da0073e9SAndroid Build Coastguard Worker        162,
369*da0073e9SAndroid Build Coastguard Worker    )
370*da0073e9SAndroid Build Coastguard Worker    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
371*da0073e9SAndroid Build Coastguard Worker        model, device=device
372*da0073e9SAndroid Build Coastguard Worker    )
373*da0073e9SAndroid Build Coastguard Worker    return [
374*da0073e9SAndroid Build Coastguard Worker        Experiment(
375*da0073e9SAndroid Build Coastguard Worker            model.name,
376*da0073e9SAndroid Build Coastguard Worker            "token_per_sec",
377*da0073e9SAndroid Build Coastguard Worker            model.token_per_sec,
378*da0073e9SAndroid Build Coastguard Worker            f"{token_per_sec:.02f}",
379*da0073e9SAndroid Build Coastguard Worker            model.mode,
380*da0073e9SAndroid Build Coastguard Worker            device,
381*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
382*da0073e9SAndroid Build Coastguard Worker            True,
383*da0073e9SAndroid Build Coastguard Worker        ),
384*da0073e9SAndroid Build Coastguard Worker        Experiment(
385*da0073e9SAndroid Build Coastguard Worker            model.name,
386*da0073e9SAndroid Build Coastguard Worker            "memory_bandwidth(GB/s)",
387*da0073e9SAndroid Build Coastguard Worker            model.memory_bandwidth,
388*da0073e9SAndroid Build Coastguard Worker            f"{memory_bandwidth:.02f}",
389*da0073e9SAndroid Build Coastguard Worker            model.mode,
390*da0073e9SAndroid Build Coastguard Worker            device,
391*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
392*da0073e9SAndroid Build Coastguard Worker            True,
393*da0073e9SAndroid Build Coastguard Worker        ),
394*da0073e9SAndroid Build Coastguard Worker        Experiment(
395*da0073e9SAndroid Build Coastguard Worker            model.name,
396*da0073e9SAndroid Build Coastguard Worker            "compilation_time(s)",
397*da0073e9SAndroid Build Coastguard Worker            model.compilation_time,
398*da0073e9SAndroid Build Coastguard Worker            f"{compilation_time:.02f}",
399*da0073e9SAndroid Build Coastguard Worker            model.mode,
400*da0073e9SAndroid Build Coastguard Worker            device,
401*da0073e9SAndroid Build Coastguard Worker            get_arch_name(),
402*da0073e9SAndroid Build Coastguard Worker            True,
403*da0073e9SAndroid Build Coastguard Worker        ),
404*da0073e9SAndroid Build Coastguard Worker    ]
405