xref: /aosp_15_r20/external/pytorch/torchgen/static_runtime/generator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import json
4import logging
5import math
6from typing import Sequence
7
8import torchgen.api.cpp as cpp
9from torchgen.context import native_function_manager
10from torchgen.model import (
11    Argument,
12    BackendIndex,
13    BaseTy,
14    BaseType,
15    FunctionSchema,
16    NativeFunctionsGroup,
17    NativeFunctionsViewGroup,
18    OptionalType,
19    SelfArgument,
20    TensorOptionsArguments,
21    Type,
22)
23from torchgen.static_runtime import config
24
25
26logger: logging.Logger = logging.getLogger()
27
28
29def has_alias(
30    arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
31) -> bool:
32    for arg in arguments:
33        annotation = getattr(arg, "annotation", None)
34        if not annotation:
35            continue
36        alias_set = getattr(annotation, "alias_set", ())
37        if alias_set:
38            return True
39    return False
40
41
42BLOCKED_OPS = frozenset(
43    (
44        # non cpu ops
45        "sparse_sampled_addmm",
46        "hspmm",
47        "linalg_svdvals",
48        # sparse ops
49        "sspaddmm",
50        "coalesce",
51        "_indices",
52        "indices",
53        "_values",
54        "values",
55        "crow_indices",
56        "col_indices",
57        # deprecated ops
58        "floor_divide",
59        "ger",
60        # buggy ops
61        "conj_physical",  # P495807361
62        "binary_cross_entropy",  # P496394764
63        "arccosh",
64        # uncommon ops
65        "cholesky",
66        "lu_solve",
67        "linalg_cholesky",
68        "linalg_householder_product",
69        "linalg_ldl_solve",
70        "_compute_linear_combination",
71        # training related ops
72        "_make_dual",
73        # cannot call directly
74        "_fw_primal",
75        # no documentation
76        "_index_reduce",
77        # TODO: these ones got added recently and need manual inspection
78        "_new_zeros_with_same_feature_meta",
79        "_conj_physical",
80        "binary_cross_entropy_with_logits",
81        "bincount",
82        "conv_tbc",
83        "copy",
84        "_copy_from",
85        "_copy_from_and_resize",
86        "count_nonzero",
87        "cudnn_affine_grid_generator",
88        "cudnn_affine_grid_generator_backward",
89        "cudnn_grid_sampler",
90        "diag_embed",
91        "embedding",
92        "embedding_dense_backward",
93        "_embedding_bag_dense_backward",
94        "_embedding_bag_per_sample_weights_backward",
95        "grid_sampler_2d",
96        "_grid_sampler_2d_cpu_fallback",
97        "grid_sampler_3d",
98        "isnan",
99        "mkldnn_linear",
100        "median",
101        "nanmedian",
102        "_sparse_sparse_matmul",
103        "batch_norm_backward_elemt",
104        "_euclidean_dist",
105        "pixel_shuffle",
106        "pixel_unshuffle",
107        "channel_shuffle",
108        "_reshape_nested_backward",
109        "relu",
110        "prelu",
111        "celu",
112        "slice_scatter",
113        "select_scatter",
114        "diagonal_scatter",
115        "sum",
116        "_mkldnn_transpose",
117        "_nested_tensor_from_mask",
118        "_nested_from_padded",
119        "_nested_tensor_size",
120        "_nested_from_padded_and_nested_example",
121        "_standard_gamma_grad",
122        "_dirichlet_grad",
123        "native_norm",
124        "_sparse_softmax",
125        "_sparse_softmax_backward_data",
126        "_sparse_log_softmax",
127        "_sparse_log_softmax_backward_data",
128        "zero",
129        "_sparse_addmm",
130        "sparse_mask",
131        "_sparse_mask_projection",
132        "_to_dense",
133        "_coalesce",
134        "_coalesced",
135        "copy_sparse_to_sparse",
136        "to_sparse",
137        "to_sparse_csr",
138        "to_sparse_csc",
139        "to_mkldnn",
140        "quantize_per_tensor_dynamic",
141        "quantize_per_channel",
142        "q_per_channel_scales",
143        "q_per_channel_zero_points",
144        "int_repr",
145        "_make_per_channel_quantized_tensor",
146        "set",
147        "lift",
148        "lift_fresh",
149        "lift_fresh_copy",
150        "masked_scatter",
151        "_masked_softmax",
152        "_masked_softmax_backward",
153        "put",
154        "index_reduce",
155        "trace",
156        "_cholesky_solve_helper",
157        "dist",
158        "max",
159        "_torch_cuda_cu_linker_symbol_op",
160        "glu_jvp",
161        "glu_backward_jvp",
162        "hardswish_backward",
163        "rrelu_with_noise_backward",
164        "mkldnn_adaptive_avg_pool2d_backward",
165        "_adaptive_avg_pool2d_backward",
166        "_adaptive_avg_pool3d_backward",
167        "isinf",
168        "linalg_lu_solve",
169        "linalg_vecdot",
170        "linalg_matrix_exp",
171        "linalg_eigvalsh",
172        "_test_warn_in_autograd",
173        "_test_autograd_multiple_dispatch_view",
174        "_test_autograd_multiple_dispatch_view_copy",
175        "_segment_reduce",
176        "_segment_reduce_backward",
177        "_fw_primal_copy",
178        "_make_dual_copy",
179        "view_as_real_copy",
180        "view_as_complex_copy",
181        "_conj_copy",
182        "_neg_view_copy",
183        "diagonal_copy",
184        "detach_copy",
185        "squeeze_copy",
186        "t_copy",
187        "unsqueeze_copy",
188        "_indices_copy",
189        "_values_copy",
190        "indices_copy",
191        "values_copy",
192        "crow_indices_copy",
193        "col_indices_copy",
194        "ccol_indices",
195        "ccol_indices_copy",
196        "row_indices",
197        "row_indices_copy",
198        "unfold_copy",
199        "alias_copy",
200        "_triton_multi_head_attention",
201        "special_airy_ai",
202        "special_bessel_j0",
203        "special_bessel_j1",
204        "special_bessel_y0",
205        "special_bessel_y1",
206        "special_chebyshev_polynomial_t",
207        "special_chebyshev_polynomial_u",
208        "special_chebyshev_polynomial_v",
209        "special_chebyshev_polynomial_w",
210        "special_hermite_polynomial_h",
211        "special_hermite_polynomial_he",
212        "special_laguerre_polynomial_l",
213        "special_legendre_polynomial_p",
214        "special_modified_bessel_i0",
215        "special_modified_bessel_i1",
216        "special_modified_bessel_k0",
217        "special_modified_bessel_k1",
218        "special_scaled_modified_bessel_k0",
219        "special_scaled_modified_bessel_k1",
220        "special_shifted_chebyshev_polynomial_t",
221        "special_shifted_chebyshev_polynomial_u",
222        "special_shifted_chebyshev_polynomial_v",
223        "special_shifted_chebyshev_polynomial_w",
224        "special_spherical_bessel_j0",
225        "_foobar",
226        "_nested_tensor_strides",
227        "_nested_tensor_storage_offsets",
228        "_nested_get_values",  # no CPU backend
229        "_nested_get_values_copy",  # no CPU backend
230        "_nested_view_from_jagged",  # testing needs to be patched
231        "_nested_view_from_jagged_copy",  # testing needs to be patched
232        "_nested_view_from_buffer",  # testing needs to be patched
233        "_nested_view_from_buffer_copy",  # testing needs to be patched
234        "_int_mm",  # testing needs to be patched
235        "_to_sparse_csc",  # testing needs to be patched
236        "_to_sparse_csr",  # testing needs to be patched
237        "segment_reduce",  # testing needs to be patched
238    )
239)
240
241
242def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
243    base_op_name = ""
244    func = None
245    if isinstance(g, NativeFunctionsViewGroup):
246        base_op_name = g.view.root_name
247        func = g.view.func
248    else:
249        base_op_name = g.out.func.name.name.base
250        func = g.out.func
251    if config.is_hand_written(g):
252        logger.info("HAND WRITTEN: %s", base_op_name)
253        return False
254    if base_op_name in BLOCKED_OPS:
255        logger.info("BLOCKED: %s", base_op_name)
256        return False
257    for arg in func.schema_order_arguments():
258        maybe_method = ivalue_type_conversion_method(arg.type)
259        if not maybe_method:
260            # Type converting is unsupported yet.
261            logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func)
262            return False
263
264    if isinstance(g, NativeFunctionsViewGroup):
265        # TODO: stop doing type tests by converting to C++ and then testing
266        # the string, just test the dang thing directly
267        if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
268            # Returns a non-Tensor value.
269            logger.info("NON-TENSOR RET TYPE: %s", str(func))
270            return False
271        return True
272
273    # For out variant ops, we need to check the arguments of its functional func.
274    for arg in g.functional.func.schema_order_arguments():
275        maybe_method = ivalue_type_conversion_method(arg.type)
276        if not maybe_method:
277            # Type converting is unsupported yet.
278            logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func)
279            return False
280
281    if not g.structured:
282        # In case of unstructured op, we check if it has out variant implementation.
283        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
284        # parameter.
285        if (
286            not hasattr(g, "out")
287            or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
288            or not str(func.name).endswith(".out")
289        ):
290            return False
291    # TODO: stop type testing by converting to C++
292    if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
293        logger.info("NON_TENSOR RET TYPE: %s", func)
294        return False
295    if has_alias(func.arguments.non_out):
296        # This op may create an alias of inputs.
297        logger.info("INPUTS ALIAS: %s", base_op_name)
298        return False
299    return True
300
301
302def ivalue_type_conversion_method(
303    arg_type: BaseType | OptionalType | Type,
304) -> tuple[bool, str] | None:
305    """
306    Return the method call expression of `c10::ivalue' to convert its contained value to
307    the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
308    this function returns ".toTensor()", so that it can be appended to the ivalue's
309    variable name to get the value of the expected type.
310    """
311    type_conversion_methods = {
312        BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
313        BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
314        BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
315        BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
316        BaseTy.ScalarType: (
317            (False, "toScalarType()"),
318            (False, "toOptional<at::ScalarType>()"),
319        ),
320        BaseTy.str: (
321            (False, "toStringView()"),
322            (False, "toOptional<c10::string_view>()"),
323        ),
324    }
325
326    base_ty_object = None
327    if isinstance(arg_type, BaseType):
328        base_ty_object = arg_type.name
329    elif isinstance(arg_type, OptionalType):
330        if not isinstance(arg_type.elem, BaseType):
331            # ListType is currently unsupported.
332            return None
333        base_ty_object = arg_type.elem.name
334    else:
335        return None
336
337    if base_ty_object not in type_conversion_methods:
338        return None
339    methods = type_conversion_methods[base_ty_object]
340    if isinstance(arg_type, BaseType):
341        return methods[0]
342    return methods[1]
343
344
345should_use_int_tensor_ops_ = frozenset(
346    (
347        "bitwise_not",
348        "bitwise_and",
349        "bitwise_or",
350        "bitwise_xor",
351        "bitwise_left_shift",
352        "bitwise_right_shift",
353        "gcd",
354        "lcm",
355        "scatter",
356        "gather",
357        "_convert_indices_from_coo_to_csr",
358        "_convert_indices_from_csr_to_coo",
359    )
360)
361should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
362
363
364def should_use_int_tensor(op_name: str) -> bool:
365    return op_name in should_use_int_tensor_ops_
366
367
368def should_use_complex_tensor(op_name: str) -> bool:
369    return op_name in should_use_complex_tensor_ops_
370
371
372test_tensor_dim_ops_1_ = frozenset(
373    (
374        "addmv",
375        "index_add",
376        "_convert_indices_from_coo_to_csr",
377        "_convert_indices_from_csr_to_coo",
378        "nll_loss_backward",
379        "dot",
380        "vdot",
381        "outer",
382        "ger",
383    )
384)
385test_tensor_dim_ops_2_ = frozenset(
386    ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
387)
388
389
390def test_tensor_dim(op_name: str) -> int:
391    if op_name in test_tensor_dim_ops_1_:
392        return 1
393    if op_name in test_tensor_dim_ops_2_:
394        return 2
395    return 3
396
397
398test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
399test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
400
401
402def test_tensor_shape(op_name: str) -> str:
403    if op_name in test_tensor_shape_json:
404        return test_tensor_shape_json[op_name]
405    else:
406        return ""
407
408
409def test_value_expression(
410    arg_type: BaseType | OptionalType | Type, index: int, op_name: str
411) -> str:
412    tensor_size_ex = test_tensor_shape(op_name)
413    if tensor_size_ex == "":
414        num_tensors = 16 if index == 0 else 64
415        num_dim = test_tensor_dim(op_name)
416        size_per_dim = math.ceil(num_tensors / float(num_dim))
417        size_per_dim += size_per_dim % 2
418        tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim))
419    if should_use_int_tensor(op_name):
420        tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
421    elif should_use_complex_tensor(op_name):
422        tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
423    else:
424        tensor_expression = f"at::rand({tensor_size_ex})"
425
426    value_expressions = {
427        BaseTy.Tensor: tensor_expression,
428        BaseTy.int: "1",
429        BaseTy.bool: "false",
430        BaseTy.Scalar: "2",
431        BaseTy.ScalarType: "at::ScalarType::Float",
432        BaseTy.str: '"floor"',
433    }
434
435    base_ty_object = None
436    if isinstance(arg_type, BaseType):
437        base_ty_object = arg_type.name
438    else:
439        assert isinstance(arg_type, OptionalType) and isinstance(
440            arg_type.elem, BaseType
441        )
442        base_ty_object = arg_type.elem.name
443    assert base_ty_object in value_expressions, "not expected type"
444    value_expression = value_expressions[base_ty_object]
445    return value_expression
446
447
448def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
449    assert not schema.is_out_fn()
450    schema_name = schema.name.name.base
451    arg_map = {}
452    for arg in schema.schema_order_arguments():
453        test_value_exp = test_value_expression(arg.type, index, schema_name)
454        arg_map[arg.name] = test_value_exp
455    config.override_test_values(arg_map, schema_name, index)
456    arg_populations = []
457    for arg_name, arg_value in arg_map.items():
458        arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
459    return ";\n    ".join(arg_populations) + ";"
460
461
462def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
463    assert not schema.is_out_fn()
464    return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
465
466
467generate_test_ir_arguments_base_ty_to_type_str_ = {
468    BaseTy.Tensor: "Tensor",
469    BaseTy.int: "int",
470    BaseTy.float: "float",
471    BaseTy.str: "str",
472    BaseTy.Scalar: "int",
473    BaseTy.ScalarType: "int",
474    BaseTy.bool: "bool",
475}
476
477
478def generate_test_ir_arguments(
479    schema: FunctionSchema,
480) -> list[tuple[str, str | None]]:
481    def ir_argument(arg: Argument) -> tuple[str, str | None]:
482        t = arg.type
483        add_optional = False
484        if isinstance(t, OptionalType):
485            t = t.elem
486            add_optional = True
487        assert isinstance(t, BaseType)
488        type_str = None
489        if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
490            type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
491        if type_str and add_optional:
492            type_str = f"{type_str}?"
493        return ("%" + arg.name, type_str)
494
495    return [ir_argument(arg) for arg in schema.schema_order_arguments()]
496
497
498def generate_arg_extraction(schema: FunctionSchema) -> str:
499    arg_populations = []
500    for i, arg in enumerate(schema.schema_order_arguments()):
501        maybe_method = ivalue_type_conversion_method(arg.type)
502        assert maybe_method
503        is_reference, type_conversion_method = maybe_method
504        reference = "&" if is_reference else ""
505        arg_populations.append(
506            f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
507        )
508    return ";\n    ".join(arg_populations) + ";"
509
510
511def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
512    kernel = backend_index.get_kernel(g.functional)
513    if g.structured or kernel is None:
514        return cpp.name(g.functional.func)
515    return kernel.kernel
516
517
518def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
519    kernel = backend_index.get_kernel(g.out)
520    if g.structured or kernel is None:
521        return cpp.name(g.out.func)
522    return kernel.kernel
523
524
525def generate_non_out_variant_call(
526    g: NativeFunctionsGroup, backend_index: BackendIndex
527) -> str:
528    schema = g.functional.func
529    assert not schema.is_out_fn()
530    kernel_name = get_kernel_name(g, backend_index)
531    arg_names = (arg.name for arg in schema.schema_order_arguments())
532    namespace_name = "cpu" if g.structured else "native"
533    return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
534
535
536def generate_call_to_view_ops(
537    g: NativeFunctionsViewGroup, backend_index: BackendIndex
538) -> str:
539    schema = g.view.func
540    kernel_name = cpp.name(schema)
541    kernel = backend_index.get_kernel(g.view)
542    if kernel:
543        kernel_name = kernel.kernel
544    arg_names = (arg.name for arg in schema.schema_order_arguments())
545    namespace_name = "native"
546    return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
547
548
549def generate_out_variant_call(
550    g: NativeFunctionsGroup, backend_index: BackendIndex
551) -> str:
552    schema = g.out.func
553    assert schema.is_out_fn()
554    arg_names = []
555    kernel_name = get_out_kernel_name(g, backend_index)
556    if g.structured:
557        # structured op starts with the output tensor argument.
558        arg_names = [out_arg.name for out_arg in schema.arguments.out]
559    else:
560        arg_names = []
561    for arg in schema.arguments.non_out:
562        if isinstance(arg, SelfArgument):
563            arg_names.append(arg.argument.name)
564        else:
565            assert isinstance(arg, Argument)
566            arg_names.append(arg.name)
567    if not g.structured:
568        assert len(schema.arguments.out) == 1
569        arg_names.append(schema.arguments.out[0].name)
570    cpp_arg_names = ",".join(arg_names)
571    namespace_name = "cpu" if g.structured else "native"
572    return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
573
574
575no_memory_resize_ops = frozenset(
576    (
577        "isin.Scalar_Tensor",
578        "index_add",
579        "dot",
580        "vdot",
581        "nuclear_norm",
582        "histc",
583        "l1_loss",
584        "multi_margin_loss",
585        "multilabel_margin_loss",
586        "nll_loss",
587        "nll_loss2d",
588        "prod",
589    )
590)
591
592
593def should_check_resize(schema: FunctionSchema) -> bool:
594    schema_str = str(schema)
595    type_variant_op_name = schema_str[: schema_str.find("(")]
596    return type_variant_op_name not in no_memory_resize_ops
597
598
599def op_name_from_group(g: NativeFunctionsGroup) -> str:
600    return g.functional.func.name.name.base
601
602
603class GenOpDispatcher:
604    def out_variant(
605        self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
606    ) -> str:
607        if not groups:
608            return ""
609        generated_type_variants = []
610        for g in groups:
611            with native_function_manager(g):
612                assert is_supported(g)
613                assert isinstance(g, NativeFunctionsGroup)
614                generated_type_variant = self.out_variant_op_generator(g, backend_index)
615                generated_type_variants.append(generated_type_variant)
616        op_name = op_name_from_group(groups[0])
617        body = "\n".join(generated_type_variants)
618        generated = f"""
619REGISTER_OPERATOR_FUNCTOR(
620    aten::{op_name},
621    aten_{op_name},
622    [](Node* n) -> SROperator {{
623      {body}
624      LogAndDumpSchema(n);
625      return nullptr;
626    }});
627"""
628        return generated
629
630    def view(
631        self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
632    ) -> str:
633        if not groups:
634            return ""
635        generated_type_variants = []
636        for g in groups:
637            with native_function_manager(g):
638                assert is_supported(g)
639                assert isinstance(g, NativeFunctionsViewGroup)
640                generated_type_variant = self.view_op_generator(g, backend_index)
641                generated_type_variants.append(generated_type_variant)
642        op_name = config.func_name_base_str(groups[0])
643        body = "\n".join(generated_type_variants)
644        generated = f"""
645REGISTER_NATIVE_OPERATOR_FUNCTOR(
646    aten::{op_name},
647    aten_{op_name},
648    [](Node* n) -> SROperator {{
649      {body}
650      LogAndDumpSchema(n);
651      return nullptr;
652    }});
653"""
654        return generated
655
656    def out_variant_op_generator(
657        self, g: NativeFunctionsGroup, backend_index: BackendIndex
658    ) -> str:
659        functional = g.functional
660        schema = str(functional.func)
661        populated_argument = generate_arg_extraction(g.functional.func)
662        functional_variant_call = generate_non_out_variant_call(g, backend_index)
663        assert len(g.out.func.arguments.out) == 1
664        out_variable_name = str(g.out.func.arguments.out[0].name)
665        out_variant_call = generate_out_variant_call(g, backend_index)
666        generated = f"""
667      if (n->matches(torch::schema("aten::{schema}"))) {{
668        return [](ProcessedNode* p_node) {{
669          {populated_argument}
670          if (p_node->Output(0).isNone()) {{
671            p_node->Output(0) = {functional_variant_call};
672            return;
673          }}
674          auto& {out_variable_name} = p_node->Output(0).toTensor();
675          fastResizeToZero({out_variable_name});
676          {out_variant_call};
677        }};
678      }}"""
679        return generated
680
681    def view_op_generator(
682        self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
683    ) -> str:
684        schema = str(g.view.func)
685        populated_argument = generate_arg_extraction(g.view.func)
686        functional_variant_call = generate_call_to_view_ops(g, backend_index)
687        generated = f"""
688      if (n->matches(torch::schema("aten::{schema}"))) {{
689        return [](ProcessedNode* p_node) {{
690          {populated_argument}
691            p_node->Output(0) = {functional_variant_call};
692        }};
693      }}"""
694        return generated
695
696
697class GenOpTestCase:
698    def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
699        if not groups:
700            return ""
701        generated_type_variants = []
702        for g in groups:
703            with native_function_manager(g):
704                assert is_supported(g)
705                assert isinstance(g, NativeFunctionsGroup)
706                generated_type_variant = self.out_variant_op_test_case_generator(g)
707                generated_type_variants.append(generated_type_variant)
708        return "\n".join(generated_type_variants)
709
710    def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
711        if not groups:
712            return ""
713        generated_type_variants = []
714        for g in groups:
715            with native_function_manager(g):
716                assert is_supported(g)
717                assert isinstance(g, NativeFunctionsViewGroup)
718                generated_type_variant = self.view_op_test_case_generator(g)
719                generated_type_variants.append(generated_type_variant)
720        return "\n".join(generated_type_variants)
721
722    def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
723        schema = g.functional.func
724        schema_str = str(schema)
725        assert schema_str.find("(") > 0
726        type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
727        op_name = op_name_from_group(g)
728        assert type_variant_op_name.startswith(op_name)
729
730        arg_types = generate_test_ir_arguments(schema)
731        arg_declarations = ", ".join(
732            (
733                arg_name if arg_type is None else f"{arg_name}: {arg_type}"
734                for arg_name, arg_type in arg_types
735            )
736        )
737        arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
738        assert (
739            len(schema.returns) == 1
740            and isinstance(schema.returns[0].type, BaseType)
741            and schema.returns[0].type.name is BaseTy.Tensor
742        )
743        test_value_definitions = generate_test_value_definitions(schema, 0)
744        test_value_names = generate_test_value_names(schema, 0)
745        test_value_definitions2 = generate_test_value_definitions(schema, 1)
746        test_value_names2 = generate_test_value_names(schema, 1)
747        check_resize = "true" if should_check_resize(schema) else "false"
748        generated = f"""
749TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
750  const std::string script = R"IR(
751    graph({arg_declarations}):
752        %bias: None = prim::Constant()
753        %ret = aten::{op_name}({arg_names})
754        %cloned = aten::clone(%ret, %bias)
755        return (%cloned)
756  )IR";
757
758  {test_value_definitions}
759  std::vector<IValue> args{{{test_value_names}}};
760  testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
761
762  {test_value_definitions2}
763  std::vector<IValue> args2{{{test_value_names2}}};
764  testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
765
766}}
767"""
768        return generated
769
770    def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
771        schema = g.view.func
772        schema_str = str(schema)
773        assert schema_str.find("(") > 0
774        type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
775        op_name = g.view.root_name
776        assert type_variant_op_name.startswith(op_name)
777
778        arg_types = generate_test_ir_arguments(schema)
779        arg_declarations = ", ".join(
780            (
781                arg_name if arg_type is None else f"{arg_name}: {arg_type}"
782                for arg_name, arg_type in arg_types
783            )
784        )
785        arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
786        assert (
787            len(schema.returns) == 1
788            and isinstance(schema.returns[0].type, BaseType)
789            and schema.returns[0].type.name is BaseTy.Tensor
790        )
791        test_value_definitions = generate_test_value_definitions(schema, 0)
792        test_value_names = generate_test_value_names(schema, 0)
793        generated = f"""
794TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
795  const std::string script = R"IR(
796    graph({arg_declarations}):
797        %bias: None = prim::Constant()
798        %ret = aten::{op_name}({arg_names})
799        %cloned = aten::clone(%ret, %bias)
800        return (%cloned)
801  )IR";
802
803  {test_value_definitions}
804  std::vector<IValue> args{{{test_value_names}}};
805  testStaticRuntime(script, args);
806}}
807"""
808
809        return generated
810