xref: /aosp_15_r20/external/pytorch/test/onnx/exporter/test_api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2"""Simple API tests for the ONNX exporter."""
3
4from __future__ import annotations
5
6import os
7
8import torch
9from torch.onnx._internal.exporter import testing as onnx_testing
10from torch.testing._internal import common_utils
11
12
13class SampleModel(torch.nn.Module):
14    def forward(self, x):
15        y = x + 1
16        z = y.relu()
17        return (y, z)
18
19
20class SampleModelTwoInputs(torch.nn.Module):
21    def forward(self, x, b):
22        y = x + b
23        z = y.relu()
24        return (y, z)
25
26
27class SampleModelForDynamicShapes(torch.nn.Module):
28    def forward(self, x, b):
29        return x.relu(), b.sigmoid()
30
31
32class TestExportAPIDynamo(common_utils.TestCase):
33    """Tests for the ONNX exporter API when dynamo=True."""
34
35    def assert_export(self, *args, **kwargs):
36        onnx_program = torch.onnx.export(*args, **kwargs, dynamo=True)
37        assert onnx_program is not None
38        onnx_testing.assert_onnx_program(onnx_program)
39
40    def test_args_normalization_with_no_kwargs(self):
41        self.assert_export(
42            SampleModelTwoInputs(),
43            (torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
44        )
45
46    def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
47        self.assert_export(
48            SampleModelForDynamicShapes(),
49            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
50            dynamic_axes={
51                "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
52                "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
53            },
54        )
55
56    def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
57        self.assert_export(
58            SampleModelForDynamicShapes(),
59            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
60            dynamic_axes={
61                "x": [0, 1, 2],
62                "b": [0, 1, 2],
63            },
64        )
65
66    def test_dynamic_axes_supports_partial_dynamic_shapes(self):
67        self.assert_export(
68            SampleModelForDynamicShapes(),
69            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
70            dynamic_axes={
71                "b": [0, 1, 2],
72            },
73        )
74
75    def test_dynamic_axes_supports_output_names(self):
76        self.assert_export(
77            SampleModelForDynamicShapes(),
78            (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
79            dynamic_axes={
80                "b": [0, 1, 2],
81            },
82        )
83        onnx_program = torch.onnx.export(
84            SampleModelForDynamicShapes(),
85            (
86                torch.randn(2, 2, 3),
87                torch.randn(2, 2, 3),
88            ),
89            input_names=["x", "b"],
90            output_names=["x_out", "b_out"],
91            dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]},
92            dynamo=True,
93        )
94        assert onnx_program is not None
95        onnx_testing.assert_onnx_program(onnx_program)
96
97    def test_saved_f_exists_after_export(self):
98        with common_utils.TemporaryFileName(suffix=".onnx") as path:
99            _ = torch.onnx.export(
100                SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True
101            )
102            self.assertTrue(os.path.exists(path))
103
104    def test_export_supports_script_module(self):
105        class ScriptModule(torch.nn.Module):
106            def forward(self, x):
107                return x
108
109        self.assert_export(torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),))
110
111    def test_dynamic_shapes_with_fully_specified_axes(self):
112        exported_program = torch.export.export(
113            SampleModelForDynamicShapes(),
114            (
115                torch.randn(2, 2, 3),
116                torch.randn(2, 2, 3),
117            ),
118            dynamic_shapes={
119                "x": {
120                    0: torch.export.Dim("customx_dim_0"),
121                    1: torch.export.Dim("customx_dim_1"),
122                    2: torch.export.Dim("customx_dim_2"),
123                },
124                "b": {
125                    0: torch.export.Dim("customb_dim_0"),
126                    1: torch.export.Dim("customb_dim_1"),
127                    2: torch.export.Dim("customb_dim_2"),
128                },
129            },
130        )
131
132        self.assert_export(exported_program)
133
134    def test_partial_dynamic_shapes(self):
135        self.assert_export(
136            SampleModelForDynamicShapes(),
137            (
138                torch.randn(2, 2, 3),
139                torch.randn(2, 2, 3),
140            ),
141            dynamic_shapes={
142                "x": None,
143                "b": {
144                    0: torch.export.Dim("customb_dim_0"),
145                    1: torch.export.Dim("customb_dim_1"),
146                    2: torch.export.Dim("customb_dim_2"),
147                },
148            },
149        )
150
151    def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self):
152        os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1"
153        assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1"
154
155        class Nested(torch.nn.Module):
156            def forward(self, x):
157                (a0, a1), (b0, b1), (c0, c1, c2) = x
158                return a0 + a1 + b0 + b1 + c0 + c1 + c2
159
160        inputs = (
161            (1, 2),
162            (
163                torch.randn(4, 4),
164                torch.randn(4, 4),
165            ),
166            (
167                torch.randn(4, 4),
168                torch.randn(4, 4),
169                torch.randn(4, 4),
170            ),
171        )
172
173        onnx_program = torch.onnx.dynamo_export(
174            Nested(),
175            inputs,
176            export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
177        )
178        assert onnx_program is not None
179        onnx_testing.assert_onnx_program(onnx_program)
180
181    def test_refine_dynamic_shapes_with_onnx_export(self):
182        # NOTE: From test/export/test_export.py
183
184        # refine lower, upper bound
185        class TestRefineDynamicShapeModel(torch.nn.Module):
186            def forward(self, x, y):
187                if x.shape[0] >= 6 and y.shape[0] <= 16:
188                    return x * 2.0, y + 1
189
190        inps = (torch.randn(16), torch.randn(12))
191        dynamic_shapes = {
192            "x": (torch.export.Dim("dx"),),
193            "y": (torch.export.Dim("dy"),),
194        }
195        self.assert_export(
196            TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
197        )
198
199    def test_zero_output_aten_node(self):
200        class Model(torch.nn.Module):
201            def forward(self, x):
202                torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed")
203                return x + x
204
205        input = torch.randn(2)
206        self.assert_export(Model(), (input))
207
208
209if __name__ == "__main__":
210    common_utils.run_tests()
211