xref: /aosp_15_r20/external/pytorch/test/export/test_experimental.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2# flake8: noqa
3import unittest
4from typing import Dict, List, Tuple
5
6import torch
7import torch._dynamo
8from torch._dynamo.test_case import run_tests, TestCase
9from torch._export.wrappers import _mark_strict_experimental
10from torch._functorch.aot_autograd import aot_export_module
11from torch.export._trace import _convert_ts_to_export_experimental
12from torch.export.experimental import _export_forward_backward
13from torch.testing import FileCheck
14
15
16@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
17class TestExperiment(TestCase):
18    def test_with_buffer_as_submodule(self):
19        @_mark_strict_experimental
20        class B(torch.nn.Module):
21            def __init__(self) -> None:
22                super().__init__()
23                self.buffer1 = torch.nn.Buffer(torch.ones(3))
24
25            def forward(self, x):
26                y = x + 2
27                y.add_(4)
28                # this doesnt' work today with HOO
29                # self.buffer1.add_(6)
30                buffer_updated = self.buffer1 + 6
31                return x.sum() + y.sum() + buffer_updated.sum()
32
33        class M(torch.nn.Module):
34            def __init__(self) -> None:
35                super().__init__()
36                self.submodule = B()
37
38            def forward(self, x):
39                x_v2 = x.sin()
40                return (self.submodule(x_v2), x + 3)
41
42        inp = torch.randn(3)
43        ep = torch.export.export(M(), (inp,), strict=False)
44        self.assertExpectedInline(
45            str(ep.graph_module.code.strip()),
46            """\
47def forward(self, b_submodule_buffer1, x):
48    sin = torch.ops.aten.sin.default(x)
49    strict_graph_0 = self.strict_graph_0
50    strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1));  strict_graph_0 = sin = b_submodule_buffer1 = None
51    getitem_2 = strict_mode[0];  strict_mode = None
52    add = torch.ops.aten.add.Tensor(x, 3);  x = None
53    return (getitem_2, add)""",
54        )
55
56        self.assertExpectedInline(
57            str(ep.graph_module.strict_graph_0.code.strip()),
58            """\
59def forward(self, arg0_1, arg1_1):
60    add = torch.ops.aten.add.Tensor(arg0_1, 2)
61    add_1 = torch.ops.aten.add.Tensor(add, 4);  add = None
62    add_2 = torch.ops.aten.add.Tensor(arg1_1, 6);  arg1_1 = None
63    sum_1 = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
64    sum_2 = torch.ops.aten.sum.default(add_1);  add_1 = None
65    add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
66    sum_3 = torch.ops.aten.sum.default(add_2);  add_2 = None
67    add_4 = torch.ops.aten.add.Tensor(add_3, sum_3);  add_3 = sum_3 = None
68    return (add_4,)""",
69        )
70
71        eager_mod = M()
72        ep = torch.export.export(eager_mod, (inp,), strict=True)
73
74        graph_res_1, graph_res_2 = ep.module()(inp)
75        eager_res_1, eager_res_2 = eager_mod(inp)
76
77        self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
78        self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
79
80        graph_res_1, graph_res_2 = ep.module()(inp)
81        eager_res_1, eager_res_2 = eager_mod(inp)
82
83        self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
84        self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
85
86    def test_mark_strict_with_container_type(self):
87        @_mark_strict_experimental
88        class B(torch.nn.Module):
89            def __init__(self) -> None:
90                super().__init__()
91
92            def forward(self, x):
93                x0 = x[0][0]
94                return x0.sum()
95
96        class M(torch.nn.Module):
97            def __init__(self) -> None:
98                super().__init__()
99                self.submodule = B()
100
101            def forward(self, x):
102                return self.submodule(x)
103
104        inp = ((torch.randn(3),),)
105        with self.assertRaisesRegex(
106            RuntimeError, "strict_mode HOO doesn't work unless"
107        ):
108            ep = torch.export.export(M(), inp, strict=False)
109
110    def test_torchscript_module_export(self):
111        class M(torch.nn.Module):
112            def forward(self, x):
113                return x.cos() + x.sin()
114
115        model_to_trace = M()
116        inps = (torch.randn(4, 4),)
117        traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
118
119        exported_module = _convert_ts_to_export_experimental(
120            traced_module_by_torchscript, inps
121        )
122
123        self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps)))
124
125    def test_torchscript_module_export_single_input(self):
126        class M(torch.nn.Module):
127            def forward(self, x):
128                return x.cos() + x.sin()
129
130        model_to_trace = M()
131        inps = torch.randn(4, 4)
132        traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
133
134        exported_module = _convert_ts_to_export_experimental(
135            traced_module_by_torchscript, inps
136        )
137
138        self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps)))
139
140    def test_torchscript_module_export_various_inputs_with_annotated_input_names(self):
141        def _check_equality_and_annotations(m_func, inps):
142            # Original module.
143            model_to_trace = m_func()
144
145            # ExportedProgram from TorchScript module.
146            traced_module_by_torchscript = torch.jit.trace(
147                m_func(), example_inputs=inps
148            )
149            exported_module = _convert_ts_to_export_experimental(
150                traced_module_by_torchscript, inps
151            )
152
153            # ExportedProgram from original module.
154            original_exported_module = torch.export.export(m_func(), inps)
155
156            # Check whether input annotations are the same as tracing the original module.
157            orig_ph_name_list = [
158                n.name
159                for n in original_exported_module.graph.nodes
160                if n.op == "placeholder"
161            ]
162            ph_name_list = [
163                n.name for n in exported_module.graph.nodes if n.op == "placeholder"
164            ]
165            self.assertEqual(orig_ph_name_list, ph_name_list)
166
167            # Check results equality.
168            self.assertTrue(
169                torch.allclose(exported_module(*inps), model_to_trace(*inps))
170            )
171
172        # Tuple
173        class MTuple(torch.nn.Module):
174            def forward(self, x: Tuple[torch.Tensor]):
175                return x[0] + x[1]
176
177        _check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),))
178
179        # List
180        class MList(torch.nn.Module):
181            def forward(self, x: List[torch.Tensor]):
182                return x[0] + x[1]
183
184        _check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],))
185
186        # Dict
187        class MDict(torch.nn.Module):
188            def forward(self, x: Dict[str, torch.Tensor]):
189                return x["0"] + x["1"]
190
191        _check_equality_and_annotations(
192            MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
193        )
194
195    def test_joint_basic(self) -> None:
196        class Module(torch.nn.Module):
197            def __init__(self) -> None:
198                super().__init__()
199                self.linear = torch.nn.Linear(3, 3)
200                self.loss = torch.nn.CrossEntropyLoss()
201
202            def forward(self, x):
203                return self.loss(
204                    self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
205                )
206
207        m = Module()
208        example_inputs = (torch.randn(3),)
209        m(*example_inputs)
210        ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True)
211        joint_ep = _export_forward_backward(ep)
212        self.assertExpectedInline(
213            str(joint_ep.graph_module.code).strip(),
214            """\
215def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
216    view = torch.ops.aten.view.default(x, [1, 3]);  x = None
217    permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]);  p_linear_weight = None
218    addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute);  p_linear_bias = permute = None
219    view_1 = torch.ops.aten.view.default(addmm, [3]);  addmm = None
220    _softmax = torch.ops.aten._softmax.default(view_1, 0, False);  view_1 = None
221    alias = torch.ops.aten.alias.default(_softmax)
222    alias_1 = torch.ops.aten.alias.default(alias);  alias = None
223    clone = torch.ops.aten.clone.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
224    alias_2 = torch.ops.aten.alias.default(clone);  clone = None
225    alias_3 = torch.ops.aten.alias.default(alias_2);  alias_2 = None
226    alias_4 = torch.ops.aten.alias.default(alias_3);  alias_3 = None
227    _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False);  _softmax = None
228    alias_5 = torch.ops.aten.alias.default(_log_softmax)
229    alias_6 = torch.ops.aten.alias.default(alias_5);  alias_5 = None
230    mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4);  _log_softmax = None
231    sum_1 = torch.ops.aten.sum.dim_IntList(mul, []);  mul = None
232    neg = torch.ops.aten.neg.default(sum_1);  sum_1 = None
233    div = torch.ops.aten.div.Scalar(neg, 1);  neg = None
234    full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
235    div_1 = torch.ops.aten.div.Scalar(full_like, 1);  full_like = None
236    neg_1 = torch.ops.aten.neg.default(div_1);  div_1 = None
237    expand = torch.ops.aten.expand.default(neg_1, [3]);  neg_1 = None
238    mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4);  expand = alias_4 = None
239    alias_7 = torch.ops.aten.alias.default(alias_6);  alias_6 = None
240    alias_8 = torch.ops.aten.alias.default(alias_7);  alias_7 = None
241    exp = torch.ops.aten.exp.default(alias_8);  alias_8 = None
242    sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
243    mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2);  exp = sum_2 = None
244    sub = torch.ops.aten.sub.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
245    alias_9 = torch.ops.aten.alias.default(alias_1);  alias_1 = None
246    alias_10 = torch.ops.aten.alias.default(alias_9);  alias_9 = None
247    mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10);  sub = None
248    sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
249    mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3);  alias_10 = sum_3 = None
250    sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None
251    view_2 = torch.ops.aten.view.default(sub_1, [1, 3]);  sub_1 = None
252    permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
253    mm = torch.ops.aten.mm.default(permute_1, view);  permute_1 = view = None
254    permute_2 = torch.ops.aten.permute.default(mm, [1, 0]);  mm = None
255    sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True);  view_2 = None
256    view_3 = torch.ops.aten.view.default(sum_4, [3]);  sum_4 = None
257    permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]);  permute_2 = None
258    return (div, permute_3, view_3)""",
259        )
260        ep = joint_ep.run_decompositions()
261        self.assertExpectedInline(
262            str(ep.graph_module.code).strip(),
263            """\
264def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
265    view = torch.ops.aten.view.default(x, [1, 3]);  x = None
266    permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]);  p_linear_weight = None
267    addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute);  p_linear_bias = permute = None
268    view_1 = torch.ops.aten.view.default(addmm, [3]);  addmm = None
269    _softmax = torch.ops.aten._softmax.default(view_1, 0, False);  view_1 = None
270    alias = torch.ops.aten.alias.default(_softmax)
271    alias_1 = torch.ops.aten.alias.default(alias);  alias = None
272    clone = torch.ops.aten.clone.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
273    alias_2 = torch.ops.aten.alias.default(clone);  clone = None
274    alias_3 = torch.ops.aten.alias.default(alias_2);  alias_2 = None
275    alias_4 = torch.ops.aten.alias.default(alias_3);  alias_3 = None
276    _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False);  _softmax = None
277    alias_5 = torch.ops.aten.alias.default(_log_softmax)
278    alias_6 = torch.ops.aten.alias.default(alias_5);  alias_5 = None
279    mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4);  _log_softmax = None
280    sum_1 = torch.ops.aten.sum.dim_IntList(mul, []);  mul = None
281    neg = torch.ops.aten.neg.default(sum_1);  sum_1 = None
282    div = torch.ops.aten.div.Scalar(neg, 1);  neg = None
283    full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
284    div_1 = torch.ops.aten.div.Scalar(full_like, 1);  full_like = None
285    neg_1 = torch.ops.aten.neg.default(div_1);  div_1 = None
286    expand = torch.ops.aten.expand.default(neg_1, [3]);  neg_1 = None
287    mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4);  expand = alias_4 = None
288    alias_7 = torch.ops.aten.alias.default(alias_6);  alias_6 = None
289    alias_8 = torch.ops.aten.alias.default(alias_7);  alias_7 = None
290    exp = torch.ops.aten.exp.default(alias_8);  alias_8 = None
291    sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
292    mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2);  exp = sum_2 = None
293    sub = torch.ops.aten.sub.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
294    alias_9 = torch.ops.aten.alias.default(alias_1);  alias_1 = None
295    alias_10 = torch.ops.aten.alias.default(alias_9);  alias_9 = None
296    mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10);  sub = None
297    sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
298    mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3);  alias_10 = sum_3 = None
299    sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None
300    view_2 = torch.ops.aten.view.default(sub_1, [1, 3]);  sub_1 = None
301    permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
302    mm = torch.ops.aten.mm.default(permute_1, view);  permute_1 = view = None
303    permute_2 = torch.ops.aten.permute.default(mm, [1, 0]);  mm = None
304    sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True);  view_2 = None
305    view_3 = torch.ops.aten.view.default(sum_4, [3]);  sum_4 = None
306    permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]);  permute_2 = None
307    return (div, permute_3, view_3)""",
308        )
309
310    def test_joint_dynamic(self) -> None:
311        from torch.export import Dim
312
313        class Module(torch.nn.Module):
314            def __init__(self) -> None:
315                super().__init__()
316                self.y = torch.nn.Parameter(torch.randn(3))
317
318            def forward(self, x):
319                x = torch.ones(x.shape[0], 3)
320                return (self.y + x).sum()
321
322        m = Module()
323        example_inputs = (torch.randn(3),)
324        m(*example_inputs)
325        ep = torch.export._trace._export(
326            m, example_inputs, pre_dispatch=True, dynamic_shapes={"x": {0: Dim("x0")}}
327        )
328        joint_ep = _export_forward_backward(ep)
329
330
331if __name__ == "__main__":
332    run_tests()
333