1# Owner(s): ["oncall: export"] 2import copy 3import unittest 4 5import torch 6from functorch.experimental import control_flow 7from torch._dynamo.eval_frame import is_dynamo_supported 8from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse 9from torch.export import export 10from torch.fx.passes.infra.pass_base import PassResult 11from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase 12 13 14@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 15class TestPassInfra(TestCase): 16 def test_export_pass_base(self) -> None: 17 class Foo(torch.nn.Module): 18 def forward(self, x): 19 y = torch.cat([x, x]) 20 return torch.ops.aten.tensor_split.sections(y, 2) 21 22 f = Foo() 23 24 class NullPass(_ExportPassBaseDeprecatedDoNotUse): 25 pass 26 27 ep = export(f, (torch.ones(3, 2),)) 28 old_nodes = ep.graph.nodes 29 30 ep = ep._transform_do_not_use(NullPass()) 31 new_nodes = ep.graph.nodes 32 33 for node in new_nodes: 34 if node.op != "call_function": 35 continue 36 self.assertTrue(hasattr(node, "stack_trace")) 37 self.assertIsNotNone(node.stack_trace) 38 39 self.assertEqual(len(new_nodes), len(old_nodes)) 40 for new_node, old_node in zip(new_nodes, old_nodes): 41 self.assertEqual(new_node.op, old_node.op) 42 self.assertEqual(new_node.target, old_node.target) 43 44 @unittest.skipIf(IS_WINDOWS, "Windows not supported") 45 def test_cond(self) -> None: 46 class M(torch.nn.Module): 47 def __init__(self) -> None: 48 super().__init__() 49 50 def forward(self, pred, x, y): 51 def true_fn(x, y): 52 b = x.item() 53 torch._check(b >= 2) 54 torch._check(b <= 5) 55 return x - y 56 57 def false_fn(x, y): 58 c = y.item() 59 torch._check(c >= 2) 60 torch._check(c <= 5) 61 return x + y 62 63 ret = control_flow.cond(pred, true_fn, false_fn, [x, y]) 64 return ret 65 66 x = torch.tensor([2]) 67 y = torch.tensor([5]) 68 mod = M() 69 _ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use( 70 _ExportPassBaseDeprecatedDoNotUse() 71 ) 72 73 def test_node_name_stability(self) -> None: 74 # Tests that graph nodes stay the same for nodes that are not touched 75 # during transformation 76 class CustomModule(torch.nn.Module): 77 def __init__(self) -> None: 78 super().__init__() 79 80 # Define a parameter 81 self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) 82 83 # Define two buffers 84 self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) 85 self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) 86 87 def forward(self, x1, x2): 88 # Use the parameter, buffers, and both inputs in the forward method 89 output = ( 90 x1 + self.my_parameter 91 ) * self.my_buffer1 + x2 * self.my_buffer2 92 93 # Mutate one of the buffers (e.g., increment it by 1) 94 self.my_buffer2.add_(1.0) 95 96 return output 97 98 inps = (torch.rand(1), torch.rand(1)) 99 m = CustomModule() 100 101 ep_before = export(m, inps) 102 103 # No op transformation that doesn't perform any meaningful changes to node 104 ep_after = ep_before._transform_do_not_use(_ExportPassBaseDeprecatedDoNotUse()) 105 106 for before_node, after_node in zip(ep_before.graph.nodes, ep_after.graph.nodes): 107 self.assertEqual(before_node.name, after_node.name) 108 109 def test_graph_signature_updated_after_transformation(self) -> None: 110 # Checks that pass infra correctly updates graph signature 111 # after transformations. 112 class CustomModule(torch.nn.Module): 113 def __init__(self) -> None: 114 super().__init__() 115 116 self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) 117 118 self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) 119 self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) 120 121 def forward(self, x1, x2): 122 # Use the parameter, buffers, and both inputs in the forward method 123 output = ( 124 x1 + self.my_parameter 125 ) * self.my_buffer1 + x2 * self.my_buffer2 126 return output 127 128 my_module = CustomModule() 129 130 # Test the custom module with two input tensors 131 input_tensor1 = torch.tensor(5.0) 132 input_tensor2 = torch.tensor(6.0) 133 134 ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2)) 135 from torch.fx.passes.infra.pass_base import PassResult 136 137 def modify_input_output_pass(gm): 138 for node in gm.graph.nodes: 139 if node.op == "call_function": 140 node.name = node.name + "_modified" 141 gm.recompile() 142 return PassResult(gm, True) 143 144 ep_after = ep_before._transform_do_not_use(modify_input_output_pass) 145 new_signature = ep_after.graph_signature 146 147 for node_name in new_signature.user_outputs: 148 self.assertTrue("_modified" in node_name) 149 150 old_signature = ep_before.graph_signature 151 self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs) 152 153 def test_replace_hook_basic(self) -> None: 154 class CustomModule(torch.nn.Module): 155 def __init__(self) -> None: 156 super().__init__() 157 158 self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) 159 160 self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) 161 self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) 162 163 def forward(self, x1, x2): 164 # Use the parameter, buffers, and both inputs in the forward method 165 output = ( 166 x1 + self.my_parameter 167 ) * self.my_buffer1 + x2 * self.my_buffer2 168 return output 169 170 my_module = CustomModule() 171 inputs = (torch.tensor(6.0), torch.tensor(7.0)) 172 ep_before = export(my_module, inputs) 173 174 def replace_pass(gm): 175 for node in gm.graph.nodes: 176 if node.op == "call_function": 177 node.name = node.name + "_modified" 178 gm.recompile() 179 return PassResult(gm, True) 180 181 gm = copy.deepcopy(ep_before.graph_module) 182 sig = copy.deepcopy(ep_before.graph_signature) 183 184 with gm._set_replace_hook(sig.get_replace_hook()): 185 replace_pass(gm) 186 187 for node_name in sig.user_outputs: 188 self.assertTrue("_modified" in node_name) 189 190 old_signature = ep_before.graph_signature 191 self.assertNotEqual(sig.user_outputs, old_signature.user_outputs) 192 193 194if __name__ == "__main__": 195 run_tests() 196