xref: /aosp_15_r20/external/pytorch/test/fx/test_lazy_graph_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: fx"]
2
3import contextlib
4import pickle
5from io import BytesIO
6from unittest.mock import patch
7
8import torch
9import torch._export
10from torch import fx
11from torch.fx._lazy_graph_module import (
12    _LazyGraphModule,
13    _make_graph_module,
14    _use_lazy_graph_module,
15)
16from torch.fx.experimental.proxy_tensor import make_fx
17from torch.package import PackageExporter, PackageImporter
18from torch.testing._internal.common_utils import run_tests, TestCase
19
20
21class TestLazyGraphModule(TestCase):
22    exit_stack = None
23
24    @classmethod
25    def setUpClass(cls):
26        cls.exit_stack = contextlib.ExitStack()
27        cls.exit_stack.enter_context(_use_lazy_graph_module(True))
28
29    @classmethod
30    def tearDownClass(cls):
31        cls.exit_stack.close()
32
33    @staticmethod
34    def replace_sin_with_cos(gm):
35        for n in gm.graph.nodes:
36            if n.target == "sin":
37                n.target = "cos"
38
39    def test_replace_sin_with_cos(self):
40        def f(x):
41            return x.sin()
42
43        x = torch.randn(2, 3)
44        gm = fx.symbolic_trace(f)
45        self.replace_sin_with_cos(gm)
46
47        gm.recompile()
48        expected = x.cos()
49        actual = gm(x)
50
51        self.assertTrue(torch.allclose(expected, actual))
52        code = gm.print_readable(False)
53        self.assertTrue("cos()" in code)
54        self.assertTrue(isinstance(gm, _LazyGraphModule))
55
56    def test_call_forward_directly(self):
57        def f(x):
58            return x.sin()
59
60        x = torch.randn(2, 3)
61        gm = fx.symbolic_trace(f)
62        self.assertTrue(isinstance(gm, _LazyGraphModule))
63        self.replace_sin_with_cos(gm)
64        gm.recompile()
65        expected = x.cos()
66        actual = gm.forward(x)
67
68        self.assertTrue(torch.allclose(expected, actual))
69
70    def test_needs_recompile(self):
71        """
72        Make sure needs_recompile() return the corrent state.
73        """
74
75        def f(x):
76            return x.sin()
77
78        gm = fx.symbolic_trace(f)
79        self.assertTrue(isinstance(gm, _LazyGraphModule))
80        self.assertTrue(gm._needs_recompile())
81        gm(torch.randn(2, 3))
82        self.assertFalse(gm._needs_recompile())
83
84    def test_multi_recompile(self):
85        """
86        Cover the case that multiple recompilation happens.
87        """
88
89        def f(x):
90            return x.sin()
91
92        gm = fx.symbolic_trace(f)
93        self.assertTrue(isinstance(gm, _LazyGraphModule))
94        self.assertTrue(gm._needs_recompile())
95        x = torch.randn(2, 3)
96        # trigger the first recompilation
97        self.assertTrue(torch.allclose(x.sin(), gm(x)))
98        self.assertFalse(gm._needs_recompile())
99
100        self.replace_sin_with_cos(gm)
101        self.assertFalse(gm._needs_recompile())
102        gm.recompile()
103        self.assertTrue(gm._needs_recompile())
104        # trigger the second recompilation
105        self.assertTrue(torch.allclose(x.cos(), gm(x)))
106        self.assertFalse(gm._needs_recompile())
107
108    def test_accessing_code_cause_recompiling(self):
109        """
110        Make sure we recompile if we have not done that yet when we access the code
111        property of a GraphModule.
112        """
113
114        def f(x):
115            return x.sin()
116
117        gm = fx.symbolic_trace(f)
118        self.assertTrue(isinstance(gm, _LazyGraphModule))
119        self.assertTrue(gm._needs_recompile())
120        # should trigger a recompilation
121        code = gm.code
122        self.assertTrue("sin" in code)
123        self.assertFalse(gm._needs_recompile())
124
125    def test_graph_module_str(self):
126        def f(x):
127            return x.sin()
128
129        gm = fx.symbolic_trace(f)
130        self.assertTrue(isinstance(gm, _LazyGraphModule))
131        self.assertTrue("sin" in str(gm))
132
133    def test_recapture_with_make_fx(self):
134        def f(x):
135            return x.sin()
136
137        gm = fx.symbolic_trace(f)
138        self.assertTrue(isinstance(gm, _LazyGraphModule))
139        self.assertTrue(gm._needs_recompile())
140        gm2 = make_fx(gm)(torch.randn(2, 3))
141        self.assertTrue(isinstance(gm2, _LazyGraphModule))
142        self.assertTrue(gm2._needs_recompile())
143
144        # make_fx will cal foward method of gm. That clears the _needs_recompile()
145        # flag.
146        self.assertFalse(gm._needs_recompile())
147
148    def test_recapture_with_symbolic_trace(self):
149        def f(x):
150            return x.sin()
151
152        gm = fx.symbolic_trace(f)
153        self.assertTrue(isinstance(gm, _LazyGraphModule))
154        self.assertTrue(gm._needs_recompile())
155        gm2 = fx.symbolic_trace(gm)
156
157        # the lazy recompilcation is already realized. We realize the
158        # recompilation in the beginning of symbolic_trace since symbolic_trace can not
159        # handle the tracing of lazy recompilation.
160        self.assertFalse(gm._needs_recompile())
161        self.assertTrue(gm2._needs_recompile())
162
163    def test_recapture_with_dynamo(self):
164        def f(x):
165            return x.sin()
166
167        gm = fx.symbolic_trace(f)
168        self.assertTrue(isinstance(gm, _LazyGraphModule))
169        self.assertTrue(gm._needs_recompile())
170        torch.compile(gm)(torch.rand(2, 3))
171
172        # dynamo calls gm.forward with eval hook installed. That will trigger
173        # the real recompilation.
174        self.assertFalse(gm._needs_recompile())
175
176    def test_save_lazy_foward(self):
177        """
178        Save the lazy forward method and call it repeatly. Make sure we
179        don't recompile for each such call.
180        """
181
182        def f(x):
183            return x.sin()
184
185        orig_gm_recompile = fx.GraphModule.recompile
186        recompile_count = 0
187
188        def mock_gm_recompile(self):
189            nonlocal recompile_count
190            recompile_count += 1
191            return orig_gm_recompile(self)
192
193        with patch.object(fx.GraphModule, "recompile", mock_gm_recompile):
194            gm = fx.symbolic_trace(f)
195            self.assertTrue(isinstance(gm, _LazyGraphModule))
196            saved_fwd = gm.forward
197
198            x = torch.rand(2, 3)
199            for _ in range(10):
200                saved_fwd(x)
201
202        self.assertEqual(recompile_count, 1)
203
204    def test_pickle(self):
205        """
206        Fx graph cache need the ability to pickle GraphModule/_LazyGraphModule.
207        """
208
209        def f(x):
210            return x.sin()
211
212        gm = fx.symbolic_trace(f)
213        self.assertTrue(isinstance(gm, _LazyGraphModule))
214        serialized = pickle.dumps(gm)
215        gm2 = pickle.loads(serialized)
216        self.assertTrue(isinstance(gm2, _LazyGraphModule))
217        self.assertTrue("sin" in gm2.code)
218
219    def test_make_graph_module(self):
220        gm = fx.symbolic_trace(lambda x: x.sin())
221        self.assertTrue(isinstance(gm, _LazyGraphModule))
222
223        gm1 = _make_graph_module(
224            gm, gm.graph, class_name="MyGraphModule", graph_module_cls=fx.GraphModule
225        )
226        self.assertFalse(isinstance(gm1, _LazyGraphModule))
227        self.assertTrue(gm1.__class__.__name__ == "MyGraphModule")
228
229        gm2 = _make_graph_module(gm, gm.graph)
230        self.assertTrue(isinstance(gm2, _LazyGraphModule))
231        self.assertTrue(gm2.__class__.__name__ == "GraphModule")
232
233    def test_package_fx_simple(self):
234        """
235        Copied from test/package/test_package_fx.py to make sure LazyGraphModule
236        works with torch.package.
237        """
238
239        class SimpleTest(torch.nn.Module):
240            def forward(self, x):
241                return torch.relu(x + 3.0)
242
243        st = SimpleTest()
244        traced = fx.symbolic_trace(st)
245
246        f = BytesIO()
247        with PackageExporter(f) as pe:
248            pe.save_pickle("model", "model.pkl", traced)
249
250        f.seek(0)
251        pi = PackageImporter(f)
252        loaded_traced = pi.load_pickle("model", "model.pkl")
253        input = torch.rand(2, 3)
254        self.assertEqual(loaded_traced(input), traced(input))
255
256    def test_dynamo_innermost_fn(self):
257        """
258        Repro for https://github.com/pytorch/pytorch/issues/121198 .
259        """
260
261        def f(x):
262            return x * 2
263
264        gm = torch.fx.symbolic_trace(f)
265        lazy_gm = torch.fx._lazy_graph_module._LazyGraphModule.from_graphmodule(gm)
266
267        wrapped_forward = torch._dynamo.disable(gm.forward)
268        got_inner_forward = torch._dynamo.eval_frame.innermost_fn(wrapped_forward)
269        assert hasattr(got_inner_forward, "__self__")
270
271        wrapped_lazy_forward = torch._dynamo.disable(lazy_gm.forward)
272        got_lazy_inner_forward = torch._dynamo.eval_frame.innermost_fn(
273            wrapped_lazy_forward
274        )
275        assert hasattr(got_lazy_inner_forward, "__self__")
276
277
278if __name__ == "__main__":
279    run_tests()
280