xref: /aosp_15_r20/external/pytorch/test/onnx/test_operators.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3"""
4Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
5          --no-onnx: no onnx python dependency
6          --produce-onnx-test-data: generate onnx test data
7          --accept: accept onnx updates and overwrite models
8"""
9
10import glob
11import inspect
12import io
13import itertools
14import operator
15import os
16import shutil
17import tempfile
18
19# Full diff for expect files
20import unittest
21
22from pytorch_test_common import (
23    BATCH_SIZE,
24    flatten,
25    RNN_HIDDEN_SIZE,
26    RNN_INPUT_SIZE,
27    RNN_SEQUENCE_LENGTH,
28)
29
30import torch
31import torch.nn as nn
32import torch.nn.functional as F
33import torch.onnx
34from torch.autograd import Function, Variable
35from torch.nn import functional, Module
36from torch.onnx._internal import diagnostics
37from torch.onnx.symbolic_helper import (
38    _get_tensor_dim_size,
39    _get_tensor_sizes,
40    parse_args,
41)
42from torch.testing._internal import common_utils
43from torch.testing._internal.common_utils import skipIfNoLapack
44
45
46unittest.TestCase.maxDiff = None
47
48_onnx_test = False  # flag to produce onnx test cases.
49_onnx_dep = True  # flag to import onnx package.
50
51
52def export_to_pbtxt(model, inputs, *args, **kwargs):
53    return torch.onnx.export_to_pretty_string(
54        model, inputs, *args, google_printer=True, **kwargs
55    )
56
57
58def export_to_pb(model, inputs, *args, **kwargs):
59    f = io.BytesIO()
60    with torch.no_grad():
61        torch.onnx.export(model, inputs, f, *args, **kwargs)
62    return f.getvalue()
63
64
65class FuncModule(Module):
66    def __init__(self, f, params=None):
67        if params is None:
68            params = ()
69        super().__init__()
70        self.f = f
71        self.params = nn.ParameterList(list(params))
72
73    def forward(self, *args):
74        return self.f(*itertools.chain(args, self.params))
75
76
77class TestOperators(common_utils.TestCase):
78    def setUp(self):
79        super().setUp()
80        diagnostics.engine.clear()
81
82    def assertONNX(self, f, args, params=None, **kwargs):
83        if params is None:
84            params = ()
85        if isinstance(f, nn.Module):
86            m = f
87        else:
88            m = FuncModule(f, params)
89        m.eval()
90        onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs)
91        subname = kwargs.pop("subname", None)
92        self.assertExpected(onnx_model_pbtxt, subname)
93        if _onnx_dep:
94            onnx_model_pb = export_to_pb(m, args, **kwargs)
95            import onnx
96            import onnx.checker
97            import onnx.numpy_helper
98            import onnx_test_common
99
100            model_def = onnx.ModelProto.FromString(onnx_model_pb)
101            onnx.checker.check_model(model_def)
102            if _onnx_test:
103                test_function = inspect.stack()[1][0].f_code.co_name
104                test_name = test_function[0:4] + "_operator" + test_function[4:]
105                output_dir = os.path.join(
106                    onnx_test_common.pytorch_operator_dir, test_name
107                )
108                # Assume:
109                #     1) the old test should be delete before the test.
110                #     2) only one assertONNX in each test, otherwise will override the data.
111                assert not os.path.exists(output_dir), f"{output_dir} should not exist!"
112                os.makedirs(output_dir)
113                with open(os.path.join(output_dir, "model.onnx"), "wb") as file:
114                    file.write(model_def.SerializeToString())
115                data_dir = os.path.join(output_dir, "test_data_set_0")
116                os.makedirs(data_dir)
117                if isinstance(args, Variable):
118                    args = (args,)
119                for index, var in enumerate(flatten(args)):
120                    tensor = onnx.numpy_helper.from_array(var.data.numpy())
121                    with open(
122                        os.path.join(data_dir, f"input_{index}.pb"), "wb"
123                    ) as file:
124                        file.write(tensor.SerializeToString())
125                outputs = m(*args)
126                if isinstance(outputs, Variable):
127                    outputs = (outputs,)
128                for index, var in enumerate(flatten(outputs)):
129                    tensor = onnx.numpy_helper.from_array(var.data.numpy())
130                    with open(
131                        os.path.join(data_dir, f"output_{index}.pb"), "wb"
132                    ) as file:
133                        file.write(tensor.SerializeToString())
134
135    def assertONNXRaises(self, err, f, args, params=None, **kwargs):
136        if params is None:
137            params = ()
138        if isinstance(f, nn.Module):
139            m = f
140        else:
141            m = FuncModule(f, params)
142        self.assertExpectedRaises(err, lambda: export_to_pbtxt(m, args, **kwargs))
143
144    def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs):
145        if params is None:
146            params = ()
147        if isinstance(f, nn.Module):
148            m = f
149        else:
150            m = FuncModule(f, params)
151        with self.assertRaisesRegex(err, reg):
152            export_to_pbtxt(m, args, **kwargs)
153
154    def test_basic(self):
155        x = torch.tensor([0.4], requires_grad=True)
156        y = torch.tensor([0.7], requires_grad=True)
157        self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y))
158
159    def test_view(self):
160        x = torch.tensor([0.0], requires_grad=True)
161        self.assertONNX(lambda x: x.view(1, 1), x)
162
163    def test_index(self):
164        x = torch.tensor([[0.0]], requires_grad=True)
165        self.assertONNX(lambda x: x[0], x)
166
167    def test_type_as(self):
168        x = torch.tensor([0.0], requires_grad=True)
169        self.assertONNX(lambda x: x.type_as(x), x)
170
171    def test_addconstant(self):
172        x = torch.randn(2, 3, requires_grad=True).double()
173        self.assertONNX(lambda x: x + 1, x)
174
175    def test_add_broadcast(self):
176        x = torch.randn(2, 3, requires_grad=True).double()
177        y = torch.randn(3, requires_grad=True).double()
178        self.assertONNX(operator.add, (x, y))
179
180    def test_add_left_broadcast(self):
181        x = torch.randn(3, requires_grad=True).double()
182        y = torch.randn(2, 3, requires_grad=True).double()
183        self.assertONNX(operator.add, (x, y))
184
185    def test_add_size1_broadcast(self):
186        x = torch.randn(2, 3, requires_grad=True).double()
187        y = torch.randn(2, 1, requires_grad=True).double()
188        self.assertONNX(operator.add, (x, y))
189
190    def test_add_size1_right_broadcast(self):
191        x = torch.randn(2, 3, requires_grad=True).double()
192        y = torch.randn(3, requires_grad=True).double()
193        self.assertONNX(operator.add, (x, y))
194
195    def test_add_size1_singleton_broadcast(self):
196        x = torch.randn(2, 3, requires_grad=True).double()
197        y = torch.randn(1, 3, requires_grad=True).double()
198        self.assertONNX(operator.add, (x, y))
199
200    def test_rsub(self):
201        x = torch.randn(2, 3, requires_grad=True).double()
202        self.assertONNX(lambda x: 1 - x, (x,))
203
204    def test_mul_bool(self):
205        x = torch.tensor([True, False, True, False])
206        y = torch.tensor([True, True, False, False])
207        self.assertONNX(lambda x, y: torch.mul(x, y), (x, y))
208
209    def test_mul_fp_bool(self):
210        x = torch.tensor([9.4, 1.7, 3.6])
211        y = torch.tensor([True, True, False])
212        self.assertONNX(lambda x, y: torch.mul(x, y), (x, y))
213
214    def test_transpose(self):
215        x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
216        self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x)
217
218    def test_chunk(self):
219        x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
220        self.assertONNX(lambda x: x.chunk(2), x)
221
222    def test_split(self):
223        x = torch.tensor(
224            [[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]
225        )
226        self.assertONNX(lambda x: torch.split(x, 2, 1), x)
227
228    def test_split_with_sizes(self):
229        x = torch.tensor(
230            [[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]
231        )
232        self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x)
233
234    def test_concat2(self):
235        x = torch.randn(2, 3)
236        y = torch.randn(2, 3)
237        self.assertONNX(lambda inputs: torch.cat(inputs, 1), ((x, y),))
238
239    def test_mm(self):
240        m1 = torch.randn(2, 3, requires_grad=True)
241        m2 = torch.randn(3, 4, requires_grad=True)
242        self.assertONNX(torch.mm, (m1, m2))
243
244    def test_addmm(self):
245        m1 = torch.randn(2, 3, requires_grad=True)
246        m2 = torch.randn(3, 4, requires_grad=True)
247        m3 = torch.randn(4, requires_grad=True)
248        self.assertONNX(
249            lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3)
250        )
251
252    def test_permute2(self):
253        x = torch.tensor([[[[[[0.0]]]]]], requires_grad=True)
254        self.assertONNX(lambda x: x.permute(0, 1, 4, 2, 5, 3), x)
255
256    def test_pad(self):
257        x = torch.tensor(
258            [[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True
259        )
260        self.assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x)
261
262    def test_params(self):
263        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
264        y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
265        self.assertONNX(
266            lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))),
267            x,
268            params=(y,),
269            keep_initializers_as_inputs=True,
270        )
271
272    def test_params_onnx_irv4(self):
273        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
274        y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
275        self.assertONNX(
276            lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))),
277            x,
278            params=(y,),
279            keep_initializers_as_inputs=False,
280        )
281
282    def test_symbolic_mismatch(self):
283        class MyFun(Function):
284            @staticmethod
285            def symbolic(g, x):
286                # The inside of this function should never be invoked, because
287                # we will fail due to an argument mismatch first.
288                raise AssertionError
289
290            @staticmethod
291            def forward(ctx, x, y):
292                return x + y
293
294        x = torch.ones(2, 2)
295        y = torch.ones(2, 2)
296        # NB: Don't use expect test here, the type error wobbles depending
297        # on Python version
298        with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
299            export_to_pbtxt(FuncModule(MyFun().apply), (x, y))
300
301    # TODO: Do an nn style test for these
302    def test_batchnorm(self):
303        x = torch.ones(2, 2, 2, 2, requires_grad=True)
304        self.assertONNX(nn.BatchNorm2d(2), x, keep_initializers_as_inputs=True)
305
306    def test_batchnorm_onnx_irv4(self):
307        x = torch.ones(2, 2, 2, 2, requires_grad=True)
308        self.assertONNX(nn.BatchNorm2d(2), x)
309
310    def test_batchnorm_1d(self):
311        x = torch.ones(2, 2, requires_grad=True)
312        self.assertONNX(nn.BatchNorm1d(2), x, keep_initializers_as_inputs=True)
313
314    def test_batchnorm_training(self):
315        x = torch.ones(2, 2, 2, 2, requires_grad=True)
316        self.assertONNX(
317            nn.BatchNorm2d(2),
318            x,
319            training=torch.onnx.TrainingMode.TRAINING,
320            keep_initializers_as_inputs=True,
321        )
322
323    def test_conv(self):
324        x = torch.ones(20, 16, 50, 40, requires_grad=True)
325        self.assertONNX(
326            nn.Conv2d(16, 13, 3, bias=False), x, keep_initializers_as_inputs=True
327        )
328
329    def test_conv_onnx_irv4(self):
330        x = torch.ones(20, 16, 50, 40, requires_grad=True)
331        self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x)
332
333    def test_conv_onnx_irv4_opset8(self):
334        # This test point checks that for opset 8 (or lower), even if
335        # keep_initializers_as_inputs is set to False, it is ignored,
336        # and initializers are listed as ONNX graph input, in accordance
337        # with ONNX IR v3 semantics (which apply to opset version <= 8).
338        x = torch.ones(1, 2, 5, 7, requires_grad=True)
339        conv_node = nn.Conv2d(2, 4, 3, bias=False)
340        conv_node.weight.data.fill_(1.0)
341        self.assertONNX(
342            conv_node, x, opset_version=8, keep_initializers_as_inputs=False
343        )
344
345    def test_conv_variable_length(self):
346        x = torch.ones(5, 3, 6, 6, requires_grad=True)
347        model = torch.nn.Conv2d(3, 2, 3)
348
349        dynamic_axes = {
350            "input_1": [0, 2, 3],
351            "output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"},
352        }
353        model_proto_file = tempfile.NamedTemporaryFile()
354        torch.onnx.export(
355            model,
356            x,
357            model_proto_file.name,
358            verbose=True,
359            input_names=["input_1"],
360            output_names=["output_1"],
361            dynamic_axes=dynamic_axes,
362        )
363
364        import onnx
365
366        onnx_model = onnx.load(model_proto_file.name)
367        onnx.checker.check_model(onnx_model)
368
369        # Asserting the default dynamic axes names are generated when custom names are not provided
370        assert (
371            onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
372            == "input_1_dynamic_axes_1"
373        )
374        assert (
375            onnx_model.graph.input[0].type.tensor_type.shape.dim[2].dim_param
376            == "input_1_dynamic_axes_2"
377        )
378        assert (
379            onnx_model.graph.input[0].type.tensor_type.shape.dim[3].dim_param
380            == "input_1_dynamic_axes_3"
381        )
382
383        # Asserting the custom names are applied when provided
384        assert (
385            onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param
386            == "output_1_variable_dim_0"
387        )
388        assert (
389            onnx_model.graph.output[0].type.tensor_type.shape.dim[1].dim_param
390            == "output_1_variable_dim_1"
391        )
392
393    def test_convtranspose(self):
394        x = torch.ones(2, 3, 4, 5, requires_grad=True)
395        self.assertONNX(
396            nn.ConvTranspose2d(
397                3, 3, 3, stride=3, bias=False, padding=1, output_padding=2
398            ),
399            x,
400            keep_initializers_as_inputs=True,
401        )
402
403    def test_maxpool(self):
404        x = torch.randn(20, 16, 50)
405        self.assertONNX(nn.MaxPool1d(3, stride=2), x)
406
407    def test_maxpool_dilations(self):
408        x = torch.randn(20, 16, 50)
409        self.assertONNX(nn.MaxPool1d(2, stride=1, dilation=2), x, opset_version=10)
410
411    def test_avg_pool2d(self):
412        x = torch.randn(20, 16, 50, 32)
413        self.assertONNX(nn.AvgPool2d(3, stride=2), x)
414
415    def test_maxpool_indices(self):
416        x = torch.randn(20, 16, 50)
417        self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
418
419    def test_at_op(self):
420        x = torch.randn(3, 4)
421
422        class MyFun(Function):
423            @staticmethod
424            def symbolic(g, x):
425                return g.at("add", x, x)
426
427            @staticmethod
428            def forward(ctx, x):
429                return x + x
430
431        class MyModule(Module):
432            def forward(self, x):
433                return MyFun.apply(x)
434
435        self.assertONNX(
436            MyModule(),
437            x,
438            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
439        )
440
441    def test_clip(self):
442        x = torch.randn(3, 4, requires_grad=True)
443        self.assertONNX(lambda x: torch.clamp(x, min=-0.5, max=0.5), x)
444
445    def test_clip_min(self):
446        x = torch.randn(1, 2, 3, 4, requires_grad=True)
447        self.assertONNX(lambda x: x.clamp(min=-0.1), x)
448
449    def test_clip_max(self):
450        x = torch.randn(1, 2, 3, 4, requires_grad=True)
451        self.assertONNX(lambda x: x.clamp(max=0.1), x)
452
453    def test_hardtanh(self):
454        x = torch.randn(3, 4, requires_grad=True)
455        self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
456
457    def test_full(self):
458        x = torch.randn(3, 4, requires_grad=True)
459        self.assertONNX(lambda x: torch.full(x.shape, 2.0), x)
460
461    def test_full_like(self):
462        x = torch.randn(3, 4, requires_grad=True)
463        self.assertONNX(lambda x: torch.full_like(x, 2), x)
464
465    def test_max(self):
466        x = torch.randn(3, 4, requires_grad=True)
467        y = torch.randn(3, 4, requires_grad=True)
468        self.assertONNX(lambda x, y: torch.max(x, y), (x, y))
469
470    def test_min(self):
471        x = torch.randn(3, 4, requires_grad=True)
472        y = torch.randn(3, 4, requires_grad=True)
473        self.assertONNX(lambda x, y: torch.min(x, y), (x, y))
474
475    def test_mean(self):
476        x = torch.randn(1, 2, 3, 4, requires_grad=True)
477        self.assertONNX(lambda x: torch.mean(x), x)
478
479    def test_reduced_mean(self):
480        x = torch.randn(1, 2, 3, 4, requires_grad=True)
481        self.assertONNX(lambda x: torch.mean(x, dim=2), x)
482
483    def test_reduced_mean_keepdim(self):
484        x = torch.randn(1, 2, 3, 4, requires_grad=True)
485        self.assertONNX(lambda x: torch.mean(x, dim=(2, 3), keepdim=True), x)
486
487    def test_mean_dtype(self):
488        x = torch.randn(1, 2, 3, 4, requires_grad=True)
489        self.assertONNX(lambda x: torch.mean(x, dtype=torch.double), x)
490
491    def test_reduced_mean_dtype(self):
492        x = torch.randn(1, 2, 3, 4, requires_grad=True)
493        self.assertONNX(lambda x: torch.mean(x, dim=0, dtype=torch.double), x)
494
495    def test_sum(self):
496        x = torch.randn(1, 2, 3, 4, requires_grad=True)
497        self.assertONNX(lambda x: torch.sum(x), x)
498
499    def test_sum_dtype(self):
500        x = torch.randn(1, 2, 3, 4, requires_grad=True)
501        self.assertONNX(lambda x: torch.sum(x, dtype=torch.double), x)
502
503    def test_reduced_sum_dtype(self):
504        x = torch.randn(1, 2, 3, 4, requires_grad=True)
505        self.assertONNX(lambda x: torch.sum(x, dim=0, dtype=torch.double), x)
506
507    def test_reduced_sum(self):
508        x = torch.randn(1, 2, 3, 4, requires_grad=True)
509        self.assertONNX(lambda x: torch.sum(x, dim=(1, 2)), x)
510
511    def test_reduced_sum_keepdim(self):
512        x = torch.randn(1, 2, 3, 4, requires_grad=True)
513        self.assertONNX(lambda x: torch.sum(x, dim=2, keepdim=True), x)
514
515    def test_prod(self):
516        x = torch.randn(1, 2, 3, 4, requires_grad=True)
517        self.assertONNX(lambda x: torch.prod(x), x)
518
519    def test_reduced_prod(self):
520        x = torch.randn(1, 2, 3, 4, requires_grad=True)
521        self.assertONNX(lambda x: torch.prod(x, dim=2), x)
522
523    def test_reduced_prod_keepdim(self):
524        x = torch.randn(1, 2, 3, 4, requires_grad=True)
525        self.assertONNX(lambda x: torch.prod(x, dim=2, keepdim=True), x)
526
527    def test_prod_dtype(self):
528        x = torch.randn(1, 2, 3, 4, requires_grad=True)
529        self.assertONNX(lambda x: torch.prod(x, dtype=torch.double), x)
530
531    def test_reduced_prod_dtype(self):
532        x = torch.randn(1, 2, 3, 4, requires_grad=True)
533        self.assertONNX(lambda x: torch.prod(x, dim=0, dtype=torch.double), x)
534
535    def test_sqrt(self):
536        x = torch.randn(3, 4, requires_grad=True)
537        self.assertONNX(lambda x: torch.sqrt(x), x)
538
539    def test_rsqrt(self):
540        x = torch.randn(3, 4, requires_grad=True)
541        self.assertONNX(lambda x: torch.rsqrt(x), x)
542
543    def test_equal(self):
544        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
545        y = torch.randn(1, 4, requires_grad=False).int()
546        self.assertONNX(operator.eq, (x, y))
547
548    def test_lt(self):
549        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
550        y = torch.randn(1, 4, requires_grad=False).int()
551        self.assertONNX(operator.lt, (x, y))
552
553    def test_gt(self):
554        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
555        y = torch.randn(1, 4, requires_grad=False).int()
556        self.assertONNX(operator.gt, (x, y))
557
558    def test_le(self):
559        x = torch.randn(3, 4, requires_grad=False).int()
560        y = torch.randn(3, 4, requires_grad=False).int()
561        self.assertONNX(operator.le, (x, y))
562
563    def test_ge(self):
564        x = torch.randn(3, 4, requires_grad=False).int()
565        y = torch.randn(3, 4, requires_grad=False).int()
566        self.assertONNX(operator.ge, (x, y))
567
568    def test_exp(self):
569        x = torch.randn(3, 4, requires_grad=True)
570        self.assertONNX(lambda x: x.exp(), x)
571
572    def test_sin(self):
573        x = torch.randn(3, 4, requires_grad=True)
574        self.assertONNX(lambda x: x.sin(), x)
575
576    def test_cos(self):
577        x = torch.randn(3, 4, requires_grad=True)
578        self.assertONNX(lambda x: x.cos(), x)
579
580    def test_tan(self):
581        x = torch.randn(3, 4, requires_grad=True)
582        self.assertONNX(lambda x: x.tan(), x)
583
584    def test_asin(self):
585        x = torch.rand(3, 4, requires_grad=True)
586        self.assertONNX(lambda x: x.asin(), x)
587
588    def test_acos(self):
589        x = torch.rand(3, 4, requires_grad=True)
590        self.assertONNX(lambda x: x.acos(), x)
591
592    def test_slice(self):
593        x = torch.rand(3, 4, requires_grad=True)
594        self.assertONNX(lambda x: x[:, 1:2], x)
595
596    def test_slice_dynamic(self):
597        x = torch.rand(3, 4, requires_grad=True)
598        self.assertONNX(lambda x: x[x.size(0) :, x.size(1) - 3], x, opset_version=10)
599
600    def test_sign(self):
601        x = torch.rand(3, 4, requires_grad=True)
602        self.assertONNX(lambda x: x.sign(), x)
603
604    def test_narrow(self):
605        x = torch.randn(3, 3, requires_grad=True)
606        self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
607
608    def test_atan(self):
609        x = torch.randn(3, 4, requires_grad=True)
610        self.assertONNX(lambda x: x.atan(), x)
611
612    def test_view_flatten(self):
613        x = torch.randn(1, 2, 3, 4, requires_grad=True)
614        self.assertONNX(lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x)
615
616    def test_flatten(self):
617        x = torch.randn(1, 2, 3, 4, requires_grad=True)
618        self.assertONNX(lambda x: torch.flatten(x), x)
619
620    def test_flatten2D(self):
621        x = torch.randn(1, 2, 3, 4, requires_grad=True)
622        self.assertONNX(lambda x: torch.flatten(x, 1), x)
623
624    def test_isnan(self):
625        x = torch.tensor([1, float("nan"), 2])
626        self.assertONNX(lambda x: torch.isnan(x), x)
627
628    def test_argmax(self):
629        x = torch.randn(4, 4, requires_grad=True)
630        self.assertONNX(lambda x: torch.argmax(x, dim=1), x)
631
632    def test_logsoftmax(self):
633        x = torch.randn(1, 2, 3, 4, requires_grad=True)
634        self.assertONNX(nn.LogSoftmax(dim=3), x)
635
636    def test_pow(self):
637        x = torch.randn(1, 2, 3, 4, requires_grad=True)
638        y = torch.randn(1, 2, 3, 4, requires_grad=True)
639        self.assertONNX(lambda x, y: x.pow(y), (x, y))
640
641    def test_elu(self):
642        x = torch.randn(1, 2, 3, 4, requires_grad=True)
643        self.assertONNX(nn.ELU(), x)
644
645    def test_selu(self):
646        x = torch.randn(1, 2, 3, 4, requires_grad=True)
647        self.assertONNX(nn.SELU(), x)
648
649    def test_repeat(self):
650        x = torch.randn(1, 2, 3, 4, requires_grad=True)
651        self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
652
653    def test_repeat_dim_overflow(self):
654        x = torch.randn(1, 2, requires_grad=True)
655        self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
656
657    def test_norm_p1(self):
658        x = torch.randn(1, 2, 3, 4, requires_grad=True)
659        self.assertONNX(lambda x: x.norm(p=1, dim=2), (x))
660
661    def test_norm_p2(self):
662        x = torch.randn(1, 2, 3, 4, requires_grad=True)
663        self.assertONNX(lambda x: x.norm(p=2, dim=2), (x))
664
665    def test_upsample_nearest_scale(self):
666        x = torch.randn(1, 2, 3, 4, requires_grad=True)
667        self.assertONNX(
668            lambda x: nn.functional.interpolate(
669                x, scale_factor=2.0, mode="nearest", recompute_scale_factor=False
670            ),
671            x,
672        )
673
674    def test_upsample_nearest_scale_default_scale_factor(self):
675        x = torch.randn(1, 2, 3, 4, requires_grad=True)
676        self.assertONNX(
677            lambda x: nn.functional.interpolate(x, scale_factor=2.0, mode="nearest"), x
678        )
679
680    def test_upsample_nearest_size(self):
681        x = torch.randn(1, 2, 3, 4, requires_grad=True)
682        self.assertONNX(
683            lambda x: nn.functional.interpolate(x, size=16, mode="nearest"), x
684        )
685
686    def test_unsqueeze(self):
687        x = torch.randn(3, 4, requires_grad=True)
688        self.assertONNX(lambda x: x.unsqueeze(len(x.shape)), x)
689
690    def test_batchnorm_noaffine(self):
691        x = torch.randn(128, 128, 1, 1, requires_grad=True)
692        self.assertONNX(
693            nn.BatchNorm2d(128, affine=False, momentum=0.3),
694            x,
695            keep_initializers_as_inputs=True,
696        )
697
698    def test_embedding_bags(self):
699        emb_bag = nn.EmbeddingBag(10, 8)
700        input = torch.tensor([1, 2, 3, 4]).long()
701        offset = torch.tensor([0]).long()
702        self.assertONNX(
703            emb_bag,
704            (input, offset),
705            keep_initializers_as_inputs=True,
706            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
707        )
708
709    def test_implicit_expand(self):
710        x = torch.randn(3, 4, requires_grad=True)
711        self.assertONNX(lambda x: x + 1, x)
712
713    def test_reduce_sum_negative_indices(self):
714        x = torch.randn(3, 4, requires_grad=True)
715        self.assertONNX(lambda x: x.sum(-1), x)
716
717    def test_randn(self):
718        x = torch.randn(1, 2, 3, 4)
719        self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x)
720
721    def test_rand(self):
722        x = torch.rand(1, 2, 3, 4)
723        self.assertONNX(lambda x: torch.rand(1, 2, 3, 4) + x, x)
724
725    def test_rrelu(self):
726        x = torch.randn(1, 2, 3, 4)
727        self.assertONNX(torch.nn.RReLU(), x)
728
729    def test_prelu(self):
730        x = torch.randn(1, 2, 3, 4)
731        self.assertONNX(torch.nn.PReLU(2), x, keep_initializers_as_inputs=True)
732
733    def test_log_sigmoid(self):
734        x = torch.randn(1, 2, 3, 4)
735        self.assertONNX(torch.nn.LogSigmoid(), x)
736
737    def test_linear(self):
738        x = torch.randn(3, 4)
739        self.assertONNX(
740            torch.nn.Linear(4, 5, bias=True), x, keep_initializers_as_inputs=True
741        )
742
743    def test_empty_like(self):
744        x = torch.randn(5, 8, requires_grad=True)
745        self.assertONNX(lambda x: torch.empty_like(x), x)
746
747    def test_zeros_like(self):
748        x = torch.randn(5, 8, requires_grad=True)
749        self.assertONNX(lambda x: torch.zeros_like(x), x)
750
751    def test_ones_like(self):
752        x = torch.randn(6, 10, requires_grad=True)
753        self.assertONNX(lambda x: torch.ones_like(x), x)
754
755    def test_expand(self):
756        x = torch.randn(6, 1, requires_grad=True)
757        self.assertONNX(lambda x: x.expand(4, 6, 2), x)
758
759    def test_ne(self):
760        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
761        y = torch.randn(1, 4, requires_grad=False).int()
762        self.assertONNX(lambda x, y: torch.ne(x, y), (x, y))
763
764    def test_reducemax(self):
765        x = torch.randn(1, 2, 3, 4)
766        self.assertONNX(lambda x: torch.max(x), x)
767
768    def test_reducemin(self):
769        x = torch.randn(1, 2, 3, 4)
770        self.assertONNX(lambda x: torch.min(x), x)
771
772    def test_erf(self):
773        x = torch.randn(1, 2, 3, 4)
774        self.assertONNX(lambda x: x.erf(), x)
775
776    def test_dropout(self):
777        x = torch.randn(3, 4, requires_grad=True)
778        self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
779
780    def test_dropout_default(self):
781        x = torch.randn(3, 4, requires_grad=True)
782        self.assertONNX(
783            lambda x: torch.max(
784                functional.dropout(
785                    x,
786                )
787            ),
788            x,
789        )
790
791    def test_dropout_training(self):
792        x = torch.randn(3, 4, requires_grad=True)
793        self.assertONNX(
794            lambda x: torch.max(functional.dropout(x)),
795            x,
796            training=torch.onnx.TrainingMode.TRAINING,
797        )
798
799    def test_dropout_opset12(self):
800        x = torch.randn(3, 4, requires_grad=True)
801        self.assertONNX(
802            lambda x: torch.max(functional.dropout(x, training=False)),
803            x,
804            opset_version=12,
805        )
806
807    def test_dropout_training_opset12(self):
808        x = torch.randn(3, 4, requires_grad=True)
809        self.assertONNX(
810            lambda x: torch.max(functional.dropout(x)),
811            x,
812            opset_version=12,
813            training=torch.onnx.TrainingMode.TRAINING,
814        )
815
816    def test_nonzero(self):
817        x = torch.tensor(
818            [[[2.0, 2.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]], requires_grad=True
819        )
820        self.assertONNX(lambda x: torch.nonzero(x), x)
821
822    def test_gather(self):
823        data = torch.randn(3, 4, 3, requires_grad=True)
824        index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
825        self.assertONNX(lambda data, index: data.gather(1, index), (data, index))
826
827    def test_gather_opset11(self):
828        data = torch.randn(3, 4, 3, requires_grad=True)
829        index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
830        self.assertONNX(
831            lambda data, index: data.gather(1, index), (data, index), opset_version=11
832        )
833
834    def test_scatter_add(self):
835        data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
836        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
837        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
838        self.assertONNX(
839            lambda data, index: data.scatter_add(1, indices, values),
840            (data, (indices, values)),
841        )
842
843    def test_scatter_add_opset11(self):
844        data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
845        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
846        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
847        self.assertONNX(
848            lambda data, index: data.scatter_add(1, indices, values),
849            (data, (indices, values)),
850            opset_version=11,
851        )
852
853    def test_scatter_add_opset16(self):
854        data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
855        indices = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
856        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
857        self.assertONNX(
858            lambda data, index: data.scatter_add(1, indices, values),
859            (data, (indices, values)),
860            opset_version=16,
861        )
862
863    def test_master_opset(self):
864        x = torch.randn(2, 3).float()
865        y = torch.randn(2, 3).float()
866        self.assertONNX(operator.add, (x, y), opset_version=10)
867
868    def test_std(self):
869        x = torch.randn(2, 3, 4).float()
870        self.assertONNX(
871            lambda x: torch.std(x, dim=(0, 1), unbiased=True, keepdim=True), x
872        )
873
874    def test_cumsum(self):
875        x = torch.randn(2, 3, 4, requires_grad=True)
876        self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
877
878    def test_dict(self):
879        class MyModel(torch.nn.Module):
880            def forward(self, x_in):
881                x_out = {}
882                x_out["test_key_out"] = torch.add(
883                    x_in[list(x_in.keys())[0]],  # noqa: RUF015
884                    list(x_in.keys())[0],  # noqa: RUF015
885                )
886                return x_out
887
888        x = {torch.tensor(1.0): torch.randn(1, 2, 3)}
889        self.assertONNX(MyModel(), (x, {}))
890
891    def test_dict_str(self):
892        class MyModel(torch.nn.Module):
893            def forward(self, x_in):
894                x_out = {}
895                x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0)
896                return x_out
897
898        x = {"test_key_in": torch.randn(1, 2, 3)}
899        self.assertONNX(MyModel(), (x, {}))
900
901    def test_arange_dynamic(self):
902        class TestModel(torch.nn.Module):
903            def forward(self, input):
904                return torch.arange(input.shape[0], input.shape[0] + 5, 0.5)
905
906        input = torch.randn(5, 3, 2)
907        self.assertONNX(TestModel(), input, opset_version=11)
908
909    def test_bitshift(self):
910        class BitshiftModel(torch.nn.Module):
911            def forward(self, input):
912                return input >> 1, input >> 2
913
914        input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
915        self.assertONNX(BitshiftModel(), input, opset_version=11)
916
917    def test_bitwise_and(self):
918        class BiwiseAndModel(torch.nn.Module):
919            def forward(self, input, other):
920                return torch.bitwise_and(input, other), input & 2
921
922        input = torch.randint(0, 100, (2, 3, 4), dtype=torch.uint8)
923        other = torch.randint(-50, 50, (2, 3, 4), dtype=torch.int8)
924        self.assertONNX(BiwiseAndModel(), (input, other), opset_version=18)
925
926    def test_layer_norm_aten(self):
927        model = torch.nn.LayerNorm([10, 10])
928        x = torch.randn(20, 5, 10, 10)
929        self.assertONNX(
930            model,
931            x,
932            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
933        )
934
935    def test_pixel_shuffle(self):
936        x = torch.randn(2, 8, 3, 4).float()
937        self.assertONNX(
938            lambda x: torch.pixel_shuffle(x, upscale_factor=2), x, opset_version=11
939        )
940
941    def test_frobenius_norm(self):
942        x = torch.randn(2, 3, 4).float()
943        self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x)
944
945    def test_unfold(self):
946        x = torch.randn(2, 3, 4, requires_grad=True)
947        self.assertONNX(lambda x: x.unfold(dimension=2, size=2, step=2), x)
948
949    def test_remainder(self):
950        x = torch.randn(2, 3, 4)
951        y = torch.randn(2, 1, 4)
952        self.assertONNX(lambda x, y: torch.remainder(x, y), (x, y))
953
954    def test_fmod(self):
955        x = torch.randn(2, 3, 4)
956        y = torch.randn(2, 1, 4)
957        self.assertONNX(lambda x, y: torch.fmod(x, y), (x, y), opset_version=10)
958
959    def test_gelu(self):
960        x = torch.randn(2, 3, 4, 5, requires_grad=True)
961        self.assertONNX(lambda x: torch.nn.functional.gelu(x), x)
962
963    def test_unique(self):
964        x = torch.randint(3, (2, 3, 4, 5)).float()
965        self.assertONNX(
966            lambda x: torch.unique(
967                x, dim=0, sorted=True, return_inverse=False, return_counts=True
968            ),
969            x,
970            opset_version=11,
971        )
972
973    def test_meshgrid(self):
974        x = torch.ones(3, requires_grad=True)
975        y = torch.zeros(4, requires_grad=True)
976        z = torch.ones(5, requires_grad=True)
977        self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z))
978
979    def test_meshgrid_indexing(self):
980        x = torch.ones(3, requires_grad=True)
981        y = torch.zeros(4, requires_grad=True)
982        z = torch.ones(5, requires_grad=True)
983        self.assertONNX(
984            lambda x, y, z: torch.meshgrid(x, y, z, indexing="xy"),
985            (x, y, z),
986            opset_version=9,
987        )
988
989    def test_topk(self):
990        x = torch.arange(1.0, 6.0, requires_grad=True)
991        k = torch.tensor(3)
992        self.assertONNX(lambda x, k: torch.topk(x, k), (x, k), opset_version=10)
993
994    def test_topk_smallest_unsorted(self):
995        x = torch.arange(1.0, 6.0, requires_grad=True)
996        k = torch.tensor(3)
997        self.assertONNX(
998            lambda x, k: torch.topk(x, k, largest=False, sorted=False),
999            (x, k),
1000            opset_version=11,
1001        )
1002
1003    def test_baddbmm(self):
1004        x = torch.randn(10, 3, 5)
1005        b1 = torch.randn(10, 3, 4)
1006        b2 = torch.randn(10, 4, 5)
1007        self.assertONNX(lambda x, b1, b2: torch.baddbmm(x, b1, b2), (x, b1, b2))
1008
1009    def test_round(self):
1010        x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True)
1011        self.assertONNX(lambda x: torch.round(x), x, opset_version=11)
1012
1013    def test_dim(self):
1014        x = torch.ones((2, 2), requires_grad=True)
1015        self.assertONNX(lambda x: torch.scalar_tensor(x.dim()), x)
1016
1017    @skipIfNoLapack
1018    def test_det(self):
1019        x = torch.randn(2, 3, 5, 5, device=torch.device("cpu"))
1020        self.assertONNX(lambda x: torch.det(x), x, opset_version=11)
1021        self.assertONNX(lambda x: torch.linalg.det(x), x, opset_version=11)
1022
1023    def test_softmaxcrossentropy(self):
1024        x = torch.randn(3, 5)
1025        y = torch.empty(3, dtype=torch.long).random_(5)
1026        self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1027
1028    def test_softmaxcrossentropy_ignore_index(self):
1029        x = torch.randn(3, 5)
1030        y = torch.empty(3, dtype=torch.long).random_(5)
1031        self.assertONNX(
1032            torch.nn.CrossEntropyLoss(ignore_index=1), (x, y), opset_version=12
1033        )
1034
1035    def test_softmaxcrossentropy_weights(self):
1036        x = torch.randn(3, 5)
1037        y = torch.empty(3, dtype=torch.long).random_(5)
1038        self.assertONNX(
1039            torch.nn.CrossEntropyLoss(weight=torch.randn(5)), (x, y), opset_version=12
1040        )
1041
1042    def test_softmaxcrossentropy_3d(self):
1043        x = torch.randn(3, 5, 2)
1044        y = torch.empty(3, 2, dtype=torch.long).random_(5)
1045        self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1046
1047    def test_softmaxcrossentropy_3d_none(self):
1048        x = torch.randn(3, 5, 2)
1049        y = torch.empty(3, 2, dtype=torch.long).random_(5)
1050        self.assertONNX(
1051            torch.nn.CrossEntropyLoss(reduction="none"), (x, y), opset_version=12
1052        )
1053
1054    def test_softmaxcrossentropy_4d(self):
1055        x = torch.randn(3, 5, 2, 1)
1056        y = torch.empty(3, 2, 1, dtype=torch.long).random_(5)
1057        self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1058
1059    def test_lstm_none_sequence_lens(self):
1060        """Test symbolic shape inference for LSTM when the input sequence_lens = None."""
1061        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
1062        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
1063        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
1064
1065        class LSTMModel(torch.nn.Module):
1066            def __init__(self) -> None:
1067                super().__init__()
1068                self.rnn = torch.nn.LSTM(
1069                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
1070                )
1071
1072            def forward(self, x, h0, c0):
1073                a, b = self.rnn(x, (h0, c0))
1074                return torch.ones(b[0].shape)
1075
1076        self.assertONNX(
1077            LSTMModel(),
1078            (input, h0, c0),
1079            input_names=["x", "y"],
1080            dynamic_axes={"x": {0: "batch"}},
1081            opset_version=12,
1082        )
1083
1084    def test_dynamic_axes_add(self):
1085        m1 = torch.randn(2, 3, requires_grad=True)
1086        m2 = torch.randn(2, 1, requires_grad=True)
1087        self.assertONNX(
1088            lambda x, y: torch.add(x, y),
1089            (m1, m2),
1090            input_names=["input_1", "input_2"],
1091            dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}},
1092            opset_version=12,
1093        )
1094
1095    def test_dynamic_axes_add_inputs_same_symbolic_shape(self):
1096        m1 = torch.randn(2, 3, requires_grad=True)
1097        self.assertONNX(
1098            lambda x: torch.add(x, x),
1099            (m1,),
1100            input_names=["input_1"],
1101            dynamic_axes={"input_1": {1: "dim_1"}},
1102            opset_version=12,
1103        )
1104
1105    def test_dynamic_axes_matmul(self):
1106        m1 = torch.randn(2, 2, 4, requires_grad=True)
1107        m2 = torch.randn(2, 4, 3, requires_grad=True)
1108        self.assertONNX(
1109            lambda x, y: torch.matmul(x, y),
1110            (m1, m2),
1111            input_names=["input_1", "input_2"],
1112            dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}},
1113            opset_version=12,
1114        )
1115
1116    def test_dynamic_axes_reduce_mean(self):
1117        m1 = torch.randn(2, 3, 4, requires_grad=True)
1118        self.assertONNX(
1119            lambda x: torch.mean(x, dim=1),
1120            (m1),
1121            input_names=["input"],
1122            dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}},
1123            opset_version=12,
1124        )
1125
1126    def test_dynamic_axes_unchange(self):
1127        """Test ProcessUnchangeNode in symbolic shape inference."""
1128        m1 = torch.randn(2, 3, requires_grad=True)
1129        self.assertONNX(
1130            lambda x: torch.softmax(x, dim=0),
1131            (m1,),
1132            input_names=["input"],
1133            dynamic_axes={"input": {1: "dim_1"}},
1134            opset_version=12,
1135        )
1136
1137    def test_aten_embedding_1(self):
1138        _onnx_opset_version = 12
1139
1140        @parse_args("v", "v", "i", "b", "b")
1141        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
1142            custom_attributes_json = (
1143                "{"
1144                f'"padding_idx":{str(padding_idx)},'
1145                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
1146                f'"sparse":{str(sparse).lower()}'
1147                "}"
1148            )
1149            output = g.at(
1150                "embedding",
1151                weight,
1152                indices,
1153                custom_attributes_json_s=custom_attributes_json,
1154            )
1155            return output
1156
1157        torch.onnx.register_custom_op_symbolic(
1158            "::embedding", embedding, _onnx_opset_version
1159        )
1160
1161        class Model(torch.nn.Module):
1162            def __init__(self) -> None:
1163                super().__init__()
1164                self.emb = torch.nn.Embedding(4, 8)
1165
1166            def forward(self, x, y):
1167                res = self.emb(x)
1168                res = res + y
1169                return torch.ones(res.shape[0])
1170
1171        model = Model()
1172        x = torch.ones(32, dtype=torch.long)
1173        y = torch.randn(1, 8)
1174        self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)
1175
1176        torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
1177
1178    # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
1179    def test_aten_embedding_2(self):
1180        _onnx_opset_version = 12
1181
1182        @parse_args("v", "v", "i", "b", "b")
1183        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
1184            custom_attributes_json = (
1185                "{"
1186                f'"padding_idx":{str(padding_idx)},'
1187                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
1188                f'"sparse":{str(sparse).lower()}'
1189                "}"
1190            )
1191            output = g.at(
1192                "embedding",
1193                weight,
1194                indices,
1195                custom_attributes_json_s=custom_attributes_json,
1196            )
1197
1198            # do shape inference and set it via setType
1199            indices_shape = _get_tensor_sizes(indices)
1200            if indices_shape is not None and hasattr(weight.type(), "with_sizes"):
1201                output_type = weight.type().with_sizes(
1202                    indices_shape + [_get_tensor_dim_size(weight, 1)]
1203                )
1204                output.setType(output_type)
1205            return output
1206
1207        torch.onnx.register_custom_op_symbolic(
1208            "::embedding", embedding, _onnx_opset_version
1209        )
1210
1211        class Model(torch.nn.Module):
1212            def __init__(self) -> None:
1213                super().__init__()
1214                self.emb = torch.nn.Embedding(4, 8)
1215
1216            def forward(self, x, y):
1217                res = self.emb(x)
1218                res = res + y
1219                return torch.ones(res.shape[0])
1220
1221        model = Model()
1222        x = torch.ones(32, dtype=torch.long)
1223        y = torch.randn(1, 8)
1224        self.assertONNX(
1225            model,
1226            (x, y),
1227            opset_version=_onnx_opset_version,
1228            input_names=["input_1", "input_2"],
1229            dynamic_axes={"input_1": {0: "dim_0"}, "input_2": {0: "dim_1", 1: "dim_2"}},
1230            keep_initializers_as_inputs=False,
1231            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
1232        )
1233
1234        torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
1235
1236    # Without shapeValueMap, the onnx graph looks like:
1237    # graph(%0 : Float(*, 1, 128, 1, strides=[128, 128, 1, 1], requires_grad=0, device=cpu)):
1238    #   %2 : Long(4, strides=[1], device=cpu) = onnx::Shape(%0)
1239    #   %4 : Long(device=cpu) = onnx::Constant[value={0}]()
1240    #   %5 : Long(device=cpu) = onnx::Gather[axis=0](%2, %4)
1241    #   %6 : Long(device=cpu) = onnx::Constant[value={1}]()
1242    #   %7 : Long(device=cpu) = onnx::Constant[value={2}]()
1243    #   %8 : Long(device=cpu) = onnx::Constant[value={-1}]()
1244    #   %9 : int[] = prim::ListConstruct(%5, %6, %7, %8)
1245    #   %10 : Float(*, *, *, *, strides=[128, 128, 64, 1], requires_grad=0, device=cpu) = onnx::Reshape(%0, %9)
1246    #   ...
1247    # With shapeValueMap, it becomes:
1248    #   ...
1249    #   %10 : Float(*, 1, 2, 64, strides=[128, 128, 64, 1], requires_grad=0, device=cpu) = onnx::Reshape(%0, %9)
1250    #   ...
1251    def test_shape_value_map(self):
1252        class RSoftMax(torch.nn.Module):
1253            def __init__(self, radix, cardinality):
1254                super().__init__()
1255                self.radix = radix
1256                self.cardinality = cardinality
1257
1258            def forward(self, x):
1259                batch = x.size(0)
1260                x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
1261                x = F.softmax(x, dim=1)
1262                x = x.reshape(batch, -1)
1263                return x
1264
1265        radix = 2
1266        cardinality = 1
1267        x = torch.randn(10, 1, 128, 1)
1268        self.assertONNX(
1269            RSoftMax(radix, cardinality),
1270            (x,),
1271            input_names=["x"],
1272            dynamic_axes={"x": {0: "dim_0"}},
1273        )
1274
1275
1276if __name__ == "__main__":
1277    no_onnx_dep_flag = "--no-onnx"
1278    _onnx_dep = no_onnx_dep_flag not in common_utils.UNITTEST_ARGS
1279    if no_onnx_dep_flag in common_utils.UNITTEST_ARGS:
1280        common_utils.UNITTEST_ARGS.remove(no_onnx_dep_flag)
1281    onnx_test_flag = "--produce-onnx-test-data"
1282    _onnx_test = onnx_test_flag in common_utils.UNITTEST_ARGS
1283    if onnx_test_flag in common_utils.UNITTEST_ARGS:
1284        common_utils.UNITTEST_ARGS.remove(onnx_test_flag)
1285    if _onnx_test:
1286        _onnx_dep = True
1287        import onnx_test_common
1288
1289        for d in glob.glob(
1290            os.path.join(onnx_test_common.pytorch_operator_dir, "test_operator_*")
1291        ):
1292            shutil.rmtree(d)
1293    common_utils.run_tests()
1294