xref: /aosp_15_r20/external/pytorch/test/onnx/test_op_consistency.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3"""Test consistency between the output values of torch.onnx exported operators
4and torch operators given the same inputs.
5
6Usage:
7
8    pytest test/onnx/test_op_consistency.py
9
10    To run tests on a specific operator (e.g. torch.ceil):
11
12    pytest test/onnx/test_op_consistency.py -k ceil
13    pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
14
15    Read more on Running and writing tests:
16        https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
17
18Note:
19
20    When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
21    TESTED_OPS lists. See "Modify this section"
22
23"""
24
25from __future__ import annotations
26
27import copy
28from typing import Optional, Tuple
29
30import onnx_test_common
31import parameterized
32
33# For readability, these two are allowed to be imported as function
34from onnx_test_common import skip, xfail
35
36import torch
37from torch.testing._internal import (
38    common_device_type,
39    common_methods_invocations,
40    common_utils,
41)
42
43
44OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
45
46# Modify this section ##########################################################
47# NOTE: Modify this section as more ops are supported. The list should be sorted
48# alphabetically.
49#
50# For example, to add a test for torch.ceil:
51# 1.  Add "ceil" to TESTED_OPS then run pytest.
52# 2.  If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
53
54# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
55# Ops to be tested for numerical consistency between onnx and pytorch
56# TODO: https://github.com/pytorch/pytorch/issues/102211
57TESTED_OPS: frozenset[str] = frozenset(
58    [
59        "atan",
60        "atan2",
61        # "atleast_1d",  # How to support list input?
62        # "atleast_2d",
63        # "atleast_3d",
64        "broadcast_to",
65        "ceil",
66        "expand",
67        "flatten",
68        "hstack",
69        "logical_not",
70        # "logit",
71        "nn.functional.scaled_dot_product_attention",
72        "repeat",
73        "round",
74        # "scatter_add",
75        # "scatter_reduce",
76        "sqrt",
77        "stft",
78        "t",
79        "tile",
80        "unflatten",
81        "vstack",
82    ]
83)
84
85# fmt: off
86# Turn off black formatting to keep the list compact
87
88# Expected failures for onnx export.
89# The list should be sorted alphabetically by op name.
90# Q: When should I use fixme vs vs skip vs xfail?
91# A: Prefer xfail over skip when possible.
92#     2a. If a test is now failing because of xpass, because some previous errors
93#     are now fixed, removed the corresponding xfail.
94#     2b. If a test is not failing consistently, use skip.
95EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
96    skip(
97        "atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
98        reason=onnx_test_common.reason_onnx_does_not_support("Atan")
99    ),
100    xfail("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
101    skip(
102        "atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
103        reason=onnx_test_common.reason_onnx_does_not_support("Atan")
104    ),
105    xfail(
106        "atan2", dtypes=[torch.float64],
107        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])
108    ),
109    xfail(
110        "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
111        reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
112    ),
113    skip("hstack", opsets=[onnx_test_common.opsets_before(11)],
114         reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
115    xfail(
116        "logit",
117        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
118        reason=onnx_test_common.reason_onnx_does_not_support("Log", "bool, int"),
119    ),
120    skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
121    skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
122    xfail("round", opsets=[onnx_test_common.opsets_before(11)],
123          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
124    xfail("round", variant_name="decimals_0", opsets=[onnx_test_common.opsets_before(11)],
125          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
126    xfail("round", variant_name="decimals_3", opsets=[onnx_test_common.opsets_before(11)],
127          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
128    xfail("round", variant_name="decimals_neg_3", opsets=[onnx_test_common.opsets_before(11)],
129          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
130    skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)],
131         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
132    skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)],
133         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
134    skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)],
135         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
136    xfail("scatter_reduce", variant_name="mean",
137          reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")),
138    skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)],
139         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
140    xfail(
141        "scatter_reduce",
142        variant_name="sum",
143        dtypes=(torch.float16,),
144        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
145    ),
146    xfail(
147        "scatter_reduce",
148        variant_name="prod",
149        dtypes=(torch.float16,),
150        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
151    ),
152    xfail(
153        "scatter_reduce",
154        variant_name="amin",
155        dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
156        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
157    ),
158    xfail(
159        "scatter_reduce",
160        variant_name="amax",
161        dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
162        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
163    ),
164    xfail(
165        "scatter_reduce",
166        variant_name="mean",
167        reason="ONNX doesn't support reduce='mean' option",
168    ),
169    skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
170    skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
171    xfail("stft",
172          reason=onnx_test_common.reason_onnx_runtime_does_not_support("STFT", "Regression on ORT=1.15 4 percent difference")),
173    skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
174    xfail("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
175    skip("vstack", opsets=[onnx_test_common.opsets_before(11)],
176         reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
177)
178# fmt: on
179
180SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
181    skip(
182        "nn.functional.scaled_dot_product_attention",
183        matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
184        reason="dropout is random so the results do not match",
185    ),
186    skip(
187        "repeat",
188        reason="Empty repeats value leads to an invalid graph",
189        matcher=lambda sample: not sample.args[0],
190    ),
191    skip(
192        "scatter_reduce",
193        # ONNX has not include_self parameter and default is include_self=True mode
194        matcher=lambda sample: sample.kwargs.get("include_self") is False,
195        reason="ONNX does't support include_self=False option",
196    ),
197    skip(
198        "stft",
199        reason="ONNX STFT does not support complex results",
200        matcher=lambda sample: sample.kwargs.get("return_complex") is True,
201    ),
202    skip(
203        "tile",
204        matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
205        or not sample.input.shape,
206        reason="Logic not implemented for size 0 inputs in op.Reshape",
207    ),
208    skip(
209        "unflatten",
210        reason="Logic not implemented for size 0 inputs in op.Reshape",
211        matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
212    ),
213)
214
215
216# END OF SECTION TO MODIFY #####################################################
217
218OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
219ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
220# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
221assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
222
223
224class SingleOpModel(torch.nn.Module):
225    """Test model to wrap around a single op for export."""
226
227    def __init__(self, op, kwargs):
228        super().__init__()
229        self.operator = op
230        self.kwargs = kwargs
231
232    def forward(self, *args):
233        return self.operator(*args, **self.kwargs)
234
235
236def _should_skip_xfail_test_sample(
237    op_name: str, sample
238) -> Tuple[Optional[str], Optional[str]]:
239    """Returns a reason if a test sample should be skipped."""
240    if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
241        return None, None
242    for decorator_meta in SKIP_XFAIL_SUBTESTS:
243        # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
244        if decorator_meta.op_name == op_name:
245            assert decorator_meta.matcher is not None, "Matcher must be defined"
246            if decorator_meta.matcher(sample):
247                return decorator_meta.test_behavior, decorator_meta.reason
248    return None, None
249
250
251def _get_test_class_name(cls, num, params_dict) -> str:
252    del cls  # unused
253    del num  # unused
254    return params_dict["name"]
255
256
257@parameterized.parameterized_class(
258    [
259        {
260            "name": f"TestOnnxModelOutputConsistency_opset{opset}",
261            "opset_version": opset,
262        }
263        for opset in onnx_test_common.TESTED_OPSETS
264    ],
265    class_name_func=_get_test_class_name,
266)
267class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
268    """Test output consistency between exported ONNX models and PyTorch eager mode.
269
270    This is a parameterized test suite.
271    """
272
273    opset_version = -1
274
275    @common_device_type.ops(
276        [op for op in OPS_DB if op.name in TESTED_OPS],
277        allowed_dtypes=onnx_test_common.INT_TYPES
278        + onnx_test_common.FLOAT_TYPES
279        + onnx_test_common.BOOL_TYPES,
280    )
281    def test_output_match(self, device: str, dtype: torch.dtype, op):
282        """Test the ONNX exporter."""
283        # device is provided by instantiate_device_type_tests, but we only want to run in cpu.
284        assert device == "cpu"
285
286        samples = op.sample_inputs(
287            device,
288            dtype,
289            requires_grad=False,
290        )
291
292        for i, cpu_sample in enumerate(samples):
293            inputs = (cpu_sample.input, *cpu_sample.args)
294            # Provide the repr to subtest because tensors are not serializable in parallel test runs
295            with self.subTest(
296                opset=self.opset_version,
297                sample_num=i,
298                inputs=repr(inputs),
299                kwargs=repr(cpu_sample.kwargs),
300            ):
301                test_behavior, reason = _should_skip_xfail_test_sample(
302                    op.name, cpu_sample
303                )
304                with onnx_test_common.normal_xfail_skip_test_behaviors(
305                    test_behavior, reason
306                ):
307                    model = SingleOpModel(op, cpu_sample.kwargs)
308                    model.eval()
309
310                    if dtype == torch.float32:
311                        # Relax atol and rtol for float32 based on empirical results
312                        # The current most relaxed values are for aten::stft
313                        rtol = 1e-5
314                        atol = 2e-5
315                    elif dtype == torch.float64:
316                        # The current most relaxed values are for aten::stft
317                        rtol = 1e-5
318                        atol = 2e-5
319                    else:
320                        rtol = None
321                        atol = None
322                    # Run the test
323                    self.run_test(model, inputs, rtol=rtol, atol=atol)
324
325
326for opset in onnx_test_common.TESTED_OPSETS:
327    # The name needs to match the parameterized_class name.
328    test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
329    onnx_test_common.add_decorate_info(
330        OPS_DB,
331        test_class_name,
332        "test_output_match",
333        opset=opset,
334        skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
335    )
336    common_device_type.instantiate_device_type_tests(
337        globals()[test_class_name], globals(), only_for="cpu"
338    )
339
340
341if __name__ == "__main__":
342    common_utils.run_tests()
343