xref: /aosp_15_r20/external/pytorch/test/jit/test_custom_operators.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6
7import torch
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24def canonical(graph):
25    return torch._C._jit_pass_canonicalize(graph).str(False)
26
27
28class TestCustomOperators(JitTestCase):
29    def test_dynamic_op_registry(self):
30        from torch._ops import _OpNamespace
31
32        self.assertTrue(hasattr(torch, "ops"))
33
34        if "_test" in torch.ops.__dict__:
35            torch.ops.__dict__.pop("_test")
36
37        # Don't use `hasattr()` because it will call `__getattr__`.
38        self.assertNotIn("_test", torch.ops.__dict__)
39        torch.ops._test
40        self.assertIn("_test", torch.ops.__dict__)
41        self.assertEqual(type(torch.ops._test), _OpNamespace)
42
43        self.assertNotIn("leaky_relu", torch.ops._test.__dict__)
44        op = torch.ops._test.leaky_relu
45        self.assertTrue(callable(op))
46        self.assertIn("leaky_relu", torch.ops._test.__dict__)
47        op2 = torch.ops._test.leaky_relu
48        self.assertEqual(op, op2)
49
50    def test_getting_invalid_attr(self):
51        for attr in ["__origin__", "__self__"]:
52            with self.assertRaisesRegexWithHighlight(
53                AttributeError,
54                f"Invalid attribute '{attr}' for '_OpNamespace' '_test'",
55                "",
56            ):
57                getattr(torch.ops._test, attr)
58
59    def test_simply_calling_an_operator(self):
60        input = torch.randn(100)
61        output = torch.ops.aten.relu(input)
62        self.assertEqual(output, input.relu())
63
64    def test_default_arguments_are_used(self):
65        output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
66        self.assertEqual(output, torch.tensor([-0.01, 1]))
67
68    def test_passing_too_many_args(self):
69        with self.assertRaisesRegexWithHighlight(
70            RuntimeError,
71            r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)",
72            "",
73        ):
74            torch.ops.aten.relu(1, 2)
75
76    def test_passing_too_few_args(self):
77        with self.assertRaisesRegexWithHighlight(
78            RuntimeError, r"aten::relu\(\) is missing value for argument 'self'.", ""
79        ):
80            torch.ops.aten.relu()
81
82    def test_passing_one_positional_but_not_the_second(self):
83        with self.assertRaisesRegexWithHighlight(
84            RuntimeError,
85            r"aten::type_as\(\) is missing value for argument 'other'.",
86            "",
87        ):
88            torch.ops.aten.type_as(torch.ones(5, 5))
89
90    def test_passing_unknown_kwargs(self):
91        with self.assertRaisesRegexWithHighlight(
92            RuntimeError,
93            "Unknown keyword argument 'foo' for operator '_test::leaky_relu'",
94            "",
95        ):
96            torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
97
98    def test_passing_and_returning_lists(self):
99        # Replace with actual test once we support lists.
100        a, b = torch.rand(5), torch.rand(5)
101        output = torch.ops._test.cat([a, b])
102        output_ref = torch.cat([a, b])
103        self.assertEqual(output, output_ref)
104
105    def test_calling_scripted_custom_op(self):
106        @torch.jit.script
107        def func(x):
108            return torch.ops.aten.relu(x)
109
110        input = torch.ones(5, 5)
111        self.assertEqual(func(input), input.relu())
112
113    def test_calling_traced_custom_op(self):
114        input = torch.ones(5, 5)
115        func = torch.jit.trace(torch.ops.aten.relu, [input])
116        self.assertEqual(func(input), input.relu())
117
118    @unittest.skip(
119        "Need to figure out default dtype differences between fbcode and oss"
120    )
121    def test_script_graph_for_custom_ops_matches_traced_graph(self):
122        input = torch.ones(5, 5)
123        trace = torch.jit.trace(torch.ops.aten.relu, [input])
124        self.assertExpectedInline(
125            canonical(trace.graph),
126            """\
127graph(%0 : Float(5, 5)):
128  %1 : Float(5, 5) = aten::relu(%0)
129  return (%1)
130""",
131        )
132
133    def test_script_graph_contains_custom_op(self):
134        @torch.jit.script
135        def func(x):
136            return torch.ops.aten.relu(x)
137
138        self.assertExpectedInline(
139            canonical(func.graph),
140            """\
141graph(%x.1 : Tensor):
142  %1 : Tensor = aten::relu(%x.1)
143  return (%1)
144""",
145        )
146
147    def test_generic_list(self):
148        self.assertEqual(torch.ops._test.get_first([["hello"]]), "hello")
149
150    # https://github.com/pytorch/pytorch/issues/80508
151    def test_where_no_scalar(self):
152        x = torch.rand(1, 3, 224, 224)
153        torch.ops.aten.where(x > 0.5, -1.5, 1.5)  # does not raise
154