1# Owner(s): ["module: codegen"] 2 3import textwrap 4import unittest 5from typing import cast 6 7import expecttest 8import yaml 9 10import torchgen.dest as dest 11import torchgen.gen as gen 12from torchgen.gen import LineLoader, parse_native_yaml_struct 13from torchgen.model import ( 14 Annotation, 15 CustomClassType, 16 DispatchKey, 17 NativeFunctionsGroup, 18 Type, 19) 20 21 22class TestCodegenModel(expecttest.TestCase): 23 def assertParseErrorInline(self, yaml_str: str, expect: str) -> None: 24 es = yaml.load(yaml_str, Loader=LineLoader) 25 try: 26 parse_native_yaml_struct(es, set()) 27 except AssertionError as e: 28 # hack to strip out the context 29 msg, _ = str(e).split(" in ", 2) 30 self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1) 31 return 32 self.fail(msg="Did not raise when expected to") 33 34 def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None: 35 # parse a single structured group out of the yaml to g 36 es = yaml.load(yaml_str, Loader=LineLoader) 37 parsed_yaml = parse_native_yaml_struct(es, set()) 38 native_functions, backend_indices = ( 39 parsed_yaml.native_functions, 40 parsed_yaml.backend_indices, 41 ) 42 grouped_native_functions = gen.get_grouped_native_functions(native_functions) 43 assert len(grouped_native_functions) == 1 44 g = grouped_native_functions[0] 45 assert isinstance(g, NativeFunctionsGroup) 46 assert g.out.ufunc_inner_loop 47 # this is not ufunc codegen per se, but it does some basic sanity tests for 48 # ufunc generation 49 gen.compute_meta_function_declaration(g) 50 dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CPU]) 51 dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CUDA]) 52 try: 53 # the real kahuna 54 dest.compute_ufunc_cpu(g) 55 dest.compute_ufunc_cpu_kernel(g) 56 dest.compute_ufunc_cuda(g) 57 except AssertionError as e: 58 # hack to strip out the context 59 msg, _ = str(e).split(" in ", 2) 60 self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1) 61 return 62 self.fail(msg="Did not raise when expected to") 63 64 # NB: indent is hardcoded to be two here, so format your yaml accordingly 65 binop_out = ( 66 "func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" 67 ) 68 ti_binop_out = f"""{binop_out} 69 structured: True 70 structured_inherits: TensorIteratorBase""" 71 ti_binop = """func: binop(Tensor self, Tensor other) -> Tensor 72 structured_delegate: binop.out 73""" 74 75 ti_unop_out = """func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 76 structured: True 77 structured_inherits: TensorIteratorBase""" 78 ti_unop = """func: unop(Tensor self) -> Tensor 79 structured_delegate: unop.out 80""" 81 82 def test_nonstructured_ufunc(self) -> None: 83 yaml_str = f"""\ 84- {self.binop_out} 85 ufunc_inner_loop: 86 Generic: binop (Bool) 87""" 88 self.assertParseErrorInline( 89 yaml_str, 90 """\ 91ufunc must be structured""", 92 ) 93 94 def test_overlapping_ufunc_and_dispatch(self) -> None: 95 yaml_str = f"""\ 96- {self.ti_binop_out} 97 ufunc_inner_loop: 98 Generic: binop (Bool) 99 dispatch: 100 CPU: binop_cpu 101""" 102 self.assertParseErrorInline( 103 yaml_str, 104 """\ 105ufunc should not have explicit dispatch entry for CPU""", 106 ) 107 108 # See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456 109 @unittest.expectedFailure 110 def test_scalaronly_shadowed(self) -> None: 111 yaml_str = f"""\ 112- {self.ti_binop_out} 113 ufunc_inner_loop: 114 Generic: binop (Bool) 115 ScalarOnly: binop (Bool) 116""" 117 self.assertParseErrorInline( 118 yaml_str, 119 """\ 120""", 121 ) 122 123 def test_conflicting_ufunc(self) -> None: 124 yaml_str = f"""\ 125- {self.ti_binop_out} 126 ufunc_inner_loop: 127 Generic: binop (Bool) 128 ScalarOnly: binop_scalar (Bool) 129- {self.ti_binop} 130""" 131 self.assertUfuncErrorInline( 132 yaml_str, 133 """\ 134ScalarOnly and Generic must have same ufunc name""", 135 ) 136 137 def test_invalid_cudafunctoronself_for_binary_op(self) -> None: 138 yaml_str = f"""\ 139- {self.ti_unop_out} 140 ufunc_inner_loop: 141 Generic: unop (All) 142 CUDAFunctorOnSelf: unop_self_cuda (All) 143- {self.ti_unop} 144""" 145 self.assertUfuncErrorInline( 146 yaml_str, 147 """\ 148cannot use CUDAFunctorOnSelf on non-binary function""", 149 ) 150 151 def test_parse_custom_class_type(self) -> None: 152 custom_class_name = "namespace_foo.class_bar" 153 custom_class_name_with_prefix = f"__torch__.torch.classes.{custom_class_name}" 154 custom_class_type = cast( 155 CustomClassType, Type.parse(custom_class_name_with_prefix) 156 ) 157 self.assertTrue(isinstance(custom_class_type, CustomClassType)) 158 self.assertEqual(custom_class_name, custom_class_type.class_name) 159 self.assertEqual(custom_class_name_with_prefix, str(custom_class_type)) 160 161 162class TestAnnotation(expecttest.TestCase): 163 def test_single_alias_no_write(self) -> None: 164 a = Annotation.parse("a") 165 self.assertEqual(a.alias_set, tuple("a")) 166 self.assertFalse(a.is_write) 167 self.assertEqual(a.alias_set_after, ()) 168 169 def test_single_alias_is_write(self) -> None: 170 a = Annotation.parse("a!") 171 self.assertEqual(a.alias_set, tuple("a")) 172 self.assertTrue(a.is_write) 173 self.assertEqual(a.alias_set_after, ()) 174 175 def test_single_alias_is_write_to_wildcard(self) -> None: 176 a = Annotation.parse("a! -> *") 177 self.assertEqual(a.alias_set, tuple("a")) 178 self.assertTrue(a.is_write) 179 self.assertEqual(a.alias_set_after, tuple("*")) 180 181 def test_alias_set(self) -> None: 182 a = Annotation.parse("a|b") 183 self.assertEqual(a.alias_set, ("a", "b")) 184 185 def test_alias_set_is_write_raises_exception(self) -> None: 186 with self.assertRaisesRegex( 187 AssertionError, r"alias set larger than 1 is not mutable" 188 ): 189 Annotation.parse("a|b!") 190 191 def test_single_alias_is_write_to_alias_set(self) -> None: 192 a = Annotation.parse("a! -> a|b") 193 self.assertEqual(a.alias_set, tuple("a")) 194 self.assertTrue(a.is_write) 195 self.assertEqual(a.alias_set_after, ("a", "b")) 196 197 def test_before_and_after_alias_set_larger_than_1_raises_exception(self) -> None: 198 with self.assertRaisesRegex( 199 AssertionError, 200 r"before alias set and after alias set cannot be larger than 1 at the same time", 201 ): 202 Annotation.parse("a|b -> c|d") 203 204 205if __name__ == "__main__": 206 unittest.main() 207