xref: /aosp_15_r20/external/pytorch/torchgen/static_runtime/config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
4
5
6def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
7    if isinstance(g, NativeFunctionsGroup):
8        return str(g.functional.func.name.name.base)
9    else:
10        return str(g.view.root_name)
11
12
13is_hand_written_ops_ = frozenset(
14    (
15        "abs",
16        "add",
17        "addmm",
18        "all",
19        "any",
20        "argmin",
21        "bmm",
22        "clamp",
23        "clamp_min",
24        "cumsum",
25        "div",
26        "fmod",
27        "index_select",
28        "leaky_relu",
29        "linear",
30        "log",
31        "matmul",
32        "mul",
33        "narrow_copy",
34        "nonzero",
35        "pow",
36        "remainder",
37        "sigmoid",
38        "sign",
39        "sub",
40        "tanh",
41        "detach",
42        "expand_as",
43        "flatten",
44        "narrow",
45        "reshape_as",
46        "select",
47        "slice",
48        "softmax",
49        "split",
50        "squeeze",
51        "transpose",
52        "view",
53        "where",
54    )
55)
56
57
58def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
59    name_base = func_name_base_str(g)
60    return name_base in is_hand_written_ops_
61
62
63def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
64    assert index == 0 or index == 1
65    if op_name == "addr":
66        if index == 0:
67            arg_map["self"] = "at::rand({6, 6})"
68            arg_map["vec1"] = "at::rand({6})"
69            arg_map["vec2"] = "at::rand({6})"
70        else:
71            arg_map["self"] = "at::rand({22, 22})"
72            arg_map["vec1"] = "at::rand({22})"
73            arg_map["vec2"] = "at::rand({22})"
74        return
75    if op_name == "mv":
76        if index == 0:
77            arg_map["self"] = "at::rand({6, 6})"
78            arg_map["vec"] = "at::rand({6})"
79        else:
80            arg_map["self"] = "at::rand({22, 22})"
81            arg_map["vec"] = "at::rand({22})"
82        return
83    if op_name == "addbmm":
84        if index == 0:
85            arg_map["self"] = "at::rand({6, 6})"
86        else:
87            arg_map["self"] = "at::rand({22, 22})"
88        return
89    if op_name == "cross":
90        if index == 0:
91            arg_map["self"] = "at::rand({3, 3, 3})"
92            arg_map["other"] = "at::rand({3, 3, 3})"
93        else:
94            arg_map["self"] = "at::rand({22, 3, 22})"
95            arg_map["other"] = "at::rand({22, 3, 22})"
96        return
97    if op_name == "take":
98        if index == 0:
99            arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
100        else:
101            arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
102        return
103    if op_name == "take_along_dim":
104        if index == 0:
105            arg_map["indices"] = "at::argsort(self0, 1, true)"
106        else:
107            arg_map["indices"] = "at::argsort(self1, 1, true)"
108        return
109    if op_name == "masked_select":
110        if index == 0:
111            arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
112        else:
113            arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
114        return
115    if op_name == "orgqr":
116        if index == 0:
117            arg_map["input2"] = "at::rand({6, 6})"
118        else:
119            arg_map["input2"] = "at::rand({22, 22})"
120        return
121    if op_name == "ormqr":
122        if index == 0:
123            arg_map["input2"] = "at::rand({6, 6})"
124        else:
125            arg_map["input2"] = "at::rand({22, 22})"
126        return
127    if op_name == "quantile":
128        if index == 0:
129            arg_map["q"] = "at::rand({6})"
130            arg_map["interpolation"] = '"linear"'
131        else:
132            arg_map["q"] = "at::rand({22})"
133            arg_map["interpolation"] = '"linear"'
134        return
135    if op_name == "nanquantile":
136        if index == 0:
137            arg_map["q"] = "at::rand({6})"
138            arg_map["interpolation"] = '"linear"'
139        else:
140            arg_map["q"] = "at::rand({22})"
141            arg_map["interpolation"] = '"linear"'
142        return
143    if op_name == "multi_margin_loss":
144        if index == 0:
145            arg_map["self"] = "at::rand({6, 6})"
146            arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
147            arg_map["weight"] = "at::rand({6})"
148        else:
149            arg_map["self"] = "at::rand({22, 22})"
150            arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
151            arg_map["weight"] = "at::rand({22})"
152        return
153    if op_name == "multilabel_margin_loss":
154        if index == 0:
155            arg_map["self"] = "at::rand({6, 6})"
156            arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
157        else:
158            arg_map["self"] = "at::rand({22, 22})"
159            arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
160        return
161    if op_name == "nll_loss":
162        if index == 0:
163            arg_map["self"] = "at::rand({6, 6})"
164            arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
165            arg_map["weight"] = "at::rand({6})"
166        else:
167            arg_map["self"] = "at::rand({22, 22})"
168            arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
169            arg_map["weight"] = "at::rand({22})"
170        return
171    if op_name == "nll_loss2d":
172        if index == 0:
173            arg_map["self"] = "at::rand({6, 6, 6, 6})"
174            arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
175            arg_map["weight"] = "at::rand({6})"
176        else:
177            arg_map["self"] = "at::rand({22, 22, 22, 22})"
178            arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
179            arg_map["weight"] = "at::rand({22})"
180        return
181    if op_name in (
182        "fft_fft",
183        "fft_ifft",
184        "fft_rfft",
185        "fft_irfft",
186        "fft_hfft",
187        "fft_ihfft",
188    ):
189        arg_map["norm"] = '"forward"'
190        return
191    if op_name == "linalg_tensorinv":
192        if index == 0:
193            arg_map["self"] = "at::rand({6, 6, 6, 6})"
194            arg_map["ind"] = "2"
195        else:
196            arg_map["self"] = "at::rand({22, 22, 22, 22})"
197            arg_map["ind"] = "2"
198        return
199    if op_name == "addmv":
200        if index == 0:
201            arg_map["self"] = "at::rand({2})"
202            arg_map["mat"] = "at::rand({2, 2})"
203            arg_map["vec"] = "at::rand({2})"
204        else:
205            arg_map["self"] = "at::rand({35})"
206            arg_map["mat"] = "at::rand({35, 35})"
207            arg_map["vec"] = "at::rand({35})"
208        return
209    if op_name == "acosh":
210        if index == 0:
211            arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
212        else:
213            arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
214        return
215    if op_name == "adaptive_max_pool2d_backward":
216        if index == 0:
217            arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)"
218            arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)"
219            arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)"
220        else:
221            arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)"
222            arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)"
223            arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)"
224        return
225    if op_name == "adaptive_max_pool3d_backward":
226        if index == 0:
227            arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
228            arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
229            arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)"
230        else:
231            arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
232            arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
233            arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)"
234        return
235    if op_name == "bitwise_left_shift":
236        if index == 0:
237            arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)"
238            arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)"
239        else:
240            arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)"
241            arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)"
242        return
243    if op_name == "bitwise_right_shift":
244        if index == 0:
245            arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)"
246            arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)"
247        else:
248            arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)"
249            arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)"
250        return
251    if op_name == "gather":
252        if index == 0:
253            arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
254            arg_map["dim"] = "1"
255            arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
256            arg_map["sparse_grad"] = "false"
257        else:
258            arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
259            arg_map["dim"] = "1"
260            arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
261            arg_map["sparse_grad"] = "false"
262        return
263    if op_name == "gelu":
264        if index == 0:
265            arg_map["self"] = "at::rand({6, 6, 6})"
266            arg_map["approximate"] = '"tanh"'
267        else:
268            arg_map["self"] = "at::rand({22, 22, 22})"
269            arg_map["approximate"] = '"tanh"'
270        return
271    if op_name == "gelu_backward":
272        if index == 0:
273            arg_map["grad_output"] = "at::rand({6, 6, 6})"
274            arg_map["self"] = "at::rand({6, 6, 6})"
275            arg_map["approximate"] = '"tanh"'
276        else:
277            arg_map["grad_output"] = "at::rand({22, 22, 22})"
278            arg_map["self"] = "at::rand({22, 22, 22})"
279            arg_map["approximate"] = '"tanh"'
280        return
281    if op_name == "index_add":
282        if index == 0:
283            arg_map["self"] = "at::rand({2})"
284            arg_map["dim"] = "0"
285            arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
286            arg_map["source"] = "at::rand({2})"
287            arg_map["alpha"] = "2"
288        else:
289            arg_map["self"] = "at::rand({16})"
290            arg_map["dim"] = "0"
291            arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
292            arg_map["source"] = "at::rand({16})"
293            arg_map["alpha"] = "2"
294        return
295    if op_name == "index_copy":
296        if index == 0:
297            arg_map["self"] = "at::rand({2})"
298            arg_map["dim"] = "0"
299            arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
300            arg_map["source"] = "at::rand({2})"
301        else:
302            arg_map["self"] = "at::rand({32})"
303            arg_map["dim"] = "0"
304            arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
305            arg_map["source"] = "at::rand({32})"
306        return
307    if op_name == "linalg_cross":
308        if index == 0:
309            arg_map["self"] = "at::rand({6, 3, 6})"
310            arg_map["other"] = "at::rand({6, 3, 6})"
311            arg_map["dim"] = "1"
312        else:
313            arg_map["self"] = "at::rand({22, 3, 22})"
314            arg_map["other"] = "at::rand({22, 3, 22})"
315            arg_map["dim"] = "1"
316        return
317    if op_name == "nll_loss_backward":
318        if index == 0:
319            arg_map["grad_output"] = "at::rand({})"
320            arg_map["self"] = "at::rand({6})"
321            arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
322            arg_map["weight"] = "at::rand({6})"
323            arg_map["reduction"] = "1"
324            arg_map["ignore_index"] = "1"
325            arg_map["total_weight"] = "at::rand({})"
326        else:
327            arg_map["grad_output"] = "at::rand({})"
328            arg_map["self"] = "at::rand({36})"
329            arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
330            arg_map["weight"] = "at::rand({36})"
331            arg_map["reduction"] = "1"
332            arg_map["ignore_index"] = "1"
333            arg_map["total_weight"] = "at::rand({})"
334        return
335    if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
336        if index == 0:
337            arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
338            arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
339            arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
340        else:
341            arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
342            arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
343            arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
344        if "reduce" in arg_map:
345            arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
346        return
347    if op_name == "scatter_reduce":
348        arg_map["reduce"] = '"mean"'
349        if index == 0:
350            arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
351        else:
352            arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
353        return
354    if op_name == "special_zeta":
355        if index == 0:
356            arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
357            arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
358        else:
359            arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
360            arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
361        return
362    if op_name == "_convert_indices_from_csr_to_coo":
363        if index == 0:
364            arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
365            arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
366            arg_map["out_int32"] = "false"
367        else:
368            arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
369            arg_map[
370                "col_indices"
371            ] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
372            arg_map["out_int32"] = "false"
373        return
374    if op_name == "_convert_indices_from_coo_to_csr":
375        if index == 0:
376            arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
377            arg_map["size"] = "10"
378            arg_map["out_int32"] = "false"
379        else:
380            arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
381            arg_map["size"] = "24"
382            arg_map["out_int32"] = "false"
383        return
384    if op_name in ("diagonal", "linalg_diagonal"):
385        arg_map["offset"] = "0"
386        arg_map["dim1"] = "2"
387        arg_map["dim2"] = "1"
388        return
389