xref: /aosp_15_r20/external/pytorch/test/export/test_hop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2# flake8: noqa
3import copy
4import io
5import unittest
6
7import torch
8import torch._dynamo as torchdynamo
9import torch.utils._pytree as pytree
10from torch._dynamo.test_case import TestCase
11from torch.export import export, load, save
12from torch.export._trace import _export
13from torch.testing._internal.common_device_type import (
14    instantiate_device_type_tests,
15    ops,
16)
17from torch.testing._internal.common_utils import (
18    IS_WINDOWS,
19    run_tests,
20    TestCase as TorchTestCase,
21)
22from torch.testing._internal.hop_db import (
23    hop_db,
24    hop_that_doesnt_have_opinfo_test_allowlist,
25)
26
27
28hop_tests = []
29
30for op_info in hop_db:
31    op_info_hop_name = op_info.name
32    if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist:
33        continue
34    hop_tests.append(op_info)
35
36
37class TestHOPGeneric(TestCase):
38    def test_all_hops_have_op_info(self):
39        from torch._ops import _higher_order_ops
40
41        hops_that_have_op_info = set([k.name for k in hop_db])
42        all_hops = _higher_order_ops.keys()
43
44        missing_ops = []
45
46        for op in all_hops:
47            if (
48                op not in hops_that_have_op_info
49                and op not in hop_that_doesnt_have_opinfo_test_allowlist
50            ):
51                missing_ops.append(op)
52
53        self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}")
54
55
56@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
57@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
58class TestHOP(TestCase):
59    def _compare(self, eager_model, export, args, kwargs):
60        eager_args = copy.deepcopy(args)
61        eager_kwargs = copy.deepcopy(kwargs)
62        export_args = copy.deepcopy(args)
63        export_kwargs = copy.deepcopy(kwargs)
64
65        flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs))
66        flat_loaded_outputs = pytree.tree_leaves(
67            export.module()(*export_args, **export_kwargs)
68        )
69
70        for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
71            self.assertEqual(type(orig), type(loaded))
72            self.assertEqual(orig, loaded)
73
74    @ops(hop_tests, allowed_dtypes=(torch.float,))
75    def test_aot_export(self, device, dtype, op):
76        class Foo(torch.nn.Module):
77            def forward(self, *args):
78                return op.op(*args)
79
80        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
81        for inp in sample_inputs_itr:
82            model = Foo()
83            input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
84            args = (*input, *inp.args)
85            kwargs = inp.kwargs
86            ep = export(model, args, kwargs)
87            self._compare(model, ep, args, kwargs)
88
89    @ops(hop_tests, allowed_dtypes=(torch.float,))
90    def test_pre_dispatch_export(self, device, dtype, op):
91        class Foo(torch.nn.Module):
92            def forward(self, *args):
93                return op.op(*args)
94
95        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
96        for inp in sample_inputs_itr:
97            model = Foo()
98            input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
99            args = (*input, *inp.args)
100            kwargs = inp.kwargs
101            ep = _export(model, args, kwargs, pre_dispatch=True)
102            self._compare(model, ep, args, kwargs)
103
104    @ops(hop_tests, allowed_dtypes=(torch.float,))
105    def test_retrace_export(self, device, dtype, op):
106        class Foo(torch.nn.Module):
107            def forward(self, *args):
108                return op.op(*args)
109
110        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
111        for inp in sample_inputs_itr:
112            model = Foo()
113            input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
114            args = (*input, *inp.args)
115            kwargs = inp.kwargs
116            ep = _export(model, args, kwargs, pre_dispatch=True)
117            ep = ep.run_decompositions()
118            self._compare(model, ep, args, kwargs)
119
120    @ops(hop_tests, allowed_dtypes=(torch.float,))
121    def test_serialize_export(self, device, dtype, op):
122        class Foo(torch.nn.Module):
123            def forward(self, *args):
124                return op.op(*args)
125
126        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
127        for inp in sample_inputs_itr:
128            model = Foo()
129            input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
130            args = (*input, *inp.args)
131            kwargs = inp.kwargs
132            ep = _export(model, args, kwargs, pre_dispatch=True)
133            ep = ep.run_decompositions()
134            buffer = io.BytesIO()
135            save(ep, buffer)
136            buffer.seek(0)
137            ep = load(buffer)
138            if "while_loop" in str(op):
139                # while_loop's arguments are cast into list after deserailize
140                # but while_loop expects it to still be tuple
141                with self.assertRaisesRegex(
142                    RuntimeError, "carried_inputs must be a tuple"
143                ):
144                    self._compare(model, ep, args, kwargs)
145            else:
146                self._compare(model, ep, args, kwargs)
147
148
149instantiate_device_type_tests(TestHOP, globals())
150
151if __name__ == "__main__":
152    run_tests()
153