1# Owner(s): ["module: fx"] 2 3import copy 4import unittest 5from typing import Set, Type 6 7import torch 8import torch.fx 9from torch.testing._internal.common_utils import IS_MACOS, TestCase 10 11 12class TestDCE(TestCase): 13 def _custom_is_impure_node(self, node: torch.fx.Node) -> bool: 14 if node.is_impure(): 15 return True 16 # a custom function that defines add operators as impure. 17 if node.target == torch.ops.aten.add: 18 return True 19 return False 20 21 def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False): 22 for node in m.graph.nodes: 23 if (not custom and node.is_impure()) or ( 24 custom and self._custom_is_impure_node(node) 25 ): 26 continue 27 if len(node.users) == 0: 28 return True 29 return False 30 31 def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int: 32 count = 0 33 for node in m.graph.nodes: 34 if node.op == "placeholder": 35 count += 1 36 return count 37 38 def _run_dce_and_test( 39 self, 40 m: torch.nn.Module, 41 expect_dce_changes: bool, 42 modules_to_be_leafs: Set[Type] = None, 43 custom: bool = False, 44 ): 45 class TestTracer(torch.fx.Tracer): 46 def is_leaf_module(self, m, qualname): 47 if modules_to_be_leafs and type(m) in modules_to_be_leafs: 48 return True 49 return super().trace(m, qualname) 50 51 traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m)) 52 print(str(traced.graph)) 53 54 # Verify there are nodes without users (if expected). 55 has_nodes_without_users = self._has_nodes_without_users(traced, custom=custom) 56 if expect_dce_changes: 57 self.assertTrue(has_nodes_without_users) 58 else: 59 self.assertFalse(has_nodes_without_users) 60 61 # Get the original number of placeholders to verify it doesn't change 62 # during DCE. 63 orig_num_phs = self._get_num_placeholders(traced) 64 if custom: 65 changed = traced.graph.eliminate_dead_code( 66 is_impure_node=self._custom_is_impure_node 67 ) 68 else: 69 changed = traced.graph.eliminate_dead_code() 70 71 self.assertTrue(changed if expect_dce_changes else not changed) 72 73 # Verify there are no nodes without users after DCE is run. 74 self.assertFalse(self._has_nodes_without_users(traced, custom=custom)) 75 new_num_phs = self._get_num_placeholders(traced) 76 self.assertEqual(orig_num_phs, new_num_phs) 77 78 traced.recompile() 79 # Make sure we run and get the same results before/after DCE. 80 inputs = [torch.tensor([1.5])] * new_num_phs 81 inputs_copy = copy.deepcopy(inputs) 82 self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy))) 83 84 def test_simple(self): 85 """ 86 Tests that a single node in the graph is DCE'd correctly. 87 """ 88 89 class TestModule(torch.nn.Module): 90 def __init__(self) -> None: 91 super().__init__() 92 self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) 93 94 def forward(self, x): 95 a = x + 1 96 return x + self.attr_1 97 98 self._run_dce_and_test(TestModule(), expect_dce_changes=True) 99 100 def test_dead_chain(self): 101 """ 102 Tests that a chain of two nodes in the graph are DCE'd correctly. 103 """ 104 105 class TestModule(torch.nn.Module): 106 def __init__(self) -> None: 107 super().__init__() 108 self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) 109 110 def forward(self, x): 111 a = x + 1 112 b = a * 7 113 return x + self.attr_1 114 115 self._run_dce_and_test(TestModule(), expect_dce_changes=True) 116 117 def test_dead_getattr(self): 118 """ 119 Tests that a getatrr in the graph is DCE'd correctly. 120 """ 121 122 class TestModule(torch.nn.Module): 123 def __init__(self) -> None: 124 super().__init__() 125 self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) 126 127 def forward(self, x): 128 a = x + 1 129 b = a * self.attr_1 130 return x + 11 131 132 self._run_dce_and_test(TestModule(), expect_dce_changes=True) 133 134 def test_dead_placeholder(self): 135 """ 136 Tests that a placeholder in the graph is not DCE'd, as that would change 137 the function signature. 138 """ 139 140 class TestModule(torch.nn.Module): 141 def forward(self, x, y): 142 return x + 7 143 144 self._run_dce_and_test(TestModule(), expect_dce_changes=False) 145 146 def test_dead_placeholder_with_user(self): 147 """ 148 Tests that a placeholder in the graph is not DCE'd, as that would change 149 the function signature. Also verifies that a dead node that uses the 150 placeholder is DCE'd. 151 152 """ 153 154 class TestModule(torch.nn.Module): 155 def forward(self, x, y): 156 a = y + 2 157 return x + 7 158 159 self._run_dce_and_test(TestModule(), expect_dce_changes=True) 160 161 def test_keep_module_with_side_effects(self): 162 """ 163 Test that DCE doesn't remove a module if it's specified as having side effects. 164 """ 165 166 class ReLUImpure(torch.nn.ReLU): 167 _is_impure = True 168 169 class TestModule(torch.nn.Module): 170 def __init__(self) -> None: 171 super().__init__() 172 self.relu = ReLUImpure() 173 174 def forward(self, a: torch.Tensor) -> torch.Tensor: 175 r = self.relu(a) 176 return a * 2 177 178 self._run_dce_and_test( 179 TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure} 180 ) 181 182 def test_keep_torch_assert(self): 183 """ 184 Test that DCE doesn't remove torch._assert since it has side effects. 185 """ 186 187 class TestModule(torch.nn.Module): 188 def forward(self, a: torch.Tensor) -> torch.Tensor: 189 torch._assert(torch.equal(a, a), "a must equal a") 190 return a * 2 191 192 # Note: Don't need to specify torch._assert as having side effects 193 # because it's known to. 194 self._run_dce_and_test(TestModule(), expect_dce_changes=False) 195 196 def test_impure_nodes_args(self): 197 """ 198 Test that DCE doesn't remove call_function nodes with side effects. 199 """ 200 201 class TestModule(torch.nn.Module): 202 def forward(self, a: torch.Tensor) -> torch.Tensor: 203 torch._ops.ops.aten.add_.Tensor(a, 1) 204 return a * 2 205 206 # %add_ node should not be removed because it has side effects. 207 self._run_dce_and_test(TestModule(), expect_dce_changes=False) 208 209 def test_impure_kwargs(self): 210 """ 211 Test that DCE doesn't remove call_function nodes with side effects on kwargs. 212 """ 213 214 class TestModule(torch.nn.Module): 215 def forward(self, a: torch.Tensor) -> torch.Tensor: 216 b = a + 1 217 torch._ops.ops.aten.add.out(b, b, out=a, alpha=2) 218 return a 219 220 # %add_out node should not be removed because it has side effects. 221 self._run_dce_and_test(TestModule(), expect_dce_changes=False) 222 223 def test_impure_custom(self): 224 """ 225 Test that DCE doesn't remove nodes marked as impure by a custom function. 226 """ 227 228 class TestModule(torch.nn.Module): 229 def forward(self, a: torch.Tensor) -> torch.Tensor: 230 b = a + 1 231 c = torch._ops.ops.aten.add(b, b) 232 return a 233 234 # %add_out node should not be removed because it has side effects. 235 self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True) 236 237 @unittest.skipIf(IS_MACOS, "Not working on macos") 238 def test_keep_collectives(self): 239 """ 240 Test that DCE doesn't remote collective ops even the results are not used. 241 """ 242 243 from torch.testing._internal.distributed.fake_pg import FakeStore 244 245 class TestModule(torch.nn.Module): 246 def forward( 247 self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor 248 ) -> torch.Tensor: 249 d = torch.ops.aten.mul.Tensor(a, b) 250 e = torch.ops.aten.mul.Tensor(a, c) 251 future = torch.ops._c10d_functional.all_reduce.default(e, "sum", "0") 252 synced_e = torch.ops._c10d_functional.wait_tensor.default( 253 future 254 ) # synced_e is not used 255 return d 256 257 torch.distributed.init_process_group( 258 backend="fake", 259 world_size=2, 260 rank=0, 261 store=FakeStore(), 262 ) 263 # collective nodes should not be removed because they have side effects. 264 self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False) 265 torch.distributed.destroy_process_group() 266 267 @unittest.skipIf(IS_MACOS, "Not working on macos") 268 def test_keep_collectives_no_overload(self): 269 """ 270 Test that DCE doesn't remote collective ops (no overload version) even the results are not used. 271 """ 272 273 from torch.testing._internal.distributed.fake_pg import FakeStore 274 275 class TestModule(torch.nn.Module): 276 def forward( 277 self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor 278 ) -> torch.Tensor: 279 d = torch.ops.aten.mul(a, b) 280 e = torch.ops.aten.mul(a, c) 281 future = torch.ops._c10d_functional.all_reduce(e, "sum", "0") 282 synced_e = torch.ops._c10d_functional.wait_tensor( 283 future 284 ) # synced_e is not used 285 return d 286 287 torch.distributed.init_process_group( 288 backend="fake", 289 world_size=2, 290 rank=0, 291 store=FakeStore(), 292 ) 293 # collective nodes should not be removed because they have side effects. 294 self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False) 295 torch.distributed.destroy_process_group() 296