xref: /aosp_15_r20/external/executorch/examples/models/llama/experimental/generate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 # Copyright (c) Meta Platforms, Inc. and affiliates.
2 # All rights reserved.
3 #
4 # This source code is licensed under the BSD-style license found in the
5 # LICENSE file in the root directory of this source tree.
6 
7 # Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
8 import argparse
9 
10 from typing import Optional, Tuple
11 
12 import torch
13 
14 from executorch.examples.models.llama.experimental.load_gguf_q4_0 import load_gguf_q4_0
15 from sentencepiece import SentencePieceProcessor
16 
17 
18 def multinomial_sample_one_no_sync(
19     probs_sort,
20 ):  # Does multinomial sampling without a cuda synchronization
21     q = torch.empty_like(probs_sort).exponential_(1)
22     return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
23 
24 
25 def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
26     logits = logits / max(temperature, 1e-5)
27 
28     if top_k is not None:
29         v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
30         pivot = v.select(-1, -1).unsqueeze(-1)
31         logits = torch.where(logits < pivot, -float("Inf"), logits)
32     probs = torch.nn.functional.softmax(logits, dim=-1)
33     return probs
34 
35 
36 def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
37     probs = logits_to_probs(logits[0, -1], temperature, top_k)
38     idx_next = multinomial_sample_one_no_sync(probs)
39     return idx_next, probs
40 
41 
42 def encode_tokens(tokenizer, string, bos=True, device="cpu"):
43     tokens = tokenizer.encode(string)
44     if bos:
45         tokens = [tokenizer.bos_id()] + tokens
46     return torch.tensor(tokens, dtype=torch.int, device=device)
47 
48 
49 def decode_one_token(
50     model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs
51 ) -> Tuple[torch.Tensor, torch.Tensor]:
52     logits = model(x)
53     return sample(logits, **sampling_kwargs)
54 
55 
56 def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
57     return decode_one_token(model, x, **sampling_kwargs)[0]
58 
59 
60 def decode_n_tokens(
61     model: torch.nn.Module,
62     cur_token: torch.Tensor,
63     num_new_tokens: int,
64     callback=lambda _: _,
65     **sampling_kwargs,
66 ):
67     print(f"cur_token: {cur_token}")
68     new_tokens, new_probs = [], []
69     for _ in range(num_new_tokens):
70         with torch.backends.cuda.sdp_kernel(
71             enable_flash=False, enable_mem_efficient=False, enable_math=True
72         ):  # Actually better for Inductor to codegen attention here
73             next_token, next_prob = decode_one_token(
74                 model, cur_token.view(1, -1), **sampling_kwargs
75             )
76             new_tokens.append(next_token.clone())
77             # print(next_token)
78             callback(next_token)
79             new_probs.append(next_prob.clone())
80             cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0)
81             # print(cur_token)
82 
83     return new_tokens, new_probs
84 
85 
86 @torch.no_grad()
87 def generate(
88     model: torch.nn.Module,
89     prompt: torch.Tensor,
90     max_new_tokens: int,
91     *,
92     interactive: bool,
93     callback=lambda x: x,
94     **sampling_kwargs,
95 ) -> torch.Tensor:
96     """
97     Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
98     """
99 
100     # create an empty tensor of the expected final shape and fill in the current tokens
101     T = prompt.size(0)
102     T_new = T + max_new_tokens
103     # if interactive:
104     #     max_seq_length = 350
105     # else:
106     #     max_seq_length = min(T_new, model.params.max_seq_len)
107 
108     device, dtype = prompt.device, prompt.dtype
109 
110     # with torch.device(device):
111     #     model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
112 
113     # create an empty tensor of the expected final shape and fill in the current tokens
114     empty = torch.empty(T_new, dtype=dtype, device=device)
115     empty[:T] = prompt
116     seq = empty
117     # input_pos = torch.arange(0, T, device=device)
118 
119     next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs)
120     seq[T] = next_token
121     callback(next_token)
122 
123     cur_tokens = torch.cat((prompt, next_token), dim=0)
124     # input_pos = torch.tensor([T], device=device, dtype=torch.int)
125 
126     generated_tokens, _ = decode_n_tokens(
127         model,
128         cur_tokens.view(1, -1),
129         # input_pos,
130         max_new_tokens - 1,
131         callback=callback,
132         **sampling_kwargs,
133     )
134     seq[T + 1 :] = torch.cat(generated_tokens)
135 
136     return seq
137 
138 
139 def main() -> None:
140     parser = argparse.ArgumentParser()
141     parser.add_argument(
142         "--gguf_file",
143         type=str,
144         help="The GGUF file to load.",
145     )
146     parser.add_argument(
147         "--tokenizer_path",
148         type=str,
149         help="The tokenizer.model path.",
150     )
151     parser.add_argument(
152         "--prompt", type=str, default="Hello, my name is", help="Input prompt."
153     )
154 
155     args = parser.parse_args()
156 
157     tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
158     encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu")
159 
160     pt_model = load_gguf_q4_0(args.gguf_file)
161 
162     max_new_tokens = 100
163     buffer = [tokenizer.decode(encoded.tolist())]
164     period_id = tokenizer.encode(".")[0]
165     done_generating = False
166 
167     def callback(x):
168         nonlocal done_generating
169         if done_generating:
170             return
171         buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
172         if x.item() == tokenizer.eos_id():
173             done_generating = True
174         if len(buffer) == 4 or done_generating:
175             print("".join(buffer), end="", flush=True)
176             buffer.clear()
177 
178     generate(
179         pt_model,
180         encoded,
181         max_new_tokens,
182         interactive=False,
183         callback=callback,
184         temperature=1.0,
185         top_k=10,
186     )
187 
188 
189 if __name__ == "__main__":
190     main()
191