1*da0073e9SAndroid Build Coastguard Workerimport dataclasses 2*da0073e9SAndroid Build Coastguard Workerimport itertools 3*da0073e9SAndroid Build Coastguard Workerimport platform 4*da0073e9SAndroid Build Coastguard Workerimport time 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE 8*da0073e9SAndroid Build Coastguard Workerfrom mixtral_moe_quantize import ( 9*da0073e9SAndroid Build Coastguard Worker ConditionalFeedForwardInt8, 10*da0073e9SAndroid Build Coastguard Worker WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, 11*da0073e9SAndroid Build Coastguard Worker) 12*da0073e9SAndroid Build Coastguard Workerfrom model import Transformer as LLaMA 13*da0073e9SAndroid Build Coastguard Workerfrom quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerimport torch 16*da0073e9SAndroid Build Coastguard Workerimport torch._inductor.config 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.coordinate_descent_tuning = True 20*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.triton.unique_kernel_names = True 21*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 22*da0073e9SAndroid Build Coastguard Workertorch._inductor.config.assert_indirect_indexing = False 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 26*da0073e9SAndroid Build Coastguard Workerclass GPTModelConfig: 27*da0073e9SAndroid Build Coastguard Worker name: str 28*da0073e9SAndroid Build Coastguard Worker module: type 29*da0073e9SAndroid Build Coastguard Worker mode: Optional[str] 30*da0073e9SAndroid Build Coastguard Worker quantizer: type 31*da0073e9SAndroid Build Coastguard Worker token_per_sec: float 32*da0073e9SAndroid Build Coastguard Worker memory_bandwidth: float 33*da0073e9SAndroid Build Coastguard Worker compilation_time: float 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerdef device_sync(device): 37*da0073e9SAndroid Build Coastguard Worker if "cuda" in device: 38*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(device) 39*da0073e9SAndroid Build Coastguard Worker elif "cpu" in device: 40*da0073e9SAndroid Build Coastguard Worker pass 41*da0073e9SAndroid Build Coastguard Worker else: 42*da0073e9SAndroid Build Coastguard Worker print(f"device={device} is not yet suppported") 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerdef get_arch_name() -> str: 46*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 47*da0073e9SAndroid Build Coastguard Worker return torch.cuda.get_device_name() 48*da0073e9SAndroid Build Coastguard Worker else: 49*da0073e9SAndroid Build Coastguard Worker # This returns x86_64 or arm64 (for aarch64) 50*da0073e9SAndroid Build Coastguard Worker return platform.machine() 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef multinomial_sample_one_no_sync( 54*da0073e9SAndroid Build Coastguard Worker probs_sort, 55*da0073e9SAndroid Build Coastguard Worker): # Does multinomial sampling without a cuda synchronization 56*da0073e9SAndroid Build Coastguard Worker q = torch.empty_like(probs_sort).exponential_(1) 57*da0073e9SAndroid Build Coastguard Worker return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerdef logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 61*da0073e9SAndroid Build Coastguard Worker logits = logits / max(temperature, 1e-5) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker if top_k is not None: 64*da0073e9SAndroid Build Coastguard Worker v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 65*da0073e9SAndroid Build Coastguard Worker pivot = v.select(-1, -1).unsqueeze(-1) 66*da0073e9SAndroid Build Coastguard Worker logits = torch.where(logits < pivot, -float("Inf"), logits) 67*da0073e9SAndroid Build Coastguard Worker probs = torch.nn.functional.softmax(logits, dim=-1) 68*da0073e9SAndroid Build Coastguard Worker return probs 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 72*da0073e9SAndroid Build Coastguard Worker probs = logits_to_probs(logits[0, -1], temperature, top_k) 73*da0073e9SAndroid Build Coastguard Worker idx_next = multinomial_sample_one_no_sync(probs) 74*da0073e9SAndroid Build Coastguard Worker return idx_next, probs 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker@torch.compile(fullgraph=True) 78*da0073e9SAndroid Build Coastguard Workerdef prefill( 79*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs 80*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 81*da0073e9SAndroid Build Coastguard Worker # input_pos: [B, S] 82*da0073e9SAndroid Build Coastguard Worker logits = model(x, input_pos) 83*da0073e9SAndroid Build Coastguard Worker return sample(logits, **sampling_kwargs)[0] 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker@torch.compile(fullgraph=True, mode="reduce-overhead") 87*da0073e9SAndroid Build Coastguard Workerdef decode_one_token( 88*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs 89*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, torch.Tensor]: 90*da0073e9SAndroid Build Coastguard Worker # input_pos: [B, 1] 91*da0073e9SAndroid Build Coastguard Worker assert input_pos.shape[-1] == 1 92*da0073e9SAndroid Build Coastguard Worker logits = model(x, input_pos) 93*da0073e9SAndroid Build Coastguard Worker return sample(logits, **sampling_kwargs) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workerdef decode_n_tokens( 97*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module, 98*da0073e9SAndroid Build Coastguard Worker cur_token: torch.Tensor, 99*da0073e9SAndroid Build Coastguard Worker input_pos: torch.Tensor, 100*da0073e9SAndroid Build Coastguard Worker num_new_tokens: int, 101*da0073e9SAndroid Build Coastguard Worker **sampling_kwargs, 102*da0073e9SAndroid Build Coastguard Worker): 103*da0073e9SAndroid Build Coastguard Worker new_tokens, new_probs = [], [] 104*da0073e9SAndroid Build Coastguard Worker for i in range(num_new_tokens): 105*da0073e9SAndroid Build Coastguard Worker with torch.nn.attention.sdpa_kernel( 106*da0073e9SAndroid Build Coastguard Worker torch.nn.attention.SDPBackend.MATH 107*da0073e9SAndroid Build Coastguard Worker ): # Actually better for Inductor to codegen attention here 108*da0073e9SAndroid Build Coastguard Worker next_token, next_prob = decode_one_token( 109*da0073e9SAndroid Build Coastguard Worker model, cur_token, input_pos, **sampling_kwargs 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker input_pos += 1 112*da0073e9SAndroid Build Coastguard Worker new_tokens.append(next_token.clone()) 113*da0073e9SAndroid Build Coastguard Worker new_probs.append(next_prob.clone()) 114*da0073e9SAndroid Build Coastguard Worker cur_token = next_token.view(1, -1) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker return new_tokens, new_probs 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker@torch.no_grad() 120*da0073e9SAndroid Build Coastguard Workerdef generate( 121*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs 122*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 123*da0073e9SAndroid Build Coastguard Worker device, dtype = prompt.device, prompt.dtype 124*da0073e9SAndroid Build Coastguard Worker T = prompt.size(0) 125*da0073e9SAndroid Build Coastguard Worker T_new = T + max_new_tokens 126*da0073e9SAndroid Build Coastguard Worker max_seq_length = min(T_new, model.config.block_size) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker with torch.device(device): 129*da0073e9SAndroid Build Coastguard Worker model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker # create an empty tensor of the expected final shape and fill in the current tokens 132*da0073e9SAndroid Build Coastguard Worker empty = torch.empty(T_new, dtype=dtype, device=device) 133*da0073e9SAndroid Build Coastguard Worker empty[:T] = prompt 134*da0073e9SAndroid Build Coastguard Worker seq = empty 135*da0073e9SAndroid Build Coastguard Worker input_pos = torch.arange(0, T, device=device) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) 138*da0073e9SAndroid Build Coastguard Worker seq[T] = next_token 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker input_pos = torch.tensor([T], device=device, dtype=torch.int) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker generated_tokens, _ = decode_n_tokens( 143*da0073e9SAndroid Build Coastguard Worker model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker seq[T + 1 :] = torch.cat(generated_tokens) 146*da0073e9SAndroid Build Coastguard Worker return seq 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16): 150*da0073e9SAndroid Build Coastguard Worker with torch.device("meta"): 151*da0073e9SAndroid Build Coastguard Worker model = x.module.from_name(x.name) 152*da0073e9SAndroid Build Coastguard Worker model = model.to(dtype=precision) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker if x.mode == "int8": 155*da0073e9SAndroid Build Coastguard Worker print("Using int8 weight-only quantization!") 156*da0073e9SAndroid Build Coastguard Worker model = x.quantizer(model).convert_for_runtime() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker state_dict = model.state_dict() 159*da0073e9SAndroid Build Coastguard Worker for k, v in state_dict.items(): 160*da0073e9SAndroid Build Coastguard Worker state_dict[k] = torch.nn.Parameter( 161*da0073e9SAndroid Build Coastguard Worker torch.randn(v.shape, device=device).to(dtype=v.dtype), 162*da0073e9SAndroid Build Coastguard Worker requires_grad=v.requires_grad, 163*da0073e9SAndroid Build Coastguard Worker ) 164*da0073e9SAndroid Build Coastguard Worker model.load_state_dict(state_dict, assign=True) 165*da0073e9SAndroid Build Coastguard Worker return model.eval() 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker# Only count activated parameters and buffers. 169*da0073e9SAndroid Build Coastguard Workerdef _get_model_size(model): 170*da0073e9SAndroid Build Coastguard Worker model_size = 0 171*da0073e9SAndroid Build Coastguard Worker for name, child in model.named_children(): 172*da0073e9SAndroid Build Coastguard Worker if not isinstance(child, torch.nn.Embedding): 173*da0073e9SAndroid Build Coastguard Worker model_size += sum( 174*da0073e9SAndroid Build Coastguard Worker p.numel() * p.dtype.itemsize 175*da0073e9SAndroid Build Coastguard Worker for p in itertools.chain(child.parameters(), child.buffers()) 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker # Remove the inactivated experts from the model size if this is mixture of experts 179*da0073e9SAndroid Build Coastguard Worker # architecture, since only activated experts are loaded. 180*da0073e9SAndroid Build Coastguard Worker if hasattr(model.config, "num_experts"): 181*da0073e9SAndroid Build Coastguard Worker config = model.config 182*da0073e9SAndroid Build Coastguard Worker for submodule in model.modules(): 183*da0073e9SAndroid Build Coastguard Worker if isinstance( 184*da0073e9SAndroid Build Coastguard Worker submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8) 185*da0073e9SAndroid Build Coastguard Worker ): 186*da0073e9SAndroid Build Coastguard Worker model_size -= ( 187*da0073e9SAndroid Build Coastguard Worker sum( 188*da0073e9SAndroid Build Coastguard Worker p.numel() * p.dtype.itemsize 189*da0073e9SAndroid Build Coastguard Worker for p in itertools.chain( 190*da0073e9SAndroid Build Coastguard Worker submodule.parameters(), child.buffers() 191*da0073e9SAndroid Build Coastguard Worker ) 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker * (config.num_experts - config.num_activated_experts) 194*da0073e9SAndroid Build Coastguard Worker / config.num_experts 195*da0073e9SAndroid Build Coastguard Worker ) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker return model_size 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Workerdef run_experiment( 201*da0073e9SAndroid Build Coastguard Worker x: GPTModelConfig, 202*da0073e9SAndroid Build Coastguard Worker num_samples: int = 5, 203*da0073e9SAndroid Build Coastguard Worker max_new_tokens: int = 200, 204*da0073e9SAndroid Build Coastguard Worker top_k: int = 200, 205*da0073e9SAndroid Build Coastguard Worker temperature: float = 0.8, 206*da0073e9SAndroid Build Coastguard Worker device: str = "cuda", 207*da0073e9SAndroid Build Coastguard Worker) -> None: 208*da0073e9SAndroid Build Coastguard Worker print(f"Loading model {x.name}") 209*da0073e9SAndroid Build Coastguard Worker t0 = time.time() 210*da0073e9SAndroid Build Coastguard Worker model = _load_model(x, device=device) 211*da0073e9SAndroid Build Coastguard Worker device_sync(device=device) # MKG 212*da0073e9SAndroid Build Coastguard Worker print(f"Time to load model: {time.time() - t0:.02f} seconds") 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker prompt = torch.tensor( 215*da0073e9SAndroid Build Coastguard Worker [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32 216*da0073e9SAndroid Build Coastguard Worker ) 217*da0073e9SAndroid Build Coastguard Worker prompt_length = prompt.size(0) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1234) 220*da0073e9SAndroid Build Coastguard Worker model_size = _get_model_size(model) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} 223*da0073e9SAndroid Build Coastguard Worker start = -1 224*da0073e9SAndroid Build Coastguard Worker compilation_time = None 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker for i in range(start, num_samples): 227*da0073e9SAndroid Build Coastguard Worker device_sync(device=device) # MKG 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker t0 = time.perf_counter() 230*da0073e9SAndroid Build Coastguard Worker y = generate( 231*da0073e9SAndroid Build Coastguard Worker model, prompt, max_new_tokens, temperature=temperature, top_k=top_k 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker if i == -1: 235*da0073e9SAndroid Build Coastguard Worker compilation_time = time.perf_counter() - t0 236*da0073e9SAndroid Build Coastguard Worker print(f"Compilation time: {compilation_time:.2f} seconds") 237*da0073e9SAndroid Build Coastguard Worker continue 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker device_sync(device=device) # MKG 240*da0073e9SAndroid Build Coastguard Worker t = time.perf_counter() - t0 241*da0073e9SAndroid Build Coastguard Worker tokens_generated = y.size(0) - prompt_length 242*da0073e9SAndroid Build Coastguard Worker tokens_sec = tokens_generated / t 243*da0073e9SAndroid Build Coastguard Worker aggregate_metrics["tokens_per_sec"].append(tokens_sec) 244*da0073e9SAndroid Build Coastguard Worker aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() 247*da0073e9SAndroid Build Coastguard Worker memory_bandwidth = torch.mean( 248*da0073e9SAndroid Build Coastguard Worker torch.tensor(aggregate_metrics["memory_bandwidth"]) 249*da0073e9SAndroid Build Coastguard Worker ).item() 250*da0073e9SAndroid Build Coastguard Worker print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") 251*da0073e9SAndroid Build Coastguard Worker print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") 252*da0073e9SAndroid Build Coastguard Worker print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 253*da0073e9SAndroid Build Coastguard Worker return token_per_sec, memory_bandwidth, compilation_time 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. 257*da0073e9SAndroid Build Coastguard Workerdef run_llama2_7b_bf16(device: str = "cuda"): 258*da0073e9SAndroid Build Coastguard Worker from benchmark import Experiment 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker model = GPTModelConfig( 261*da0073e9SAndroid Build Coastguard Worker "Llama-2-7b-chat-hf", 262*da0073e9SAndroid Build Coastguard Worker LLaMA, 263*da0073e9SAndroid Build Coastguard Worker "bfloat16", 264*da0073e9SAndroid Build Coastguard Worker LLaMAWeightOnlyInt8QuantHandler, 265*da0073e9SAndroid Build Coastguard Worker 94, 266*da0073e9SAndroid Build Coastguard Worker 1253, 267*da0073e9SAndroid Build Coastguard Worker 162, 268*da0073e9SAndroid Build Coastguard Worker ) 269*da0073e9SAndroid Build Coastguard Worker token_per_sec, memory_bandwidth, compilation_time = run_experiment( 270*da0073e9SAndroid Build Coastguard Worker model, device=device 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker return [ 273*da0073e9SAndroid Build Coastguard Worker Experiment( 274*da0073e9SAndroid Build Coastguard Worker model.name, 275*da0073e9SAndroid Build Coastguard Worker "token_per_sec", 276*da0073e9SAndroid Build Coastguard Worker model.token_per_sec, 277*da0073e9SAndroid Build Coastguard Worker f"{token_per_sec:.02f}", 278*da0073e9SAndroid Build Coastguard Worker model.mode, 279*da0073e9SAndroid Build Coastguard Worker device, 280*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 281*da0073e9SAndroid Build Coastguard Worker True, 282*da0073e9SAndroid Build Coastguard Worker ), 283*da0073e9SAndroid Build Coastguard Worker Experiment( 284*da0073e9SAndroid Build Coastguard Worker model.name, 285*da0073e9SAndroid Build Coastguard Worker "memory_bandwidth(GB/s)", 286*da0073e9SAndroid Build Coastguard Worker model.memory_bandwidth, 287*da0073e9SAndroid Build Coastguard Worker f"{memory_bandwidth:.02f}", 288*da0073e9SAndroid Build Coastguard Worker model.mode, 289*da0073e9SAndroid Build Coastguard Worker device, 290*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 291*da0073e9SAndroid Build Coastguard Worker True, 292*da0073e9SAndroid Build Coastguard Worker ), 293*da0073e9SAndroid Build Coastguard Worker Experiment( 294*da0073e9SAndroid Build Coastguard Worker model.name, 295*da0073e9SAndroid Build Coastguard Worker "compilation_time(s)", 296*da0073e9SAndroid Build Coastguard Worker model.compilation_time, 297*da0073e9SAndroid Build Coastguard Worker f"{compilation_time:.02f}", 298*da0073e9SAndroid Build Coastguard Worker model.mode, 299*da0073e9SAndroid Build Coastguard Worker device, 300*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 301*da0073e9SAndroid Build Coastguard Worker True, 302*da0073e9SAndroid Build Coastguard Worker ), 303*da0073e9SAndroid Build Coastguard Worker ] 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. 307*da0073e9SAndroid Build Coastguard Workerdef run_llama2_7b_int8(device: str = "cuda"): 308*da0073e9SAndroid Build Coastguard Worker from benchmark import Experiment 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker model = GPTModelConfig( 311*da0073e9SAndroid Build Coastguard Worker "Llama-2-7b-chat-hf", 312*da0073e9SAndroid Build Coastguard Worker LLaMA, 313*da0073e9SAndroid Build Coastguard Worker "int8", 314*da0073e9SAndroid Build Coastguard Worker LLaMAWeightOnlyInt8QuantHandler, 315*da0073e9SAndroid Build Coastguard Worker 144, 316*da0073e9SAndroid Build Coastguard Worker 957, 317*da0073e9SAndroid Build Coastguard Worker 172, 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker token_per_sec, memory_bandwidth, compilation_time = run_experiment( 320*da0073e9SAndroid Build Coastguard Worker model, device=device 321*da0073e9SAndroid Build Coastguard Worker ) 322*da0073e9SAndroid Build Coastguard Worker return [ 323*da0073e9SAndroid Build Coastguard Worker Experiment( 324*da0073e9SAndroid Build Coastguard Worker model.name, 325*da0073e9SAndroid Build Coastguard Worker "token_per_sec", 326*da0073e9SAndroid Build Coastguard Worker model.token_per_sec, 327*da0073e9SAndroid Build Coastguard Worker f"{token_per_sec:.02f}", 328*da0073e9SAndroid Build Coastguard Worker model.mode, 329*da0073e9SAndroid Build Coastguard Worker device, 330*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 331*da0073e9SAndroid Build Coastguard Worker True, 332*da0073e9SAndroid Build Coastguard Worker ), 333*da0073e9SAndroid Build Coastguard Worker Experiment( 334*da0073e9SAndroid Build Coastguard Worker model.name, 335*da0073e9SAndroid Build Coastguard Worker "memory_bandwidth(GB/s)", 336*da0073e9SAndroid Build Coastguard Worker model.memory_bandwidth, 337*da0073e9SAndroid Build Coastguard Worker f"{memory_bandwidth:.02f}", 338*da0073e9SAndroid Build Coastguard Worker model.mode, 339*da0073e9SAndroid Build Coastguard Worker device, 340*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 341*da0073e9SAndroid Build Coastguard Worker True, 342*da0073e9SAndroid Build Coastguard Worker ), 343*da0073e9SAndroid Build Coastguard Worker Experiment( 344*da0073e9SAndroid Build Coastguard Worker model.name, 345*da0073e9SAndroid Build Coastguard Worker "compilation_time(s)", 346*da0073e9SAndroid Build Coastguard Worker model.compilation_time, 347*da0073e9SAndroid Build Coastguard Worker f"{compilation_time:.02f}", 348*da0073e9SAndroid Build Coastguard Worker model.mode, 349*da0073e9SAndroid Build Coastguard Worker device, 350*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 351*da0073e9SAndroid Build Coastguard Worker True, 352*da0073e9SAndroid Build Coastguard Worker ), 353*da0073e9SAndroid Build Coastguard Worker ] 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. 357*da0073e9SAndroid Build Coastguard Workerdef run_mixtral_8x7b_int8(device: str = "cuda"): 358*da0073e9SAndroid Build Coastguard Worker from benchmark import Experiment 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. 361*da0073e9SAndroid Build Coastguard Worker model = GPTModelConfig( 362*da0073e9SAndroid Build Coastguard Worker "Mixtral-8x7B-v0.1", 363*da0073e9SAndroid Build Coastguard Worker MixtralMoE, 364*da0073e9SAndroid Build Coastguard Worker "int8", 365*da0073e9SAndroid Build Coastguard Worker MixtralMoEWeightOnlyInt8QuantHandler, 366*da0073e9SAndroid Build Coastguard Worker 175, 367*da0073e9SAndroid Build Coastguard Worker 1130, 368*da0073e9SAndroid Build Coastguard Worker 162, 369*da0073e9SAndroid Build Coastguard Worker ) 370*da0073e9SAndroid Build Coastguard Worker token_per_sec, memory_bandwidth, compilation_time = run_experiment( 371*da0073e9SAndroid Build Coastguard Worker model, device=device 372*da0073e9SAndroid Build Coastguard Worker ) 373*da0073e9SAndroid Build Coastguard Worker return [ 374*da0073e9SAndroid Build Coastguard Worker Experiment( 375*da0073e9SAndroid Build Coastguard Worker model.name, 376*da0073e9SAndroid Build Coastguard Worker "token_per_sec", 377*da0073e9SAndroid Build Coastguard Worker model.token_per_sec, 378*da0073e9SAndroid Build Coastguard Worker f"{token_per_sec:.02f}", 379*da0073e9SAndroid Build Coastguard Worker model.mode, 380*da0073e9SAndroid Build Coastguard Worker device, 381*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 382*da0073e9SAndroid Build Coastguard Worker True, 383*da0073e9SAndroid Build Coastguard Worker ), 384*da0073e9SAndroid Build Coastguard Worker Experiment( 385*da0073e9SAndroid Build Coastguard Worker model.name, 386*da0073e9SAndroid Build Coastguard Worker "memory_bandwidth(GB/s)", 387*da0073e9SAndroid Build Coastguard Worker model.memory_bandwidth, 388*da0073e9SAndroid Build Coastguard Worker f"{memory_bandwidth:.02f}", 389*da0073e9SAndroid Build Coastguard Worker model.mode, 390*da0073e9SAndroid Build Coastguard Worker device, 391*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 392*da0073e9SAndroid Build Coastguard Worker True, 393*da0073e9SAndroid Build Coastguard Worker ), 394*da0073e9SAndroid Build Coastguard Worker Experiment( 395*da0073e9SAndroid Build Coastguard Worker model.name, 396*da0073e9SAndroid Build Coastguard Worker "compilation_time(s)", 397*da0073e9SAndroid Build Coastguard Worker model.compilation_time, 398*da0073e9SAndroid Build Coastguard Worker f"{compilation_time:.02f}", 399*da0073e9SAndroid Build Coastguard Worker model.mode, 400*da0073e9SAndroid Build Coastguard Worker device, 401*da0073e9SAndroid Build Coastguard Worker get_arch_name(), 402*da0073e9SAndroid Build Coastguard Worker True, 403*da0073e9SAndroid Build Coastguard Worker ), 404*da0073e9SAndroid Build Coastguard Worker ] 405