1 # Copyright (c) Meta Platforms, Inc. and affiliates. 2 # All rights reserved. 3 # 4 # This source code is licensed under the BSD-style license found in the 5 # LICENSE file in the root directory of this source tree. 6 7 # Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py 8 import argparse 9 10 from typing import Optional, Tuple 11 12 import torch 13 14 from executorch.examples.models.llama.experimental.load_gguf_q4_0 import load_gguf_q4_0 15 from sentencepiece import SentencePieceProcessor 16 17 18 def multinomial_sample_one_no_sync( 19 probs_sort, 20 ): # Does multinomial sampling without a cuda synchronization 21 q = torch.empty_like(probs_sort).exponential_(1) 22 return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 23 24 25 def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 26 logits = logits / max(temperature, 1e-5) 27 28 if top_k is not None: 29 v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 30 pivot = v.select(-1, -1).unsqueeze(-1) 31 logits = torch.where(logits < pivot, -float("Inf"), logits) 32 probs = torch.nn.functional.softmax(logits, dim=-1) 33 return probs 34 35 36 def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 37 probs = logits_to_probs(logits[0, -1], temperature, top_k) 38 idx_next = multinomial_sample_one_no_sync(probs) 39 return idx_next, probs 40 41 42 def encode_tokens(tokenizer, string, bos=True, device="cpu"): 43 tokens = tokenizer.encode(string) 44 if bos: 45 tokens = [tokenizer.bos_id()] + tokens 46 return torch.tensor(tokens, dtype=torch.int, device=device) 47 48 49 def decode_one_token( 50 model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs 51 ) -> Tuple[torch.Tensor, torch.Tensor]: 52 logits = model(x) 53 return sample(logits, **sampling_kwargs) 54 55 56 def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor: 57 return decode_one_token(model, x, **sampling_kwargs)[0] 58 59 60 def decode_n_tokens( 61 model: torch.nn.Module, 62 cur_token: torch.Tensor, 63 num_new_tokens: int, 64 callback=lambda _: _, 65 **sampling_kwargs, 66 ): 67 print(f"cur_token: {cur_token}") 68 new_tokens, new_probs = [], [] 69 for _ in range(num_new_tokens): 70 with torch.backends.cuda.sdp_kernel( 71 enable_flash=False, enable_mem_efficient=False, enable_math=True 72 ): # Actually better for Inductor to codegen attention here 73 next_token, next_prob = decode_one_token( 74 model, cur_token.view(1, -1), **sampling_kwargs 75 ) 76 new_tokens.append(next_token.clone()) 77 # print(next_token) 78 callback(next_token) 79 new_probs.append(next_prob.clone()) 80 cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0) 81 # print(cur_token) 82 83 return new_tokens, new_probs 84 85 86 @torch.no_grad() 87 def generate( 88 model: torch.nn.Module, 89 prompt: torch.Tensor, 90 max_new_tokens: int, 91 *, 92 interactive: bool, 93 callback=lambda x: x, 94 **sampling_kwargs, 95 ) -> torch.Tensor: 96 """ 97 Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 98 """ 99 100 # create an empty tensor of the expected final shape and fill in the current tokens 101 T = prompt.size(0) 102 T_new = T + max_new_tokens 103 # if interactive: 104 # max_seq_length = 350 105 # else: 106 # max_seq_length = min(T_new, model.params.max_seq_len) 107 108 device, dtype = prompt.device, prompt.dtype 109 110 # with torch.device(device): 111 # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 112 113 # create an empty tensor of the expected final shape and fill in the current tokens 114 empty = torch.empty(T_new, dtype=dtype, device=device) 115 empty[:T] = prompt 116 seq = empty 117 # input_pos = torch.arange(0, T, device=device) 118 119 next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs) 120 seq[T] = next_token 121 callback(next_token) 122 123 cur_tokens = torch.cat((prompt, next_token), dim=0) 124 # input_pos = torch.tensor([T], device=device, dtype=torch.int) 125 126 generated_tokens, _ = decode_n_tokens( 127 model, 128 cur_tokens.view(1, -1), 129 # input_pos, 130 max_new_tokens - 1, 131 callback=callback, 132 **sampling_kwargs, 133 ) 134 seq[T + 1 :] = torch.cat(generated_tokens) 135 136 return seq 137 138 139 def main() -> None: 140 parser = argparse.ArgumentParser() 141 parser.add_argument( 142 "--gguf_file", 143 type=str, 144 help="The GGUF file to load.", 145 ) 146 parser.add_argument( 147 "--tokenizer_path", 148 type=str, 149 help="The tokenizer.model path.", 150 ) 151 parser.add_argument( 152 "--prompt", type=str, default="Hello, my name is", help="Input prompt." 153 ) 154 155 args = parser.parse_args() 156 157 tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path)) 158 encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu") 159 160 pt_model = load_gguf_q4_0(args.gguf_file) 161 162 max_new_tokens = 100 163 buffer = [tokenizer.decode(encoded.tolist())] 164 period_id = tokenizer.encode(".")[0] 165 done_generating = False 166 167 def callback(x): 168 nonlocal done_generating 169 if done_generating: 170 return 171 buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) 172 if x.item() == tokenizer.eos_id(): 173 done_generating = True 174 if len(buffer) == 4 or done_generating: 175 print("".join(buffer), end="", flush=True) 176 buffer.clear() 177 178 generate( 179 pt_model, 180 encoded, 181 max_new_tokens, 182 interactive=False, 183 callback=callback, 184 temperature=1.0, 185 top_k=10, 186 ) 187 188 189 if __name__ == "__main__": 190 main() 191