1from __future__ import annotations 2 3from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup 4 5 6def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str: 7 if isinstance(g, NativeFunctionsGroup): 8 return str(g.functional.func.name.name.base) 9 else: 10 return str(g.view.root_name) 11 12 13is_hand_written_ops_ = frozenset( 14 ( 15 "abs", 16 "add", 17 "addmm", 18 "all", 19 "any", 20 "argmin", 21 "bmm", 22 "clamp", 23 "clamp_min", 24 "cumsum", 25 "div", 26 "fmod", 27 "index_select", 28 "leaky_relu", 29 "linear", 30 "log", 31 "matmul", 32 "mul", 33 "narrow_copy", 34 "nonzero", 35 "pow", 36 "remainder", 37 "sigmoid", 38 "sign", 39 "sub", 40 "tanh", 41 "detach", 42 "expand_as", 43 "flatten", 44 "narrow", 45 "reshape_as", 46 "select", 47 "slice", 48 "softmax", 49 "split", 50 "squeeze", 51 "transpose", 52 "view", 53 "where", 54 ) 55) 56 57 58def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: 59 name_base = func_name_base_str(g) 60 return name_base in is_hand_written_ops_ 61 62 63def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None: 64 assert index == 0 or index == 1 65 if op_name == "addr": 66 if index == 0: 67 arg_map["self"] = "at::rand({6, 6})" 68 arg_map["vec1"] = "at::rand({6})" 69 arg_map["vec2"] = "at::rand({6})" 70 else: 71 arg_map["self"] = "at::rand({22, 22})" 72 arg_map["vec1"] = "at::rand({22})" 73 arg_map["vec2"] = "at::rand({22})" 74 return 75 if op_name == "mv": 76 if index == 0: 77 arg_map["self"] = "at::rand({6, 6})" 78 arg_map["vec"] = "at::rand({6})" 79 else: 80 arg_map["self"] = "at::rand({22, 22})" 81 arg_map["vec"] = "at::rand({22})" 82 return 83 if op_name == "addbmm": 84 if index == 0: 85 arg_map["self"] = "at::rand({6, 6})" 86 else: 87 arg_map["self"] = "at::rand({22, 22})" 88 return 89 if op_name == "cross": 90 if index == 0: 91 arg_map["self"] = "at::rand({3, 3, 3})" 92 arg_map["other"] = "at::rand({3, 3, 3})" 93 else: 94 arg_map["self"] = "at::rand({22, 3, 22})" 95 arg_map["other"] = "at::rand({22, 3, 22})" 96 return 97 if op_name == "take": 98 if index == 0: 99 arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)" 100 else: 101 arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)" 102 return 103 if op_name == "take_along_dim": 104 if index == 0: 105 arg_map["indices"] = "at::argsort(self0, 1, true)" 106 else: 107 arg_map["indices"] = "at::argsort(self1, 1, true)" 108 return 109 if op_name == "masked_select": 110 if index == 0: 111 arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5" 112 else: 113 arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5" 114 return 115 if op_name == "orgqr": 116 if index == 0: 117 arg_map["input2"] = "at::rand({6, 6})" 118 else: 119 arg_map["input2"] = "at::rand({22, 22})" 120 return 121 if op_name == "ormqr": 122 if index == 0: 123 arg_map["input2"] = "at::rand({6, 6})" 124 else: 125 arg_map["input2"] = "at::rand({22, 22})" 126 return 127 if op_name == "quantile": 128 if index == 0: 129 arg_map["q"] = "at::rand({6})" 130 arg_map["interpolation"] = '"linear"' 131 else: 132 arg_map["q"] = "at::rand({22})" 133 arg_map["interpolation"] = '"linear"' 134 return 135 if op_name == "nanquantile": 136 if index == 0: 137 arg_map["q"] = "at::rand({6})" 138 arg_map["interpolation"] = '"linear"' 139 else: 140 arg_map["q"] = "at::rand({22})" 141 arg_map["interpolation"] = '"linear"' 142 return 143 if op_name == "multi_margin_loss": 144 if index == 0: 145 arg_map["self"] = "at::rand({6, 6})" 146 arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" 147 arg_map["weight"] = "at::rand({6})" 148 else: 149 arg_map["self"] = "at::rand({22, 22})" 150 arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" 151 arg_map["weight"] = "at::rand({22})" 152 return 153 if op_name == "multilabel_margin_loss": 154 if index == 0: 155 arg_map["self"] = "at::rand({6, 6})" 156 arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)" 157 else: 158 arg_map["self"] = "at::rand({22, 22})" 159 arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)" 160 return 161 if op_name == "nll_loss": 162 if index == 0: 163 arg_map["self"] = "at::rand({6, 6})" 164 arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" 165 arg_map["weight"] = "at::rand({6})" 166 else: 167 arg_map["self"] = "at::rand({22, 22})" 168 arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" 169 arg_map["weight"] = "at::rand({22})" 170 return 171 if op_name == "nll_loss2d": 172 if index == 0: 173 arg_map["self"] = "at::rand({6, 6, 6, 6})" 174 arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" 175 arg_map["weight"] = "at::rand({6})" 176 else: 177 arg_map["self"] = "at::rand({22, 22, 22, 22})" 178 arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" 179 arg_map["weight"] = "at::rand({22})" 180 return 181 if op_name in ( 182 "fft_fft", 183 "fft_ifft", 184 "fft_rfft", 185 "fft_irfft", 186 "fft_hfft", 187 "fft_ihfft", 188 ): 189 arg_map["norm"] = '"forward"' 190 return 191 if op_name == "linalg_tensorinv": 192 if index == 0: 193 arg_map["self"] = "at::rand({6, 6, 6, 6})" 194 arg_map["ind"] = "2" 195 else: 196 arg_map["self"] = "at::rand({22, 22, 22, 22})" 197 arg_map["ind"] = "2" 198 return 199 if op_name == "addmv": 200 if index == 0: 201 arg_map["self"] = "at::rand({2})" 202 arg_map["mat"] = "at::rand({2, 2})" 203 arg_map["vec"] = "at::rand({2})" 204 else: 205 arg_map["self"] = "at::rand({35})" 206 arg_map["mat"] = "at::rand({35, 35})" 207 arg_map["vec"] = "at::rand({35})" 208 return 209 if op_name == "acosh": 210 if index == 0: 211 arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})" 212 else: 213 arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})" 214 return 215 if op_name == "adaptive_max_pool2d_backward": 216 if index == 0: 217 arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)" 218 arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)" 219 arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)" 220 else: 221 arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)" 222 arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)" 223 arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)" 224 return 225 if op_name == "adaptive_max_pool3d_backward": 226 if index == 0: 227 arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)" 228 arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)" 229 arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)" 230 else: 231 arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)" 232 arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)" 233 arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)" 234 return 235 if op_name == "bitwise_left_shift": 236 if index == 0: 237 arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)" 238 arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)" 239 else: 240 arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)" 241 arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)" 242 return 243 if op_name == "bitwise_right_shift": 244 if index == 0: 245 arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)" 246 arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)" 247 else: 248 arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)" 249 arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)" 250 return 251 if op_name == "gather": 252 if index == 0: 253 arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)" 254 arg_map["dim"] = "1" 255 arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" 256 arg_map["sparse_grad"] = "false" 257 else: 258 arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)" 259 arg_map["dim"] = "1" 260 arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)" 261 arg_map["sparse_grad"] = "false" 262 return 263 if op_name == "gelu": 264 if index == 0: 265 arg_map["self"] = "at::rand({6, 6, 6})" 266 arg_map["approximate"] = '"tanh"' 267 else: 268 arg_map["self"] = "at::rand({22, 22, 22})" 269 arg_map["approximate"] = '"tanh"' 270 return 271 if op_name == "gelu_backward": 272 if index == 0: 273 arg_map["grad_output"] = "at::rand({6, 6, 6})" 274 arg_map["self"] = "at::rand({6, 6, 6})" 275 arg_map["approximate"] = '"tanh"' 276 else: 277 arg_map["grad_output"] = "at::rand({22, 22, 22})" 278 arg_map["self"] = "at::rand({22, 22, 22})" 279 arg_map["approximate"] = '"tanh"' 280 return 281 if op_name == "index_add": 282 if index == 0: 283 arg_map["self"] = "at::rand({2})" 284 arg_map["dim"] = "0" 285 arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)" 286 arg_map["source"] = "at::rand({2})" 287 arg_map["alpha"] = "2" 288 else: 289 arg_map["self"] = "at::rand({16})" 290 arg_map["dim"] = "0" 291 arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)" 292 arg_map["source"] = "at::rand({16})" 293 arg_map["alpha"] = "2" 294 return 295 if op_name == "index_copy": 296 if index == 0: 297 arg_map["self"] = "at::rand({2})" 298 arg_map["dim"] = "0" 299 arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)" 300 arg_map["source"] = "at::rand({2})" 301 else: 302 arg_map["self"] = "at::rand({32})" 303 arg_map["dim"] = "0" 304 arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)" 305 arg_map["source"] = "at::rand({32})" 306 return 307 if op_name == "linalg_cross": 308 if index == 0: 309 arg_map["self"] = "at::rand({6, 3, 6})" 310 arg_map["other"] = "at::rand({6, 3, 6})" 311 arg_map["dim"] = "1" 312 else: 313 arg_map["self"] = "at::rand({22, 3, 22})" 314 arg_map["other"] = "at::rand({22, 3, 22})" 315 arg_map["dim"] = "1" 316 return 317 if op_name == "nll_loss_backward": 318 if index == 0: 319 arg_map["grad_output"] = "at::rand({})" 320 arg_map["self"] = "at::rand({6})" 321 arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)" 322 arg_map["weight"] = "at::rand({6})" 323 arg_map["reduction"] = "1" 324 arg_map["ignore_index"] = "1" 325 arg_map["total_weight"] = "at::rand({})" 326 else: 327 arg_map["grad_output"] = "at::rand({})" 328 arg_map["self"] = "at::rand({36})" 329 arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)" 330 arg_map["weight"] = "at::rand({36})" 331 arg_map["reduction"] = "1" 332 arg_map["ignore_index"] = "1" 333 arg_map["total_weight"] = "at::rand({})" 334 return 335 if op_name in ["scatter", "scatter_add", "_scatter_reduce"]: 336 if index == 0: 337 arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" 338 arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" 339 arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" 340 else: 341 arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" 342 arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)" 343 arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" 344 if "reduce" in arg_map: 345 arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"' 346 return 347 if op_name == "scatter_reduce": 348 arg_map["reduce"] = '"mean"' 349 if index == 0: 350 arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" 351 else: 352 arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" 353 return 354 if op_name == "special_zeta": 355 if index == 0: 356 arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" 357 arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" 358 else: 359 arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" 360 arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" 361 return 362 if op_name == "_convert_indices_from_csr_to_coo": 363 if index == 0: 364 arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)" 365 arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)" 366 arg_map["out_int32"] = "false" 367 else: 368 arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)" 369 arg_map[ 370 "col_indices" 371 ] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)" 372 arg_map["out_int32"] = "false" 373 return 374 if op_name == "_convert_indices_from_coo_to_csr": 375 if index == 0: 376 arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)" 377 arg_map["size"] = "10" 378 arg_map["out_int32"] = "false" 379 else: 380 arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)" 381 arg_map["size"] = "24" 382 arg_map["out_int32"] = "false" 383 return 384 if op_name in ("diagonal", "linalg_diagonal"): 385 arg_map["offset"] = "0" 386 arg_map["dim1"] = "2" 387 arg_map["dim2"] = "1" 388 return 389