1from __future__ import annotations 2 3import json 4import logging 5import math 6from typing import Sequence 7 8import torchgen.api.cpp as cpp 9from torchgen.context import native_function_manager 10from torchgen.model import ( 11 Argument, 12 BackendIndex, 13 BaseTy, 14 BaseType, 15 FunctionSchema, 16 NativeFunctionsGroup, 17 NativeFunctionsViewGroup, 18 OptionalType, 19 SelfArgument, 20 TensorOptionsArguments, 21 Type, 22) 23from torchgen.static_runtime import config 24 25 26logger: logging.Logger = logging.getLogger() 27 28 29def has_alias( 30 arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments], 31) -> bool: 32 for arg in arguments: 33 annotation = getattr(arg, "annotation", None) 34 if not annotation: 35 continue 36 alias_set = getattr(annotation, "alias_set", ()) 37 if alias_set: 38 return True 39 return False 40 41 42BLOCKED_OPS = frozenset( 43 ( 44 # non cpu ops 45 "sparse_sampled_addmm", 46 "hspmm", 47 "linalg_svdvals", 48 # sparse ops 49 "sspaddmm", 50 "coalesce", 51 "_indices", 52 "indices", 53 "_values", 54 "values", 55 "crow_indices", 56 "col_indices", 57 # deprecated ops 58 "floor_divide", 59 "ger", 60 # buggy ops 61 "conj_physical", # P495807361 62 "binary_cross_entropy", # P496394764 63 "arccosh", 64 # uncommon ops 65 "cholesky", 66 "lu_solve", 67 "linalg_cholesky", 68 "linalg_householder_product", 69 "linalg_ldl_solve", 70 "_compute_linear_combination", 71 # training related ops 72 "_make_dual", 73 # cannot call directly 74 "_fw_primal", 75 # no documentation 76 "_index_reduce", 77 # TODO: these ones got added recently and need manual inspection 78 "_new_zeros_with_same_feature_meta", 79 "_conj_physical", 80 "binary_cross_entropy_with_logits", 81 "bincount", 82 "conv_tbc", 83 "copy", 84 "_copy_from", 85 "_copy_from_and_resize", 86 "count_nonzero", 87 "cudnn_affine_grid_generator", 88 "cudnn_affine_grid_generator_backward", 89 "cudnn_grid_sampler", 90 "diag_embed", 91 "embedding", 92 "embedding_dense_backward", 93 "_embedding_bag_dense_backward", 94 "_embedding_bag_per_sample_weights_backward", 95 "grid_sampler_2d", 96 "_grid_sampler_2d_cpu_fallback", 97 "grid_sampler_3d", 98 "isnan", 99 "mkldnn_linear", 100 "median", 101 "nanmedian", 102 "_sparse_sparse_matmul", 103 "batch_norm_backward_elemt", 104 "_euclidean_dist", 105 "pixel_shuffle", 106 "pixel_unshuffle", 107 "channel_shuffle", 108 "_reshape_nested_backward", 109 "relu", 110 "prelu", 111 "celu", 112 "slice_scatter", 113 "select_scatter", 114 "diagonal_scatter", 115 "sum", 116 "_mkldnn_transpose", 117 "_nested_tensor_from_mask", 118 "_nested_from_padded", 119 "_nested_tensor_size", 120 "_nested_from_padded_and_nested_example", 121 "_standard_gamma_grad", 122 "_dirichlet_grad", 123 "native_norm", 124 "_sparse_softmax", 125 "_sparse_softmax_backward_data", 126 "_sparse_log_softmax", 127 "_sparse_log_softmax_backward_data", 128 "zero", 129 "_sparse_addmm", 130 "sparse_mask", 131 "_sparse_mask_projection", 132 "_to_dense", 133 "_coalesce", 134 "_coalesced", 135 "copy_sparse_to_sparse", 136 "to_sparse", 137 "to_sparse_csr", 138 "to_sparse_csc", 139 "to_mkldnn", 140 "quantize_per_tensor_dynamic", 141 "quantize_per_channel", 142 "q_per_channel_scales", 143 "q_per_channel_zero_points", 144 "int_repr", 145 "_make_per_channel_quantized_tensor", 146 "set", 147 "lift", 148 "lift_fresh", 149 "lift_fresh_copy", 150 "masked_scatter", 151 "_masked_softmax", 152 "_masked_softmax_backward", 153 "put", 154 "index_reduce", 155 "trace", 156 "_cholesky_solve_helper", 157 "dist", 158 "max", 159 "_torch_cuda_cu_linker_symbol_op", 160 "glu_jvp", 161 "glu_backward_jvp", 162 "hardswish_backward", 163 "rrelu_with_noise_backward", 164 "mkldnn_adaptive_avg_pool2d_backward", 165 "_adaptive_avg_pool2d_backward", 166 "_adaptive_avg_pool3d_backward", 167 "isinf", 168 "linalg_lu_solve", 169 "linalg_vecdot", 170 "linalg_matrix_exp", 171 "linalg_eigvalsh", 172 "_test_warn_in_autograd", 173 "_test_autograd_multiple_dispatch_view", 174 "_test_autograd_multiple_dispatch_view_copy", 175 "_segment_reduce", 176 "_segment_reduce_backward", 177 "_fw_primal_copy", 178 "_make_dual_copy", 179 "view_as_real_copy", 180 "view_as_complex_copy", 181 "_conj_copy", 182 "_neg_view_copy", 183 "diagonal_copy", 184 "detach_copy", 185 "squeeze_copy", 186 "t_copy", 187 "unsqueeze_copy", 188 "_indices_copy", 189 "_values_copy", 190 "indices_copy", 191 "values_copy", 192 "crow_indices_copy", 193 "col_indices_copy", 194 "ccol_indices", 195 "ccol_indices_copy", 196 "row_indices", 197 "row_indices_copy", 198 "unfold_copy", 199 "alias_copy", 200 "_triton_multi_head_attention", 201 "special_airy_ai", 202 "special_bessel_j0", 203 "special_bessel_j1", 204 "special_bessel_y0", 205 "special_bessel_y1", 206 "special_chebyshev_polynomial_t", 207 "special_chebyshev_polynomial_u", 208 "special_chebyshev_polynomial_v", 209 "special_chebyshev_polynomial_w", 210 "special_hermite_polynomial_h", 211 "special_hermite_polynomial_he", 212 "special_laguerre_polynomial_l", 213 "special_legendre_polynomial_p", 214 "special_modified_bessel_i0", 215 "special_modified_bessel_i1", 216 "special_modified_bessel_k0", 217 "special_modified_bessel_k1", 218 "special_scaled_modified_bessel_k0", 219 "special_scaled_modified_bessel_k1", 220 "special_shifted_chebyshev_polynomial_t", 221 "special_shifted_chebyshev_polynomial_u", 222 "special_shifted_chebyshev_polynomial_v", 223 "special_shifted_chebyshev_polynomial_w", 224 "special_spherical_bessel_j0", 225 "_foobar", 226 "_nested_tensor_strides", 227 "_nested_tensor_storage_offsets", 228 "_nested_get_values", # no CPU backend 229 "_nested_get_values_copy", # no CPU backend 230 "_nested_view_from_jagged", # testing needs to be patched 231 "_nested_view_from_jagged_copy", # testing needs to be patched 232 "_nested_view_from_buffer", # testing needs to be patched 233 "_nested_view_from_buffer_copy", # testing needs to be patched 234 "_int_mm", # testing needs to be patched 235 "_to_sparse_csc", # testing needs to be patched 236 "_to_sparse_csr", # testing needs to be patched 237 "segment_reduce", # testing needs to be patched 238 ) 239) 240 241 242def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: 243 base_op_name = "" 244 func = None 245 if isinstance(g, NativeFunctionsViewGroup): 246 base_op_name = g.view.root_name 247 func = g.view.func 248 else: 249 base_op_name = g.out.func.name.name.base 250 func = g.out.func 251 if config.is_hand_written(g): 252 logger.info("HAND WRITTEN: %s", base_op_name) 253 return False 254 if base_op_name in BLOCKED_OPS: 255 logger.info("BLOCKED: %s", base_op_name) 256 return False 257 for arg in func.schema_order_arguments(): 258 maybe_method = ivalue_type_conversion_method(arg.type) 259 if not maybe_method: 260 # Type converting is unsupported yet. 261 logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func) 262 return False 263 264 if isinstance(g, NativeFunctionsViewGroup): 265 # TODO: stop doing type tests by converting to C++ and then testing 266 # the string, just test the dang thing directly 267 if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): 268 # Returns a non-Tensor value. 269 logger.info("NON-TENSOR RET TYPE: %s", str(func)) 270 return False 271 return True 272 273 # For out variant ops, we need to check the arguments of its functional func. 274 for arg in g.functional.func.schema_order_arguments(): 275 maybe_method = ivalue_type_conversion_method(arg.type) 276 if not maybe_method: 277 # Type converting is unsupported yet. 278 logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func) 279 return False 280 281 if not g.structured: 282 # In case of unstructured op, we check if it has out variant implementation. 283 # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last 284 # parameter. 285 if ( 286 not hasattr(g, "out") 287 or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)") 288 or not str(func.name).endswith(".out") 289 ): 290 return False 291 # TODO: stop type testing by converting to C++ 292 if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): 293 logger.info("NON_TENSOR RET TYPE: %s", func) 294 return False 295 if has_alias(func.arguments.non_out): 296 # This op may create an alias of inputs. 297 logger.info("INPUTS ALIAS: %s", base_op_name) 298 return False 299 return True 300 301 302def ivalue_type_conversion_method( 303 arg_type: BaseType | OptionalType | Type, 304) -> tuple[bool, str] | None: 305 """ 306 Return the method call expression of `c10::ivalue' to convert its contained value to 307 the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, 308 this function returns ".toTensor()", so that it can be appended to the ivalue's 309 variable name to get the value of the expected type. 310 """ 311 type_conversion_methods = { 312 BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")), 313 BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")), 314 BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")), 315 BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")), 316 BaseTy.ScalarType: ( 317 (False, "toScalarType()"), 318 (False, "toOptional<at::ScalarType>()"), 319 ), 320 BaseTy.str: ( 321 (False, "toStringView()"), 322 (False, "toOptional<c10::string_view>()"), 323 ), 324 } 325 326 base_ty_object = None 327 if isinstance(arg_type, BaseType): 328 base_ty_object = arg_type.name 329 elif isinstance(arg_type, OptionalType): 330 if not isinstance(arg_type.elem, BaseType): 331 # ListType is currently unsupported. 332 return None 333 base_ty_object = arg_type.elem.name 334 else: 335 return None 336 337 if base_ty_object not in type_conversion_methods: 338 return None 339 methods = type_conversion_methods[base_ty_object] 340 if isinstance(arg_type, BaseType): 341 return methods[0] 342 return methods[1] 343 344 345should_use_int_tensor_ops_ = frozenset( 346 ( 347 "bitwise_not", 348 "bitwise_and", 349 "bitwise_or", 350 "bitwise_xor", 351 "bitwise_left_shift", 352 "bitwise_right_shift", 353 "gcd", 354 "lcm", 355 "scatter", 356 "gather", 357 "_convert_indices_from_coo_to_csr", 358 "_convert_indices_from_csr_to_coo", 359 ) 360) 361should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj")) 362 363 364def should_use_int_tensor(op_name: str) -> bool: 365 return op_name in should_use_int_tensor_ops_ 366 367 368def should_use_complex_tensor(op_name: str) -> bool: 369 return op_name in should_use_complex_tensor_ops_ 370 371 372test_tensor_dim_ops_1_ = frozenset( 373 ( 374 "addmv", 375 "index_add", 376 "_convert_indices_from_coo_to_csr", 377 "_convert_indices_from_csr_to_coo", 378 "nll_loss_backward", 379 "dot", 380 "vdot", 381 "outer", 382 "ger", 383 ) 384) 385test_tensor_dim_ops_2_ = frozenset( 386 ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t") 387) 388 389 390def test_tensor_dim(op_name: str) -> int: 391 if op_name in test_tensor_dim_ops_1_: 392 return 1 393 if op_name in test_tensor_dim_ops_2_: 394 return 2 395 return 3 396 397 398test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' 399test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string) 400 401 402def test_tensor_shape(op_name: str) -> str: 403 if op_name in test_tensor_shape_json: 404 return test_tensor_shape_json[op_name] 405 else: 406 return "" 407 408 409def test_value_expression( 410 arg_type: BaseType | OptionalType | Type, index: int, op_name: str 411) -> str: 412 tensor_size_ex = test_tensor_shape(op_name) 413 if tensor_size_ex == "": 414 num_tensors = 16 if index == 0 else 64 415 num_dim = test_tensor_dim(op_name) 416 size_per_dim = math.ceil(num_tensors / float(num_dim)) 417 size_per_dim += size_per_dim % 2 418 tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim)) 419 if should_use_int_tensor(op_name): 420 tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)" 421 elif should_use_complex_tensor(op_name): 422 tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)" 423 else: 424 tensor_expression = f"at::rand({tensor_size_ex})" 425 426 value_expressions = { 427 BaseTy.Tensor: tensor_expression, 428 BaseTy.int: "1", 429 BaseTy.bool: "false", 430 BaseTy.Scalar: "2", 431 BaseTy.ScalarType: "at::ScalarType::Float", 432 BaseTy.str: '"floor"', 433 } 434 435 base_ty_object = None 436 if isinstance(arg_type, BaseType): 437 base_ty_object = arg_type.name 438 else: 439 assert isinstance(arg_type, OptionalType) and isinstance( 440 arg_type.elem, BaseType 441 ) 442 base_ty_object = arg_type.elem.name 443 assert base_ty_object in value_expressions, "not expected type" 444 value_expression = value_expressions[base_ty_object] 445 return value_expression 446 447 448def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: 449 assert not schema.is_out_fn() 450 schema_name = schema.name.name.base 451 arg_map = {} 452 for arg in schema.schema_order_arguments(): 453 test_value_exp = test_value_expression(arg.type, index, schema_name) 454 arg_map[arg.name] = test_value_exp 455 config.override_test_values(arg_map, schema_name, index) 456 arg_populations = [] 457 for arg_name, arg_value in arg_map.items(): 458 arg_populations.append(f"auto {arg_name}{index} = {arg_value}") 459 return ";\n ".join(arg_populations) + ";" 460 461 462def generate_test_value_names(schema: FunctionSchema, index: int) -> str: 463 assert not schema.is_out_fn() 464 return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments()) 465 466 467generate_test_ir_arguments_base_ty_to_type_str_ = { 468 BaseTy.Tensor: "Tensor", 469 BaseTy.int: "int", 470 BaseTy.float: "float", 471 BaseTy.str: "str", 472 BaseTy.Scalar: "int", 473 BaseTy.ScalarType: "int", 474 BaseTy.bool: "bool", 475} 476 477 478def generate_test_ir_arguments( 479 schema: FunctionSchema, 480) -> list[tuple[str, str | None]]: 481 def ir_argument(arg: Argument) -> tuple[str, str | None]: 482 t = arg.type 483 add_optional = False 484 if isinstance(t, OptionalType): 485 t = t.elem 486 add_optional = True 487 assert isinstance(t, BaseType) 488 type_str = None 489 if t.name in generate_test_ir_arguments_base_ty_to_type_str_: 490 type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] 491 if type_str and add_optional: 492 type_str = f"{type_str}?" 493 return ("%" + arg.name, type_str) 494 495 return [ir_argument(arg) for arg in schema.schema_order_arguments()] 496 497 498def generate_arg_extraction(schema: FunctionSchema) -> str: 499 arg_populations = [] 500 for i, arg in enumerate(schema.schema_order_arguments()): 501 maybe_method = ivalue_type_conversion_method(arg.type) 502 assert maybe_method 503 is_reference, type_conversion_method = maybe_method 504 reference = "&" if is_reference else "" 505 arg_populations.append( 506 f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" 507 ) 508 return ";\n ".join(arg_populations) + ";" 509 510 511def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: 512 kernel = backend_index.get_kernel(g.functional) 513 if g.structured or kernel is None: 514 return cpp.name(g.functional.func) 515 return kernel.kernel 516 517 518def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: 519 kernel = backend_index.get_kernel(g.out) 520 if g.structured or kernel is None: 521 return cpp.name(g.out.func) 522 return kernel.kernel 523 524 525def generate_non_out_variant_call( 526 g: NativeFunctionsGroup, backend_index: BackendIndex 527) -> str: 528 schema = g.functional.func 529 assert not schema.is_out_fn() 530 kernel_name = get_kernel_name(g, backend_index) 531 arg_names = (arg.name for arg in schema.schema_order_arguments()) 532 namespace_name = "cpu" if g.structured else "native" 533 return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' 534 535 536def generate_call_to_view_ops( 537 g: NativeFunctionsViewGroup, backend_index: BackendIndex 538) -> str: 539 schema = g.view.func 540 kernel_name = cpp.name(schema) 541 kernel = backend_index.get_kernel(g.view) 542 if kernel: 543 kernel_name = kernel.kernel 544 arg_names = (arg.name for arg in schema.schema_order_arguments()) 545 namespace_name = "native" 546 return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' 547 548 549def generate_out_variant_call( 550 g: NativeFunctionsGroup, backend_index: BackendIndex 551) -> str: 552 schema = g.out.func 553 assert schema.is_out_fn() 554 arg_names = [] 555 kernel_name = get_out_kernel_name(g, backend_index) 556 if g.structured: 557 # structured op starts with the output tensor argument. 558 arg_names = [out_arg.name for out_arg in schema.arguments.out] 559 else: 560 arg_names = [] 561 for arg in schema.arguments.non_out: 562 if isinstance(arg, SelfArgument): 563 arg_names.append(arg.argument.name) 564 else: 565 assert isinstance(arg, Argument) 566 arg_names.append(arg.name) 567 if not g.structured: 568 assert len(schema.arguments.out) == 1 569 arg_names.append(schema.arguments.out[0].name) 570 cpp_arg_names = ",".join(arg_names) 571 namespace_name = "cpu" if g.structured else "native" 572 return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" 573 574 575no_memory_resize_ops = frozenset( 576 ( 577 "isin.Scalar_Tensor", 578 "index_add", 579 "dot", 580 "vdot", 581 "nuclear_norm", 582 "histc", 583 "l1_loss", 584 "multi_margin_loss", 585 "multilabel_margin_loss", 586 "nll_loss", 587 "nll_loss2d", 588 "prod", 589 ) 590) 591 592 593def should_check_resize(schema: FunctionSchema) -> bool: 594 schema_str = str(schema) 595 type_variant_op_name = schema_str[: schema_str.find("(")] 596 return type_variant_op_name not in no_memory_resize_ops 597 598 599def op_name_from_group(g: NativeFunctionsGroup) -> str: 600 return g.functional.func.name.name.base 601 602 603class GenOpDispatcher: 604 def out_variant( 605 self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex 606 ) -> str: 607 if not groups: 608 return "" 609 generated_type_variants = [] 610 for g in groups: 611 with native_function_manager(g): 612 assert is_supported(g) 613 assert isinstance(g, NativeFunctionsGroup) 614 generated_type_variant = self.out_variant_op_generator(g, backend_index) 615 generated_type_variants.append(generated_type_variant) 616 op_name = op_name_from_group(groups[0]) 617 body = "\n".join(generated_type_variants) 618 generated = f""" 619REGISTER_OPERATOR_FUNCTOR( 620 aten::{op_name}, 621 aten_{op_name}, 622 [](Node* n) -> SROperator {{ 623 {body} 624 LogAndDumpSchema(n); 625 return nullptr; 626 }}); 627""" 628 return generated 629 630 def view( 631 self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex 632 ) -> str: 633 if not groups: 634 return "" 635 generated_type_variants = [] 636 for g in groups: 637 with native_function_manager(g): 638 assert is_supported(g) 639 assert isinstance(g, NativeFunctionsViewGroup) 640 generated_type_variant = self.view_op_generator(g, backend_index) 641 generated_type_variants.append(generated_type_variant) 642 op_name = config.func_name_base_str(groups[0]) 643 body = "\n".join(generated_type_variants) 644 generated = f""" 645REGISTER_NATIVE_OPERATOR_FUNCTOR( 646 aten::{op_name}, 647 aten_{op_name}, 648 [](Node* n) -> SROperator {{ 649 {body} 650 LogAndDumpSchema(n); 651 return nullptr; 652 }}); 653""" 654 return generated 655 656 def out_variant_op_generator( 657 self, g: NativeFunctionsGroup, backend_index: BackendIndex 658 ) -> str: 659 functional = g.functional 660 schema = str(functional.func) 661 populated_argument = generate_arg_extraction(g.functional.func) 662 functional_variant_call = generate_non_out_variant_call(g, backend_index) 663 assert len(g.out.func.arguments.out) == 1 664 out_variable_name = str(g.out.func.arguments.out[0].name) 665 out_variant_call = generate_out_variant_call(g, backend_index) 666 generated = f""" 667 if (n->matches(torch::schema("aten::{schema}"))) {{ 668 return [](ProcessedNode* p_node) {{ 669 {populated_argument} 670 if (p_node->Output(0).isNone()) {{ 671 p_node->Output(0) = {functional_variant_call}; 672 return; 673 }} 674 auto& {out_variable_name} = p_node->Output(0).toTensor(); 675 fastResizeToZero({out_variable_name}); 676 {out_variant_call}; 677 }}; 678 }}""" 679 return generated 680 681 def view_op_generator( 682 self, g: NativeFunctionsViewGroup, backend_index: BackendIndex 683 ) -> str: 684 schema = str(g.view.func) 685 populated_argument = generate_arg_extraction(g.view.func) 686 functional_variant_call = generate_call_to_view_ops(g, backend_index) 687 generated = f""" 688 if (n->matches(torch::schema("aten::{schema}"))) {{ 689 return [](ProcessedNode* p_node) {{ 690 {populated_argument} 691 p_node->Output(0) = {functional_variant_call}; 692 }}; 693 }}""" 694 return generated 695 696 697class GenOpTestCase: 698 def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str: 699 if not groups: 700 return "" 701 generated_type_variants = [] 702 for g in groups: 703 with native_function_manager(g): 704 assert is_supported(g) 705 assert isinstance(g, NativeFunctionsGroup) 706 generated_type_variant = self.out_variant_op_test_case_generator(g) 707 generated_type_variants.append(generated_type_variant) 708 return "\n".join(generated_type_variants) 709 710 def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: 711 if not groups: 712 return "" 713 generated_type_variants = [] 714 for g in groups: 715 with native_function_manager(g): 716 assert is_supported(g) 717 assert isinstance(g, NativeFunctionsViewGroup) 718 generated_type_variant = self.view_op_test_case_generator(g) 719 generated_type_variants.append(generated_type_variant) 720 return "\n".join(generated_type_variants) 721 722 def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str: 723 schema = g.functional.func 724 schema_str = str(schema) 725 assert schema_str.find("(") > 0 726 type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") 727 op_name = op_name_from_group(g) 728 assert type_variant_op_name.startswith(op_name) 729 730 arg_types = generate_test_ir_arguments(schema) 731 arg_declarations = ", ".join( 732 ( 733 arg_name if arg_type is None else f"{arg_name}: {arg_type}" 734 for arg_name, arg_type in arg_types 735 ) 736 ) 737 arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) 738 assert ( 739 len(schema.returns) == 1 740 and isinstance(schema.returns[0].type, BaseType) 741 and schema.returns[0].type.name is BaseTy.Tensor 742 ) 743 test_value_definitions = generate_test_value_definitions(schema, 0) 744 test_value_names = generate_test_value_names(schema, 0) 745 test_value_definitions2 = generate_test_value_definitions(schema, 1) 746 test_value_names2 = generate_test_value_names(schema, 1) 747 check_resize = "true" if should_check_resize(schema) else "false" 748 generated = f""" 749TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ 750 const std::string script = R"IR( 751 graph({arg_declarations}): 752 %bias: None = prim::Constant() 753 %ret = aten::{op_name}({arg_names}) 754 %cloned = aten::clone(%ret, %bias) 755 return (%cloned) 756 )IR"; 757 758 {test_value_definitions} 759 std::vector<IValue> args{{{test_value_names}}}; 760 testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); 761 762 {test_value_definitions2} 763 std::vector<IValue> args2{{{test_value_names2}}}; 764 testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); 765 766}} 767""" 768 return generated 769 770 def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str: 771 schema = g.view.func 772 schema_str = str(schema) 773 assert schema_str.find("(") > 0 774 type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") 775 op_name = g.view.root_name 776 assert type_variant_op_name.startswith(op_name) 777 778 arg_types = generate_test_ir_arguments(schema) 779 arg_declarations = ", ".join( 780 ( 781 arg_name if arg_type is None else f"{arg_name}: {arg_type}" 782 for arg_name, arg_type in arg_types 783 ) 784 ) 785 arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) 786 assert ( 787 len(schema.returns) == 1 788 and isinstance(schema.returns[0].type, BaseType) 789 and schema.returns[0].type.name is BaseTy.Tensor 790 ) 791 test_value_definitions = generate_test_value_definitions(schema, 0) 792 test_value_names = generate_test_value_names(schema, 0) 793 generated = f""" 794TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ 795 const std::string script = R"IR( 796 graph({arg_declarations}): 797 %bias: None = prim::Constant() 798 %ret = aten::{op_name}({arg_names}) 799 %cloned = aten::clone(%ret, %bias) 800 return (%cloned) 801 )IR"; 802 803 {test_value_definitions} 804 std::vector<IValue> args{{{test_value_names}}}; 805 testStaticRuntime(script, args); 806}} 807""" 808 809 return generated 810