1# Owner(s): ["oncall: fx"] 2 3import contextlib 4import pickle 5from io import BytesIO 6from unittest.mock import patch 7 8import torch 9import torch._export 10from torch import fx 11from torch.fx._lazy_graph_module import ( 12 _LazyGraphModule, 13 _make_graph_module, 14 _use_lazy_graph_module, 15) 16from torch.fx.experimental.proxy_tensor import make_fx 17from torch.package import PackageExporter, PackageImporter 18from torch.testing._internal.common_utils import run_tests, TestCase 19 20 21class TestLazyGraphModule(TestCase): 22 exit_stack = None 23 24 @classmethod 25 def setUpClass(cls): 26 cls.exit_stack = contextlib.ExitStack() 27 cls.exit_stack.enter_context(_use_lazy_graph_module(True)) 28 29 @classmethod 30 def tearDownClass(cls): 31 cls.exit_stack.close() 32 33 @staticmethod 34 def replace_sin_with_cos(gm): 35 for n in gm.graph.nodes: 36 if n.target == "sin": 37 n.target = "cos" 38 39 def test_replace_sin_with_cos(self): 40 def f(x): 41 return x.sin() 42 43 x = torch.randn(2, 3) 44 gm = fx.symbolic_trace(f) 45 self.replace_sin_with_cos(gm) 46 47 gm.recompile() 48 expected = x.cos() 49 actual = gm(x) 50 51 self.assertTrue(torch.allclose(expected, actual)) 52 code = gm.print_readable(False) 53 self.assertTrue("cos()" in code) 54 self.assertTrue(isinstance(gm, _LazyGraphModule)) 55 56 def test_call_forward_directly(self): 57 def f(x): 58 return x.sin() 59 60 x = torch.randn(2, 3) 61 gm = fx.symbolic_trace(f) 62 self.assertTrue(isinstance(gm, _LazyGraphModule)) 63 self.replace_sin_with_cos(gm) 64 gm.recompile() 65 expected = x.cos() 66 actual = gm.forward(x) 67 68 self.assertTrue(torch.allclose(expected, actual)) 69 70 def test_needs_recompile(self): 71 """ 72 Make sure needs_recompile() return the corrent state. 73 """ 74 75 def f(x): 76 return x.sin() 77 78 gm = fx.symbolic_trace(f) 79 self.assertTrue(isinstance(gm, _LazyGraphModule)) 80 self.assertTrue(gm._needs_recompile()) 81 gm(torch.randn(2, 3)) 82 self.assertFalse(gm._needs_recompile()) 83 84 def test_multi_recompile(self): 85 """ 86 Cover the case that multiple recompilation happens. 87 """ 88 89 def f(x): 90 return x.sin() 91 92 gm = fx.symbolic_trace(f) 93 self.assertTrue(isinstance(gm, _LazyGraphModule)) 94 self.assertTrue(gm._needs_recompile()) 95 x = torch.randn(2, 3) 96 # trigger the first recompilation 97 self.assertTrue(torch.allclose(x.sin(), gm(x))) 98 self.assertFalse(gm._needs_recompile()) 99 100 self.replace_sin_with_cos(gm) 101 self.assertFalse(gm._needs_recompile()) 102 gm.recompile() 103 self.assertTrue(gm._needs_recompile()) 104 # trigger the second recompilation 105 self.assertTrue(torch.allclose(x.cos(), gm(x))) 106 self.assertFalse(gm._needs_recompile()) 107 108 def test_accessing_code_cause_recompiling(self): 109 """ 110 Make sure we recompile if we have not done that yet when we access the code 111 property of a GraphModule. 112 """ 113 114 def f(x): 115 return x.sin() 116 117 gm = fx.symbolic_trace(f) 118 self.assertTrue(isinstance(gm, _LazyGraphModule)) 119 self.assertTrue(gm._needs_recompile()) 120 # should trigger a recompilation 121 code = gm.code 122 self.assertTrue("sin" in code) 123 self.assertFalse(gm._needs_recompile()) 124 125 def test_graph_module_str(self): 126 def f(x): 127 return x.sin() 128 129 gm = fx.symbolic_trace(f) 130 self.assertTrue(isinstance(gm, _LazyGraphModule)) 131 self.assertTrue("sin" in str(gm)) 132 133 def test_recapture_with_make_fx(self): 134 def f(x): 135 return x.sin() 136 137 gm = fx.symbolic_trace(f) 138 self.assertTrue(isinstance(gm, _LazyGraphModule)) 139 self.assertTrue(gm._needs_recompile()) 140 gm2 = make_fx(gm)(torch.randn(2, 3)) 141 self.assertTrue(isinstance(gm2, _LazyGraphModule)) 142 self.assertTrue(gm2._needs_recompile()) 143 144 # make_fx will cal foward method of gm. That clears the _needs_recompile() 145 # flag. 146 self.assertFalse(gm._needs_recompile()) 147 148 def test_recapture_with_symbolic_trace(self): 149 def f(x): 150 return x.sin() 151 152 gm = fx.symbolic_trace(f) 153 self.assertTrue(isinstance(gm, _LazyGraphModule)) 154 self.assertTrue(gm._needs_recompile()) 155 gm2 = fx.symbolic_trace(gm) 156 157 # the lazy recompilcation is already realized. We realize the 158 # recompilation in the beginning of symbolic_trace since symbolic_trace can not 159 # handle the tracing of lazy recompilation. 160 self.assertFalse(gm._needs_recompile()) 161 self.assertTrue(gm2._needs_recompile()) 162 163 def test_recapture_with_dynamo(self): 164 def f(x): 165 return x.sin() 166 167 gm = fx.symbolic_trace(f) 168 self.assertTrue(isinstance(gm, _LazyGraphModule)) 169 self.assertTrue(gm._needs_recompile()) 170 torch.compile(gm)(torch.rand(2, 3)) 171 172 # dynamo calls gm.forward with eval hook installed. That will trigger 173 # the real recompilation. 174 self.assertFalse(gm._needs_recompile()) 175 176 def test_save_lazy_foward(self): 177 """ 178 Save the lazy forward method and call it repeatly. Make sure we 179 don't recompile for each such call. 180 """ 181 182 def f(x): 183 return x.sin() 184 185 orig_gm_recompile = fx.GraphModule.recompile 186 recompile_count = 0 187 188 def mock_gm_recompile(self): 189 nonlocal recompile_count 190 recompile_count += 1 191 return orig_gm_recompile(self) 192 193 with patch.object(fx.GraphModule, "recompile", mock_gm_recompile): 194 gm = fx.symbolic_trace(f) 195 self.assertTrue(isinstance(gm, _LazyGraphModule)) 196 saved_fwd = gm.forward 197 198 x = torch.rand(2, 3) 199 for _ in range(10): 200 saved_fwd(x) 201 202 self.assertEqual(recompile_count, 1) 203 204 def test_pickle(self): 205 """ 206 Fx graph cache need the ability to pickle GraphModule/_LazyGraphModule. 207 """ 208 209 def f(x): 210 return x.sin() 211 212 gm = fx.symbolic_trace(f) 213 self.assertTrue(isinstance(gm, _LazyGraphModule)) 214 serialized = pickle.dumps(gm) 215 gm2 = pickle.loads(serialized) 216 self.assertTrue(isinstance(gm2, _LazyGraphModule)) 217 self.assertTrue("sin" in gm2.code) 218 219 def test_make_graph_module(self): 220 gm = fx.symbolic_trace(lambda x: x.sin()) 221 self.assertTrue(isinstance(gm, _LazyGraphModule)) 222 223 gm1 = _make_graph_module( 224 gm, gm.graph, class_name="MyGraphModule", graph_module_cls=fx.GraphModule 225 ) 226 self.assertFalse(isinstance(gm1, _LazyGraphModule)) 227 self.assertTrue(gm1.__class__.__name__ == "MyGraphModule") 228 229 gm2 = _make_graph_module(gm, gm.graph) 230 self.assertTrue(isinstance(gm2, _LazyGraphModule)) 231 self.assertTrue(gm2.__class__.__name__ == "GraphModule") 232 233 def test_package_fx_simple(self): 234 """ 235 Copied from test/package/test_package_fx.py to make sure LazyGraphModule 236 works with torch.package. 237 """ 238 239 class SimpleTest(torch.nn.Module): 240 def forward(self, x): 241 return torch.relu(x + 3.0) 242 243 st = SimpleTest() 244 traced = fx.symbolic_trace(st) 245 246 f = BytesIO() 247 with PackageExporter(f) as pe: 248 pe.save_pickle("model", "model.pkl", traced) 249 250 f.seek(0) 251 pi = PackageImporter(f) 252 loaded_traced = pi.load_pickle("model", "model.pkl") 253 input = torch.rand(2, 3) 254 self.assertEqual(loaded_traced(input), traced(input)) 255 256 def test_dynamo_innermost_fn(self): 257 """ 258 Repro for https://github.com/pytorch/pytorch/issues/121198 . 259 """ 260 261 def f(x): 262 return x * 2 263 264 gm = torch.fx.symbolic_trace(f) 265 lazy_gm = torch.fx._lazy_graph_module._LazyGraphModule.from_graphmodule(gm) 266 267 wrapped_forward = torch._dynamo.disable(gm.forward) 268 got_inner_forward = torch._dynamo.eval_frame.innermost_fn(wrapped_forward) 269 assert hasattr(got_inner_forward, "__self__") 270 271 wrapped_lazy_forward = torch._dynamo.disable(lazy_gm.forward) 272 got_lazy_inner_forward = torch._dynamo.eval_frame.innermost_fn( 273 wrapped_lazy_forward 274 ) 275 assert hasattr(got_lazy_inner_forward, "__self__") 276 277 278if __name__ == "__main__": 279 run_tests() 280