1# Owner(s): ["module: onnx"] 2"""Simple API tests for the ONNX exporter.""" 3 4from __future__ import annotations 5 6import os 7 8import torch 9from torch.onnx._internal.exporter import testing as onnx_testing 10from torch.testing._internal import common_utils 11 12 13class SampleModel(torch.nn.Module): 14 def forward(self, x): 15 y = x + 1 16 z = y.relu() 17 return (y, z) 18 19 20class SampleModelTwoInputs(torch.nn.Module): 21 def forward(self, x, b): 22 y = x + b 23 z = y.relu() 24 return (y, z) 25 26 27class SampleModelForDynamicShapes(torch.nn.Module): 28 def forward(self, x, b): 29 return x.relu(), b.sigmoid() 30 31 32class TestExportAPIDynamo(common_utils.TestCase): 33 """Tests for the ONNX exporter API when dynamo=True.""" 34 35 def assert_export(self, *args, **kwargs): 36 onnx_program = torch.onnx.export(*args, **kwargs, dynamo=True) 37 assert onnx_program is not None 38 onnx_testing.assert_onnx_program(onnx_program) 39 40 def test_args_normalization_with_no_kwargs(self): 41 self.assert_export( 42 SampleModelTwoInputs(), 43 (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), 44 ) 45 46 def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): 47 self.assert_export( 48 SampleModelForDynamicShapes(), 49 (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), 50 dynamic_axes={ 51 "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, 52 "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, 53 }, 54 ) 55 56 def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): 57 self.assert_export( 58 SampleModelForDynamicShapes(), 59 (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), 60 dynamic_axes={ 61 "x": [0, 1, 2], 62 "b": [0, 1, 2], 63 }, 64 ) 65 66 def test_dynamic_axes_supports_partial_dynamic_shapes(self): 67 self.assert_export( 68 SampleModelForDynamicShapes(), 69 (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), 70 dynamic_axes={ 71 "b": [0, 1, 2], 72 }, 73 ) 74 75 def test_dynamic_axes_supports_output_names(self): 76 self.assert_export( 77 SampleModelForDynamicShapes(), 78 (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), 79 dynamic_axes={ 80 "b": [0, 1, 2], 81 }, 82 ) 83 onnx_program = torch.onnx.export( 84 SampleModelForDynamicShapes(), 85 ( 86 torch.randn(2, 2, 3), 87 torch.randn(2, 2, 3), 88 ), 89 input_names=["x", "b"], 90 output_names=["x_out", "b_out"], 91 dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]}, 92 dynamo=True, 93 ) 94 assert onnx_program is not None 95 onnx_testing.assert_onnx_program(onnx_program) 96 97 def test_saved_f_exists_after_export(self): 98 with common_utils.TemporaryFileName(suffix=".onnx") as path: 99 _ = torch.onnx.export( 100 SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True 101 ) 102 self.assertTrue(os.path.exists(path)) 103 104 def test_export_supports_script_module(self): 105 class ScriptModule(torch.nn.Module): 106 def forward(self, x): 107 return x 108 109 self.assert_export(torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),)) 110 111 def test_dynamic_shapes_with_fully_specified_axes(self): 112 exported_program = torch.export.export( 113 SampleModelForDynamicShapes(), 114 ( 115 torch.randn(2, 2, 3), 116 torch.randn(2, 2, 3), 117 ), 118 dynamic_shapes={ 119 "x": { 120 0: torch.export.Dim("customx_dim_0"), 121 1: torch.export.Dim("customx_dim_1"), 122 2: torch.export.Dim("customx_dim_2"), 123 }, 124 "b": { 125 0: torch.export.Dim("customb_dim_0"), 126 1: torch.export.Dim("customb_dim_1"), 127 2: torch.export.Dim("customb_dim_2"), 128 }, 129 }, 130 ) 131 132 self.assert_export(exported_program) 133 134 def test_partial_dynamic_shapes(self): 135 self.assert_export( 136 SampleModelForDynamicShapes(), 137 ( 138 torch.randn(2, 2, 3), 139 torch.randn(2, 2, 3), 140 ), 141 dynamic_shapes={ 142 "x": None, 143 "b": { 144 0: torch.export.Dim("customb_dim_0"), 145 1: torch.export.Dim("customb_dim_1"), 146 2: torch.export.Dim("customb_dim_2"), 147 }, 148 }, 149 ) 150 151 def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self): 152 os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1" 153 assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1" 154 155 class Nested(torch.nn.Module): 156 def forward(self, x): 157 (a0, a1), (b0, b1), (c0, c1, c2) = x 158 return a0 + a1 + b0 + b1 + c0 + c1 + c2 159 160 inputs = ( 161 (1, 2), 162 ( 163 torch.randn(4, 4), 164 torch.randn(4, 4), 165 ), 166 ( 167 torch.randn(4, 4), 168 torch.randn(4, 4), 169 torch.randn(4, 4), 170 ), 171 ) 172 173 onnx_program = torch.onnx.dynamo_export( 174 Nested(), 175 inputs, 176 export_options=torch.onnx.ExportOptions(dynamic_shapes=True), 177 ) 178 assert onnx_program is not None 179 onnx_testing.assert_onnx_program(onnx_program) 180 181 def test_refine_dynamic_shapes_with_onnx_export(self): 182 # NOTE: From test/export/test_export.py 183 184 # refine lower, upper bound 185 class TestRefineDynamicShapeModel(torch.nn.Module): 186 def forward(self, x, y): 187 if x.shape[0] >= 6 and y.shape[0] <= 16: 188 return x * 2.0, y + 1 189 190 inps = (torch.randn(16), torch.randn(12)) 191 dynamic_shapes = { 192 "x": (torch.export.Dim("dx"),), 193 "y": (torch.export.Dim("dy"),), 194 } 195 self.assert_export( 196 TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes 197 ) 198 199 def test_zero_output_aten_node(self): 200 class Model(torch.nn.Module): 201 def forward(self, x): 202 torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed") 203 return x + x 204 205 input = torch.randn(2) 206 self.assert_export(Model(), (input)) 207 208 209if __name__ == "__main__": 210 common_utils.run_tests() 211