xref: /aosp_15_r20/external/pytorch/tools/test/test_codegen_model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: codegen"]
2
3import textwrap
4import unittest
5from typing import cast
6
7import expecttest
8import yaml
9
10import torchgen.dest as dest
11import torchgen.gen as gen
12from torchgen.gen import LineLoader, parse_native_yaml_struct
13from torchgen.model import (
14    Annotation,
15    CustomClassType,
16    DispatchKey,
17    NativeFunctionsGroup,
18    Type,
19)
20
21
22class TestCodegenModel(expecttest.TestCase):
23    def assertParseErrorInline(self, yaml_str: str, expect: str) -> None:
24        es = yaml.load(yaml_str, Loader=LineLoader)
25        try:
26            parse_native_yaml_struct(es, set())
27        except AssertionError as e:
28            # hack to strip out the context
29            msg, _ = str(e).split("  in ", 2)
30            self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1)
31            return
32        self.fail(msg="Did not raise when expected to")
33
34    def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None:
35        # parse a single structured group out of the yaml to g
36        es = yaml.load(yaml_str, Loader=LineLoader)
37        parsed_yaml = parse_native_yaml_struct(es, set())
38        native_functions, backend_indices = (
39            parsed_yaml.native_functions,
40            parsed_yaml.backend_indices,
41        )
42        grouped_native_functions = gen.get_grouped_native_functions(native_functions)
43        assert len(grouped_native_functions) == 1
44        g = grouped_native_functions[0]
45        assert isinstance(g, NativeFunctionsGroup)
46        assert g.out.ufunc_inner_loop
47        # this is not ufunc codegen per se, but it does some basic sanity tests for
48        # ufunc generation
49        gen.compute_meta_function_declaration(g)
50        dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CPU])
51        dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CUDA])
52        try:
53            # the real kahuna
54            dest.compute_ufunc_cpu(g)
55            dest.compute_ufunc_cpu_kernel(g)
56            dest.compute_ufunc_cuda(g)
57        except AssertionError as e:
58            # hack to strip out the context
59            msg, _ = str(e).split("  in ", 2)
60            self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1)
61            return
62        self.fail(msg="Did not raise when expected to")
63
64    # NB: indent is hardcoded to be two here, so format your yaml accordingly
65    binop_out = (
66        "func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
67    )
68    ti_binop_out = f"""{binop_out}
69  structured: True
70  structured_inherits: TensorIteratorBase"""
71    ti_binop = """func: binop(Tensor self, Tensor other) -> Tensor
72  structured_delegate: binop.out
73"""
74
75    ti_unop_out = """func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
76  structured: True
77  structured_inherits: TensorIteratorBase"""
78    ti_unop = """func: unop(Tensor self) -> Tensor
79  structured_delegate: unop.out
80"""
81
82    def test_nonstructured_ufunc(self) -> None:
83        yaml_str = f"""\
84- {self.binop_out}
85  ufunc_inner_loop:
86    Generic: binop (Bool)
87"""
88        self.assertParseErrorInline(
89            yaml_str,
90            """\
91ufunc must be structured""",
92        )
93
94    def test_overlapping_ufunc_and_dispatch(self) -> None:
95        yaml_str = f"""\
96- {self.ti_binop_out}
97  ufunc_inner_loop:
98    Generic: binop (Bool)
99  dispatch:
100    CPU: binop_cpu
101"""
102        self.assertParseErrorInline(
103            yaml_str,
104            """\
105ufunc should not have explicit dispatch entry for CPU""",
106        )
107
108    # See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456
109    @unittest.expectedFailure
110    def test_scalaronly_shadowed(self) -> None:
111        yaml_str = f"""\
112- {self.ti_binop_out}
113  ufunc_inner_loop:
114    Generic: binop (Bool)
115    ScalarOnly: binop (Bool)
116"""
117        self.assertParseErrorInline(
118            yaml_str,
119            """\
120""",
121        )
122
123    def test_conflicting_ufunc(self) -> None:
124        yaml_str = f"""\
125- {self.ti_binop_out}
126  ufunc_inner_loop:
127    Generic: binop (Bool)
128    ScalarOnly: binop_scalar (Bool)
129- {self.ti_binop}
130"""
131        self.assertUfuncErrorInline(
132            yaml_str,
133            """\
134ScalarOnly and Generic must have same ufunc name""",
135        )
136
137    def test_invalid_cudafunctoronself_for_binary_op(self) -> None:
138        yaml_str = f"""\
139- {self.ti_unop_out}
140  ufunc_inner_loop:
141    Generic: unop (All)
142    CUDAFunctorOnSelf: unop_self_cuda (All)
143- {self.ti_unop}
144"""
145        self.assertUfuncErrorInline(
146            yaml_str,
147            """\
148cannot use CUDAFunctorOnSelf on non-binary function""",
149        )
150
151    def test_parse_custom_class_type(self) -> None:
152        custom_class_name = "namespace_foo.class_bar"
153        custom_class_name_with_prefix = f"__torch__.torch.classes.{custom_class_name}"
154        custom_class_type = cast(
155            CustomClassType, Type.parse(custom_class_name_with_prefix)
156        )
157        self.assertTrue(isinstance(custom_class_type, CustomClassType))
158        self.assertEqual(custom_class_name, custom_class_type.class_name)
159        self.assertEqual(custom_class_name_with_prefix, str(custom_class_type))
160
161
162class TestAnnotation(expecttest.TestCase):
163    def test_single_alias_no_write(self) -> None:
164        a = Annotation.parse("a")
165        self.assertEqual(a.alias_set, tuple("a"))
166        self.assertFalse(a.is_write)
167        self.assertEqual(a.alias_set_after, ())
168
169    def test_single_alias_is_write(self) -> None:
170        a = Annotation.parse("a!")
171        self.assertEqual(a.alias_set, tuple("a"))
172        self.assertTrue(a.is_write)
173        self.assertEqual(a.alias_set_after, ())
174
175    def test_single_alias_is_write_to_wildcard(self) -> None:
176        a = Annotation.parse("a! -> *")
177        self.assertEqual(a.alias_set, tuple("a"))
178        self.assertTrue(a.is_write)
179        self.assertEqual(a.alias_set_after, tuple("*"))
180
181    def test_alias_set(self) -> None:
182        a = Annotation.parse("a|b")
183        self.assertEqual(a.alias_set, ("a", "b"))
184
185    def test_alias_set_is_write_raises_exception(self) -> None:
186        with self.assertRaisesRegex(
187            AssertionError, r"alias set larger than 1 is not mutable"
188        ):
189            Annotation.parse("a|b!")
190
191    def test_single_alias_is_write_to_alias_set(self) -> None:
192        a = Annotation.parse("a! -> a|b")
193        self.assertEqual(a.alias_set, tuple("a"))
194        self.assertTrue(a.is_write)
195        self.assertEqual(a.alias_set_after, ("a", "b"))
196
197    def test_before_and_after_alias_set_larger_than_1_raises_exception(self) -> None:
198        with self.assertRaisesRegex(
199            AssertionError,
200            r"before alias set and after alias set cannot be larger than 1 at the same time",
201        ):
202            Annotation.parse("a|b -> c|d")
203
204
205if __name__ == "__main__":
206    unittest.main()
207