xref: /aosp_15_r20/external/pytorch/test/export/test_pass_infra.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2import copy
3import unittest
4
5import torch
6from functorch.experimental import control_flow
7from torch._dynamo.eval_frame import is_dynamo_supported
8from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
9from torch.export import export
10from torch.fx.passes.infra.pass_base import PassResult
11from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
12
13
14@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
15class TestPassInfra(TestCase):
16    def test_export_pass_base(self) -> None:
17        class Foo(torch.nn.Module):
18            def forward(self, x):
19                y = torch.cat([x, x])
20                return torch.ops.aten.tensor_split.sections(y, 2)
21
22        f = Foo()
23
24        class NullPass(_ExportPassBaseDeprecatedDoNotUse):
25            pass
26
27        ep = export(f, (torch.ones(3, 2),))
28        old_nodes = ep.graph.nodes
29
30        ep = ep._transform_do_not_use(NullPass())
31        new_nodes = ep.graph.nodes
32
33        for node in new_nodes:
34            if node.op != "call_function":
35                continue
36            self.assertTrue(hasattr(node, "stack_trace"))
37            self.assertIsNotNone(node.stack_trace)
38
39        self.assertEqual(len(new_nodes), len(old_nodes))
40        for new_node, old_node in zip(new_nodes, old_nodes):
41            self.assertEqual(new_node.op, old_node.op)
42            self.assertEqual(new_node.target, old_node.target)
43
44    @unittest.skipIf(IS_WINDOWS, "Windows not supported")
45    def test_cond(self) -> None:
46        class M(torch.nn.Module):
47            def __init__(self) -> None:
48                super().__init__()
49
50            def forward(self, pred, x, y):
51                def true_fn(x, y):
52                    b = x.item()
53                    torch._check(b >= 2)
54                    torch._check(b <= 5)
55                    return x - y
56
57                def false_fn(x, y):
58                    c = y.item()
59                    torch._check(c >= 2)
60                    torch._check(c <= 5)
61                    return x + y
62
63                ret = control_flow.cond(pred, true_fn, false_fn, [x, y])
64                return ret
65
66        x = torch.tensor([2])
67        y = torch.tensor([5])
68        mod = M()
69        _ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use(
70            _ExportPassBaseDeprecatedDoNotUse()
71        )
72
73    def test_node_name_stability(self) -> None:
74        # Tests that graph nodes stay the same for nodes that are not touched
75        # during transformation
76        class CustomModule(torch.nn.Module):
77            def __init__(self) -> None:
78                super().__init__()
79
80                # Define a parameter
81                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
82
83                # Define two buffers
84                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
85                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
86
87            def forward(self, x1, x2):
88                # Use the parameter, buffers, and both inputs in the forward method
89                output = (
90                    x1 + self.my_parameter
91                ) * self.my_buffer1 + x2 * self.my_buffer2
92
93                # Mutate one of the buffers (e.g., increment it by 1)
94                self.my_buffer2.add_(1.0)
95
96                return output
97
98        inps = (torch.rand(1), torch.rand(1))
99        m = CustomModule()
100
101        ep_before = export(m, inps)
102
103        # No op transformation that doesn't perform any meaningful changes to node
104        ep_after = ep_before._transform_do_not_use(_ExportPassBaseDeprecatedDoNotUse())
105
106        for before_node, after_node in zip(ep_before.graph.nodes, ep_after.graph.nodes):
107            self.assertEqual(before_node.name, after_node.name)
108
109    def test_graph_signature_updated_after_transformation(self) -> None:
110        # Checks that pass infra correctly updates graph signature
111        # after transformations.
112        class CustomModule(torch.nn.Module):
113            def __init__(self) -> None:
114                super().__init__()
115
116                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
117
118                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
119                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
120
121            def forward(self, x1, x2):
122                # Use the parameter, buffers, and both inputs in the forward method
123                output = (
124                    x1 + self.my_parameter
125                ) * self.my_buffer1 + x2 * self.my_buffer2
126                return output
127
128        my_module = CustomModule()
129
130        # Test the custom module with two input tensors
131        input_tensor1 = torch.tensor(5.0)
132        input_tensor2 = torch.tensor(6.0)
133
134        ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
135        from torch.fx.passes.infra.pass_base import PassResult
136
137        def modify_input_output_pass(gm):
138            for node in gm.graph.nodes:
139                if node.op == "call_function":
140                    node.name = node.name + "_modified"
141            gm.recompile()
142            return PassResult(gm, True)
143
144        ep_after = ep_before._transform_do_not_use(modify_input_output_pass)
145        new_signature = ep_after.graph_signature
146
147        for node_name in new_signature.user_outputs:
148            self.assertTrue("_modified" in node_name)
149
150        old_signature = ep_before.graph_signature
151        self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
152
153    def test_replace_hook_basic(self) -> None:
154        class CustomModule(torch.nn.Module):
155            def __init__(self) -> None:
156                super().__init__()
157
158                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
159
160                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
161                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
162
163            def forward(self, x1, x2):
164                # Use the parameter, buffers, and both inputs in the forward method
165                output = (
166                    x1 + self.my_parameter
167                ) * self.my_buffer1 + x2 * self.my_buffer2
168                return output
169
170        my_module = CustomModule()
171        inputs = (torch.tensor(6.0), torch.tensor(7.0))
172        ep_before = export(my_module, inputs)
173
174        def replace_pass(gm):
175            for node in gm.graph.nodes:
176                if node.op == "call_function":
177                    node.name = node.name + "_modified"
178            gm.recompile()
179            return PassResult(gm, True)
180
181        gm = copy.deepcopy(ep_before.graph_module)
182        sig = copy.deepcopy(ep_before.graph_signature)
183
184        with gm._set_replace_hook(sig.get_replace_hook()):
185            replace_pass(gm)
186
187        for node_name in sig.user_outputs:
188            self.assertTrue("_modified" in node_name)
189
190        old_signature = ep_before.graph_signature
191        self.assertNotEqual(sig.user_outputs, old_signature.user_outputs)
192
193
194if __name__ == "__main__":
195    run_tests()
196