xref: /aosp_15_r20/external/pytorch/test/fx/test_dce_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import copy
4import unittest
5from typing import Set, Type
6
7import torch
8import torch.fx
9from torch.testing._internal.common_utils import IS_MACOS, TestCase
10
11
12class TestDCE(TestCase):
13    def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
14        if node.is_impure():
15            return True
16        # a custom function that defines add operators as impure.
17        if node.target == torch.ops.aten.add:
18            return True
19        return False
20
21    def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
22        for node in m.graph.nodes:
23            if (not custom and node.is_impure()) or (
24                custom and self._custom_is_impure_node(node)
25            ):
26                continue
27            if len(node.users) == 0:
28                return True
29        return False
30
31    def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
32        count = 0
33        for node in m.graph.nodes:
34            if node.op == "placeholder":
35                count += 1
36        return count
37
38    def _run_dce_and_test(
39        self,
40        m: torch.nn.Module,
41        expect_dce_changes: bool,
42        modules_to_be_leafs: Set[Type] = None,
43        custom: bool = False,
44    ):
45        class TestTracer(torch.fx.Tracer):
46            def is_leaf_module(self, m, qualname):
47                if modules_to_be_leafs and type(m) in modules_to_be_leafs:
48                    return True
49                return super().trace(m, qualname)
50
51        traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
52        print(str(traced.graph))
53
54        # Verify there are nodes without users (if expected).
55        has_nodes_without_users = self._has_nodes_without_users(traced, custom=custom)
56        if expect_dce_changes:
57            self.assertTrue(has_nodes_without_users)
58        else:
59            self.assertFalse(has_nodes_without_users)
60
61        # Get the original number of placeholders to verify it doesn't change
62        # during DCE.
63        orig_num_phs = self._get_num_placeholders(traced)
64        if custom:
65            changed = traced.graph.eliminate_dead_code(
66                is_impure_node=self._custom_is_impure_node
67            )
68        else:
69            changed = traced.graph.eliminate_dead_code()
70
71        self.assertTrue(changed if expect_dce_changes else not changed)
72
73        # Verify there are no nodes without users after DCE is run.
74        self.assertFalse(self._has_nodes_without_users(traced, custom=custom))
75        new_num_phs = self._get_num_placeholders(traced)
76        self.assertEqual(orig_num_phs, new_num_phs)
77
78        traced.recompile()
79        # Make sure we run and get the same results before/after DCE.
80        inputs = [torch.tensor([1.5])] * new_num_phs
81        inputs_copy = copy.deepcopy(inputs)
82        self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy)))
83
84    def test_simple(self):
85        """
86        Tests that a single node in the graph is DCE'd correctly.
87        """
88
89        class TestModule(torch.nn.Module):
90            def __init__(self) -> None:
91                super().__init__()
92                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
93
94            def forward(self, x):
95                a = x + 1
96                return x + self.attr_1
97
98        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
99
100    def test_dead_chain(self):
101        """
102        Tests that a chain of two nodes in the graph are DCE'd correctly.
103        """
104
105        class TestModule(torch.nn.Module):
106            def __init__(self) -> None:
107                super().__init__()
108                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
109
110            def forward(self, x):
111                a = x + 1
112                b = a * 7
113                return x + self.attr_1
114
115        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
116
117    def test_dead_getattr(self):
118        """
119        Tests that a getatrr in the graph is DCE'd correctly.
120        """
121
122        class TestModule(torch.nn.Module):
123            def __init__(self) -> None:
124                super().__init__()
125                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
126
127            def forward(self, x):
128                a = x + 1
129                b = a * self.attr_1
130                return x + 11
131
132        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
133
134    def test_dead_placeholder(self):
135        """
136        Tests that a placeholder in the graph is not DCE'd, as that would change
137        the function signature.
138        """
139
140        class TestModule(torch.nn.Module):
141            def forward(self, x, y):
142                return x + 7
143
144        self._run_dce_and_test(TestModule(), expect_dce_changes=False)
145
146    def test_dead_placeholder_with_user(self):
147        """
148        Tests that a placeholder in the graph is not DCE'd, as that would change
149        the function signature. Also verifies that a dead node that uses the
150        placeholder is DCE'd.
151
152        """
153
154        class TestModule(torch.nn.Module):
155            def forward(self, x, y):
156                a = y + 2
157                return x + 7
158
159        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
160
161    def test_keep_module_with_side_effects(self):
162        """
163        Test that DCE doesn't remove a module if it's specified as having side effects.
164        """
165
166        class ReLUImpure(torch.nn.ReLU):
167            _is_impure = True
168
169        class TestModule(torch.nn.Module):
170            def __init__(self) -> None:
171                super().__init__()
172                self.relu = ReLUImpure()
173
174            def forward(self, a: torch.Tensor) -> torch.Tensor:
175                r = self.relu(a)
176                return a * 2
177
178        self._run_dce_and_test(
179            TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
180        )
181
182    def test_keep_torch_assert(self):
183        """
184        Test that DCE doesn't remove torch._assert since it has side effects.
185        """
186
187        class TestModule(torch.nn.Module):
188            def forward(self, a: torch.Tensor) -> torch.Tensor:
189                torch._assert(torch.equal(a, a), "a must equal a")
190                return a * 2
191
192        # Note: Don't need to specify torch._assert as having side effects
193        # because it's known to.
194        self._run_dce_and_test(TestModule(), expect_dce_changes=False)
195
196    def test_impure_nodes_args(self):
197        """
198        Test that DCE doesn't remove call_function nodes with side effects.
199        """
200
201        class TestModule(torch.nn.Module):
202            def forward(self, a: torch.Tensor) -> torch.Tensor:
203                torch._ops.ops.aten.add_.Tensor(a, 1)
204                return a * 2
205
206        # %add_ node should not be removed because it has side effects.
207        self._run_dce_and_test(TestModule(), expect_dce_changes=False)
208
209    def test_impure_kwargs(self):
210        """
211        Test that DCE doesn't remove call_function nodes with side effects on kwargs.
212        """
213
214        class TestModule(torch.nn.Module):
215            def forward(self, a: torch.Tensor) -> torch.Tensor:
216                b = a + 1
217                torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
218                return a
219
220        # %add_out node should not be removed because it has side effects.
221        self._run_dce_and_test(TestModule(), expect_dce_changes=False)
222
223    def test_impure_custom(self):
224        """
225        Test that DCE doesn't remove nodes marked as impure by a custom function.
226        """
227
228        class TestModule(torch.nn.Module):
229            def forward(self, a: torch.Tensor) -> torch.Tensor:
230                b = a + 1
231                c = torch._ops.ops.aten.add(b, b)
232                return a
233
234        # %add_out node should not be removed because it has side effects.
235        self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
236
237    @unittest.skipIf(IS_MACOS, "Not working on macos")
238    def test_keep_collectives(self):
239        """
240        Test that DCE doesn't remote collective ops even the results are not used.
241        """
242
243        from torch.testing._internal.distributed.fake_pg import FakeStore
244
245        class TestModule(torch.nn.Module):
246            def forward(
247                self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
248            ) -> torch.Tensor:
249                d = torch.ops.aten.mul.Tensor(a, b)
250                e = torch.ops.aten.mul.Tensor(a, c)
251                future = torch.ops._c10d_functional.all_reduce.default(e, "sum", "0")
252                synced_e = torch.ops._c10d_functional.wait_tensor.default(
253                    future
254                )  # synced_e is not used
255                return d
256
257        torch.distributed.init_process_group(
258            backend="fake",
259            world_size=2,
260            rank=0,
261            store=FakeStore(),
262        )
263        # collective nodes should not be removed because they have side effects.
264        self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
265        torch.distributed.destroy_process_group()
266
267    @unittest.skipIf(IS_MACOS, "Not working on macos")
268    def test_keep_collectives_no_overload(self):
269        """
270        Test that DCE doesn't remote collective ops (no overload version) even the results are not used.
271        """
272
273        from torch.testing._internal.distributed.fake_pg import FakeStore
274
275        class TestModule(torch.nn.Module):
276            def forward(
277                self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
278            ) -> torch.Tensor:
279                d = torch.ops.aten.mul(a, b)
280                e = torch.ops.aten.mul(a, c)
281                future = torch.ops._c10d_functional.all_reduce(e, "sum", "0")
282                synced_e = torch.ops._c10d_functional.wait_tensor(
283                    future
284                )  # synced_e is not used
285                return d
286
287        torch.distributed.init_process_group(
288            backend="fake",
289            world_size=2,
290            rank=0,
291            store=FakeStore(),
292        )
293        # collective nodes should not be removed because they have side effects.
294        self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
295        torch.distributed.destroy_process_group()
296