1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6 7import torch 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24def canonical(graph): 25 return torch._C._jit_pass_canonicalize(graph).str(False) 26 27 28class TestCustomOperators(JitTestCase): 29 def test_dynamic_op_registry(self): 30 from torch._ops import _OpNamespace 31 32 self.assertTrue(hasattr(torch, "ops")) 33 34 if "_test" in torch.ops.__dict__: 35 torch.ops.__dict__.pop("_test") 36 37 # Don't use `hasattr()` because it will call `__getattr__`. 38 self.assertNotIn("_test", torch.ops.__dict__) 39 torch.ops._test 40 self.assertIn("_test", torch.ops.__dict__) 41 self.assertEqual(type(torch.ops._test), _OpNamespace) 42 43 self.assertNotIn("leaky_relu", torch.ops._test.__dict__) 44 op = torch.ops._test.leaky_relu 45 self.assertTrue(callable(op)) 46 self.assertIn("leaky_relu", torch.ops._test.__dict__) 47 op2 = torch.ops._test.leaky_relu 48 self.assertEqual(op, op2) 49 50 def test_getting_invalid_attr(self): 51 for attr in ["__origin__", "__self__"]: 52 with self.assertRaisesRegexWithHighlight( 53 AttributeError, 54 f"Invalid attribute '{attr}' for '_OpNamespace' '_test'", 55 "", 56 ): 57 getattr(torch.ops._test, attr) 58 59 def test_simply_calling_an_operator(self): 60 input = torch.randn(100) 61 output = torch.ops.aten.relu(input) 62 self.assertEqual(output, input.relu()) 63 64 def test_default_arguments_are_used(self): 65 output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0])) 66 self.assertEqual(output, torch.tensor([-0.01, 1])) 67 68 def test_passing_too_many_args(self): 69 with self.assertRaisesRegexWithHighlight( 70 RuntimeError, 71 r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)", 72 "", 73 ): 74 torch.ops.aten.relu(1, 2) 75 76 def test_passing_too_few_args(self): 77 with self.assertRaisesRegexWithHighlight( 78 RuntimeError, r"aten::relu\(\) is missing value for argument 'self'.", "" 79 ): 80 torch.ops.aten.relu() 81 82 def test_passing_one_positional_but_not_the_second(self): 83 with self.assertRaisesRegexWithHighlight( 84 RuntimeError, 85 r"aten::type_as\(\) is missing value for argument 'other'.", 86 "", 87 ): 88 torch.ops.aten.type_as(torch.ones(5, 5)) 89 90 def test_passing_unknown_kwargs(self): 91 with self.assertRaisesRegexWithHighlight( 92 RuntimeError, 93 "Unknown keyword argument 'foo' for operator '_test::leaky_relu'", 94 "", 95 ): 96 torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5)) 97 98 def test_passing_and_returning_lists(self): 99 # Replace with actual test once we support lists. 100 a, b = torch.rand(5), torch.rand(5) 101 output = torch.ops._test.cat([a, b]) 102 output_ref = torch.cat([a, b]) 103 self.assertEqual(output, output_ref) 104 105 def test_calling_scripted_custom_op(self): 106 @torch.jit.script 107 def func(x): 108 return torch.ops.aten.relu(x) 109 110 input = torch.ones(5, 5) 111 self.assertEqual(func(input), input.relu()) 112 113 def test_calling_traced_custom_op(self): 114 input = torch.ones(5, 5) 115 func = torch.jit.trace(torch.ops.aten.relu, [input]) 116 self.assertEqual(func(input), input.relu()) 117 118 @unittest.skip( 119 "Need to figure out default dtype differences between fbcode and oss" 120 ) 121 def test_script_graph_for_custom_ops_matches_traced_graph(self): 122 input = torch.ones(5, 5) 123 trace = torch.jit.trace(torch.ops.aten.relu, [input]) 124 self.assertExpectedInline( 125 canonical(trace.graph), 126 """\ 127graph(%0 : Float(5, 5)): 128 %1 : Float(5, 5) = aten::relu(%0) 129 return (%1) 130""", 131 ) 132 133 def test_script_graph_contains_custom_op(self): 134 @torch.jit.script 135 def func(x): 136 return torch.ops.aten.relu(x) 137 138 self.assertExpectedInline( 139 canonical(func.graph), 140 """\ 141graph(%x.1 : Tensor): 142 %1 : Tensor = aten::relu(%x.1) 143 return (%1) 144""", 145 ) 146 147 def test_generic_list(self): 148 self.assertEqual(torch.ops._test.get_first([["hello"]]), "hello") 149 150 # https://github.com/pytorch/pytorch/issues/80508 151 def test_where_no_scalar(self): 152 x = torch.rand(1, 3, 224, 224) 153 torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise 154