xref: /aosp_15_r20/external/pytorch/test/onnx/onnx_test_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3from __future__ import annotations
4
5import contextlib
6import copy
7import dataclasses
8import io
9import logging
10import os
11import unittest
12import warnings
13from typing import (
14    Any,
15    Callable,
16    Collection,
17    Iterable,
18    List,
19    Mapping,
20    Optional,
21    Sequence,
22    Tuple,
23    Type,
24    Union,
25)
26
27import numpy as np
28import onnxruntime
29import pytest
30import pytorch_test_common
31
32import torch
33from torch import export as torch_export
34from torch.onnx import _constants, verification
35from torch.testing._internal import common_utils
36from torch.testing._internal.opinfo import core as opinfo_core
37from torch.types import Number
38
39
40_NumericType = Union[Number, torch.Tensor, np.ndarray]
41_ModelType = Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
42_InputArgsType = Optional[
43    Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]]
44]
45_OutputsType = Sequence[_NumericType]
46
47onnx_model_dir = os.path.join(
48    os.path.dirname(os.path.realpath(__file__)),
49    os.pardir,
50    "repos",
51    "onnx",
52    "onnx",
53    "backend",
54    "test",
55    "data",
56)
57
58
59pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
60
61
62pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
63
64
65def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
66    options = verification.VerificationOptions()
67
68    kwargs["opset_version"] = test_suite.opset_version
69    kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
70    if hasattr(test_suite, "check_shape"):
71        options.check_shape = test_suite.check_shape
72    if hasattr(test_suite, "check_dtype"):
73        options.check_dtype = test_suite.check_dtype
74
75    names = {f.name for f in dataclasses.fields(options)}
76    keywords_to_pop = []
77    for k, v in kwargs.items():
78        if k in names:
79            setattr(options, k, v)
80            keywords_to_pop.append(k)
81    for k in keywords_to_pop:
82        kwargs.pop(k)
83
84    return verification.verify(*args, options=options, **kwargs)
85
86
87def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool):
88    """Assert whether the exported model has dynamic shapes or not.
89
90    Args:
91        onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export.
92        dynamic_shapes (bool): Whether the exported model has dynamic shapes or not.
93            When True, raises if graph inputs don't have at least one dynamic dimension
94            When False, raises if graph inputs have at least one dynamic dimension.
95
96    Raises:
97        AssertionError: If the exported model has dynamic shapes and dynamic_shapes is False and vice-versa.
98    """
99
100    if dynamic_shapes is None:
101        return
102
103    model_proto = onnx_program.model_proto
104    # Process graph inputs
105    dynamic_inputs = []
106    for inp in model_proto.graph.input:
107        dynamic_inputs += [
108            dim
109            for dim in inp.type.tensor_type.shape.dim
110            if dim.dim_value == 0 and dim.dim_param != ""
111        ]
112    assert dynamic_shapes == (
113        len(dynamic_inputs) > 0
114    ), "Dynamic shape check failed for graph inputs"
115
116
117def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
118    """Combine class name with the parameterized arguments.
119
120    This function is passed to `parameterized.parameterized_class` as the
121    `class_name_func` argument.
122    """
123    suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
124    return f"{cls.__name__}_{suffix}"
125
126
127class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
128    opset_version = _constants.ONNX_DEFAULT_OPSET
129    keep_initializers_as_inputs = True  # For IR version 3 type export.
130    is_script = False
131    check_shape = True
132    check_dtype = True
133
134    def setUp(self):
135        super().setUp()
136        onnxruntime.set_seed(0)
137        if torch.cuda.is_available():
138            torch.cuda.manual_seed_all(0)
139        os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
140        self.is_script_test_enabled = True
141
142    # The exported ONNX model may have less inputs than the pytorch model because of const folding.
143    # This mostly happens in unit test, where we widely use torch.size or torch.shape.
144    # So the output is only dependent on the input shape, not value.
145    # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
146    def run_test(
147        self,
148        model,
149        input_args,
150        input_kwargs=None,
151        rtol=1e-3,
152        atol=1e-7,
153        do_constant_folding=True,
154        dynamic_axes=None,
155        additional_test_inputs=None,
156        input_names=None,
157        output_names=None,
158        fixed_batch_size=False,
159        training=torch.onnx.TrainingMode.EVAL,
160        remained_onnx_input_idx=None,
161        verbose=False,
162    ):
163        def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
164            return run_model_test(
165                self,
166                m,
167                input_args=input_args,
168                input_kwargs=input_kwargs,
169                rtol=rtol,
170                atol=atol,
171                do_constant_folding=do_constant_folding,
172                dynamic_axes=dynamic_axes,
173                additional_test_inputs=additional_test_inputs,
174                input_names=input_names,
175                output_names=output_names,
176                fixed_batch_size=fixed_batch_size,
177                training=training,
178                remained_onnx_input_idx=remained_onnx_input_idx,
179                flatten=flatten,
180                ignore_none=ignore_none,
181                verbose=verbose,
182            )
183
184        if isinstance(remained_onnx_input_idx, dict):
185            scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
186            tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
187        else:
188            scripting_remained_onnx_input_idx = remained_onnx_input_idx
189            tracing_remained_onnx_input_idx = remained_onnx_input_idx
190
191        is_model_script = isinstance(
192            model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
193        )
194
195        if self.is_script_test_enabled and self.is_script:
196            script_model = model if is_model_script else torch.jit.script(model)
197            _run_test(
198                script_model,
199                scripting_remained_onnx_input_idx,
200                flatten=False,
201                ignore_none=False,
202            )
203        if not is_model_script and not self.is_script:
204            _run_test(model, tracing_remained_onnx_input_idx)
205
206    def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
207        self,
208        model: _ModelType,
209        input_args: Sequence[_InputArgsType],
210        *,
211        input_kwargs: Optional[Mapping[str, _InputArgsType]] = None,
212        rtol: Optional[float] = 1e-3,
213        atol: Optional[float] = 1e-7,
214        has_mutation: bool = False,
215        additional_test_inputs: Optional[
216            List[
217                Union[
218                    Tuple[Sequence[_InputArgsType], Mapping[str, _InputArgsType]],
219                    Tuple[Sequence[_InputArgsType]],
220                ]
221            ]
222        ] = None,
223        skip_dynamic_shapes_check: bool = False,
224    ):
225        """Compare the results of PyTorch model with exported ONNX model
226
227        Args:
228            model (_ModelType): PyTorch model
229            input_args (Sequence[_InputArgsType]): torch input arguments
230            input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs
231            rtol (float, optional): relative tolerance. Defaults to 1e-3.
232            atol (float, optional): absolute tolerance. Defaults to 1e-7.
233            has_mutation (bool, optional): Whether the model mutates its input or state.
234                `mutation` as `True` incurs extra overhead of cloning the inputs and model.
235                Defaults to False.
236            additional_test_inputs: Test the models with another dataset input, which
237                is designed for dynamic axes testing. Defaults to None. It's a list of
238                different input sets in tuples. Inside tuple, the first element is a tuple
239                of args, and the second element is a dict of kwargs. Remember to put comma
240                even if the following element is not provided.
241                For example,
242                additional_test_inputs = [((args1, args2), {"kwargs":1}), ((args1,),), ((), {"kwargs":1})]
243            skip_dynamic_shapes_check: Whether to skip dynamic shape check. Defaults to False.
244                Must be used when tests do not produce dynamic shapes even when dynamic shape feature is enabled.
245                This is needed because Torch Dynamo uses the dynamic_shapes flag as a hint, only.
246
247        """
248        from torch._dynamo import config as _dynamo_config
249
250        # avoid mutable data structure
251        if input_kwargs is None:
252            input_kwargs = {}
253
254        if (
255            has_mutation
256            and self.model_type
257            != pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
258        ):
259            ref_model = _try_clone_model(model)
260            ref_input_args, ref_input_kwargs = _try_clone_inputs(
261                input_args, input_kwargs
262            )
263        else:
264            ref_model = model
265            ref_input_args = input_args
266            ref_input_kwargs = input_kwargs
267
268        assert isinstance(ref_model, torch.nn.Module) or callable(
269            ref_model
270        ), "Model must be a torch.nn.Module or callable"
271        if (
272            self.model_type
273            == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
274        ):
275            with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
276                ref_model = torch.export.export(ref_model, args=ref_input_args)
277            if (
278                self.dynamic_shapes
279            ):  # TODO: Support dynamic shapes for torch.export.ExportedProgram
280                #       https://github.com/pytorch/pytorch/issues/113705
281                pytest.xfail(
282                    reason="torch.export.ExportedProgram does not support dynamic shapes"
283                )
284
285        # Feed args and kwargs into exporter.
286        # Note that exporter should flatten kwargs into positional args the exported model;
287        # since ONNX doesn't represent kwargs.
288        with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
289            onnx_program = torch.onnx.dynamo_export(
290                ref_model,
291                *ref_input_args,
292                **ref_input_kwargs,
293                export_options=torch.onnx.ExportOptions(
294                    dynamic_shapes=self.dynamic_shapes,
295                    diagnostic_options=torch.onnx.DiagnosticOptions(
296                        verbosity_level=logging.DEBUG
297                    ),
298                ),
299            )
300
301        if not skip_dynamic_shapes_check:
302            assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
303
304        if isinstance(ref_model, torch.export.ExportedProgram):
305            ref_model = ref_model.module()
306
307        _compare_pytorch_onnx_with_ort(
308            onnx_program,
309            ref_model,
310            input_args,
311            input_kwargs,
312            atol,
313            rtol,
314            has_mutation=has_mutation,
315        )
316        # This confirms the exported mode accepts different input shapes
317        # when dynamic shape is enabled.
318        if additional_test_inputs and self.dynamic_shapes:
319            for another_input in additional_test_inputs:
320                if len(another_input) > 2:
321                    raise ValueError(
322                        f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}"
323                    )
324                additional_input_args = another_input[0]
325                additional_input_kwargs = (
326                    another_input[1]
327                    if len(another_input) == 2 and another_input[1] is not None
328                    else {}
329                )
330                _compare_pytorch_onnx_with_ort(
331                    onnx_program,
332                    ref_model,
333                    additional_input_args,
334                    additional_input_kwargs,
335                    atol,
336                    rtol,
337                    has_mutation=has_mutation,
338                )
339
340
341def run_ort(
342    onnx_model: Union[str, torch.onnx.ONNXProgram],
343    pytorch_inputs: Sequence[_InputArgsType],
344) -> _OutputsType:
345    """Run ORT on the given ONNX model and inputs
346
347    Used in test_fx_to_onnx_with_onnxruntime.py
348
349    Args:
350        onnx_model (Union[str, torch.onnx.ONNXProgram]): Converter ONNX model
351        pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs
352
353    Raises:
354        AssertionError: ONNX and PyTorch should have the same input sizes
355
356    Returns:
357        _OutputsType: ONNX model predictions
358    """
359    if isinstance(onnx_model, torch.onnx.ONNXProgram):
360        buffer = io.BytesIO()
361        onnx_model.save(buffer)
362        ort_model = buffer.getvalue()
363    else:
364        ort_model = onnx_model
365
366    # Suppress floods of warnings from ONNX Runtime
367    session_options = onnxruntime.SessionOptions()
368    session_options.log_severity_level = 3  # Error
369    session = onnxruntime.InferenceSession(
370        ort_model, providers=["CPUExecutionProvider"], sess_options=session_options
371    )
372    input_names = [ort_input.name for ort_input in session.get_inputs()]
373
374    if len(input_names) != len(pytorch_inputs):
375        raise AssertionError(
376            f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
377        )
378
379    ort_input = {
380        k: torch.Tensor.numpy(v, force=True)
381        for k, v in zip(input_names, pytorch_inputs)
382    }
383    return session.run(None, ort_input)
384
385
386def _try_clone_model(model: _ModelType) -> _ModelType:
387    """Used for preserving original model in case forward mutates model states."""
388    try:
389        return copy.deepcopy(model)
390    except Exception:
391        warnings.warn(
392            "Failed to clone model. Model state might be mutated during verification."
393        )
394        return model
395
396
397def _try_clone_inputs(input_args, input_kwargs):
398    ref_input_args = copy.deepcopy(input_args)
399    ref_input_kwargs = copy.deepcopy(input_kwargs)
400    return ref_input_args, ref_input_kwargs
401
402
403def _compare_pytorch_onnx_with_ort(
404    onnx_program: torch.onnx.ONNXProgram,
405    model: _ModelType,
406    input_args: Sequence[_InputArgsType],
407    input_kwargs: Mapping[str, _InputArgsType],
408    atol: Optional[float] = None,
409    rtol: Optional[float] = None,
410    has_mutation: bool = False,
411):
412    if has_mutation:
413        ref_model = _try_clone_model(model)
414        ref_input_args, ref_input_kwargs = _try_clone_inputs(input_args, input_kwargs)
415    else:
416        ref_model = model
417        ref_input_args = input_args
418        ref_input_kwargs = input_kwargs
419
420    # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
421    # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
422    # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
423    # NOTE: `model_with_state_dict=ref_model` is specified to cover runs with FakeTensor support
424    ort_outputs = onnx_program(*input_args, **input_kwargs)
425    ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
426    ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
427
428    if len(ref_outputs) != len(ort_outputs):
429        raise AssertionError(
430            f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
431        )
432
433    for ref_output, ort_output in zip(ref_outputs, ort_outputs):
434        torch.testing.assert_close(
435            ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
436        )
437
438
439# The min onnx opset version to test for
440MIN_ONNX_OPSET_VERSION = 9
441# The max onnx opset version to test for
442MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
443TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
444
445# The min onnx opset version to test for
446FX_MIN_ONNX_OPSET_VERSION = 18
447# The max onnx opset version to test for
448FX_MAX_ONNX_OPSET_VERSION = 18
449FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
450
451BOOL_TYPES = (torch.bool,)
452
453INT_TYPES = (
454    # torch.int8,
455    # torch.int16,
456    torch.int32,
457    torch.int64,
458    # torch.uint8,
459)
460
461QINT_TYPES = (
462    torch.qint8,
463    torch.quint8,
464)
465
466FLOAT_TYPES = (
467    torch.float16,
468    torch.float32,
469    # torch.float64,  ORT doesn't support
470)
471
472COMPLEX_TYPES = (
473    # torch.complex32,  NOTE: torch.complex32 is experimental in torch
474    torch.complex64,
475    # torch.complex128,  ORT doesn't support
476)
477
478TESTED_DTYPES = (
479    # Boolean
480    torch.bool,
481    # Integers
482    *INT_TYPES,
483    # Floating types
484    *FLOAT_TYPES,
485    # Complex types
486    *COMPLEX_TYPES,
487)
488
489
490@dataclasses.dataclass
491class DecorateMeta:
492    """Information about a test case to skip or xfail.
493
494    Adapted from functorch: functorch/test/common_utils.py
495
496    Attributes:
497        op_name: The name of the operator.
498        variant_name: The name of the OpInfo variant.
499        decorator: The decorator to apply to the test case.
500        opsets: The opsets to apply the decorator to.
501        dtypes: The dtypes to apply the decorator to.
502        reason: The reason for skipping.
503        test_behavior: The behavior of the test case. [skip or xfail]
504        matcher: The matcher to apply to the test case.
505        enabled_if: Whether to enable test behavior. Usually used on onnx/ort version control
506        model_type: The type of the torch model. Defaults to None.
507    """
508
509    op_name: str
510    variant_name: str
511    decorator: Callable
512    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
513    dtypes: Optional[Collection[torch.dtype]]
514    reason: str
515    test_behavior: str
516    matcher: Optional[Callable[[Any], bool]] = None
517    enabled_if: bool = True
518    model_type: Optional[pytorch_test_common.TorchModelType] = None
519
520    def contains_opset(self, opset: int) -> bool:
521        if self.opsets is None:
522            return True
523        return any(
524            opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
525            for opset_spec in self.opsets
526        )
527
528
529def xfail(
530    op_name: str,
531    variant_name: str = "",
532    *,
533    reason: str,
534    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
535    dtypes: Optional[Collection[torch.dtype]] = None,
536    matcher: Optional[Callable[[Any], bool]] = None,
537    enabled_if: bool = True,
538    model_type: Optional[pytorch_test_common.TorchModelType] = None,
539):
540    """Expects a OpInfo test to fail.
541
542    Args:
543        op_name: The name of the operator.
544        variant_name: The name of the variant.
545        opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
546        dtypes: The dtypes to expect the failure.
547        reason: The reason for the failure.
548        matcher: A function that matches the test sample input. It is used only when
549            xfail is in the SKIP_XFAIL_SUBTESTS list.
550        enabled_if: Whether to enable xfail. Usually used on onnx/ort version control
551        model_type: The type of the torch model. Defaults to None.
552    """
553    return DecorateMeta(
554        op_name=op_name,
555        variant_name=variant_name,
556        decorator=unittest.expectedFailure,
557        opsets=opsets,
558        dtypes=dtypes,
559        enabled_if=enabled_if,
560        matcher=matcher,
561        reason=reason,
562        test_behavior="xfail",
563        model_type=model_type,
564    )
565
566
567def skip(
568    op_name: str,
569    variant_name: str = "",
570    *,
571    reason: str,
572    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
573    dtypes: Optional[Collection[torch.dtype]] = None,
574    matcher: Optional[Callable[[Any], Any]] = None,
575    enabled_if: bool = True,
576    model_type: Optional[pytorch_test_common.TorchModelType] = None,
577):
578    """Skips a test case in OpInfo that we don't care about.
579
580    Likely because ONNX does not support the use case or it is by design.
581
582    Args:
583        op_name: The name of the operator.
584        variant_name: The name of the variant.
585        opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
586        dtypes: The dtypes to expect the failure.
587        reason: The reason for the failure.
588        matcher: A function that matches the test sample input. It is used only when
589            skip is in the SKIP_XFAIL_SUBTESTS list.
590        enabled_if: Whether to enable skip. Usually used on onnx/ort version control
591        model_type: The type of the torch model. Defaults to None.
592    """
593    return DecorateMeta(
594        op_name=op_name,
595        variant_name=variant_name,
596        decorator=unittest.skip(f"Skip: {reason}"),
597        opsets=opsets,
598        dtypes=dtypes,
599        reason=reason,
600        matcher=matcher,
601        enabled_if=enabled_if,
602        test_behavior="skip",
603        model_type=model_type,
604    )
605
606
607def skip_slow(
608    op_name: str,
609    variant_name: str = "",
610    *,
611    reason: str,
612    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
613    dtypes: Optional[Collection[torch.dtype]] = None,
614    matcher: Optional[Callable[[Any], Any]] = None,
615    model_type: Optional[pytorch_test_common.TorchModelType] = None,
616):
617    """Skips a test case in OpInfo that is too slow.
618
619    It needs further investigation to understand why it is slow.
620
621    Args:
622        op_name: The name of the operator.
623        variant_name: The name of the variant.
624        opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
625        dtypes: The dtypes to expect the failure.
626        reason: The reason for the failure.
627        matcher: A function that matches the test sample input. It is used only when
628            skip is in the SKIP_XFAIL_SUBTESTS list.
629        model_type: The type of the torch model. Defaults to None.
630    """
631    return DecorateMeta(
632        op_name=op_name,
633        variant_name=variant_name,
634        decorator=common_utils.slowTest,
635        opsets=opsets,
636        dtypes=dtypes,
637        reason=reason,
638        matcher=matcher,
639        enabled_if=not common_utils.TEST_WITH_SLOW,
640        test_behavior="skip",
641        model_type=model_type,
642    )
643
644
645def add_decorate_info(
646    all_opinfos: Sequence[opinfo_core.OpInfo],
647    test_class_name: str,
648    base_test_name: str,
649    opset: int,
650    skip_or_xfails: Iterable[DecorateMeta],
651):
652    """Decorates OpInfo tests with decorators based on the skip_or_xfails list.
653
654    Args:
655        all_opinfos: All OpInfos.
656        test_class_name: The name of the test class.
657        base_test_name: The name of the test method.
658        opset: The opset to decorate for.
659        skip_or_xfails: DecorateMeta's.
660    """
661    ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
662    for decorate_meta in skip_or_xfails:
663        if not decorate_meta.contains_opset(opset):
664            # Skip does not apply to this opset
665            continue
666        opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
667        assert (
668            opinfo is not None
669        ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
670        assert decorate_meta.model_type is None, (
671            f"Tested op: {decorate_meta.op_name} in wrong position! "
672            "If model_type needs to be specified, it should be "
673            "put under SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE."
674        )
675        decorators = list(opinfo.decorators)
676        new_decorator = opinfo_core.DecorateInfo(
677            decorate_meta.decorator,
678            test_class_name,
679            base_test_name,
680            dtypes=decorate_meta.dtypes,
681            active_if=decorate_meta.enabled_if,
682        )
683        decorators.append(new_decorator)
684        opinfo.decorators = tuple(decorators)
685
686    # This decorator doesn't modify fn in any way
687    def wrapped(fn):
688        return fn
689
690    return wrapped
691
692
693def opsets_before(opset: int) -> Callable[[int], bool]:
694    """Returns a comparison function that decides if the given opset is before the specified."""
695
696    def compare(other_opset: int):
697        return other_opset < opset
698
699    return compare
700
701
702def opsets_after(opset: int) -> Callable[[int], bool]:
703    """Returns a comparison function that decides if the given opset is after the specified."""
704
705    def compare(other_opset: int):
706        return other_opset > opset
707
708    return compare
709
710
711def reason_onnx_script_does_not_support(
712    operator: str, dtypes: Optional[Sequence[str]] = None
713) -> str:
714    """Formats the reason: ONNX script doesn't support the given dtypes."""
715    return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX script"
716
717
718def reason_onnx_runtime_does_not_support(
719    operator: str, dtypes: Optional[Sequence[str]] = None
720) -> str:
721    """Formats the reason: ONNX Runtime doesn't support the given dtypes."""
722    return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
723
724
725def reason_onnx_does_not_support(
726    operator: str, dtypes: Optional[Sequence[str]] = None
727) -> str:
728    """Formats the reason: ONNX doesn't support the given dtypes."""
729    return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
730
731
732def reason_dynamo_does_not_support(
733    operator: str, dtypes: Optional[Sequence[str]] = None
734) -> str:
735    """Formats the reason: Dynamo doesn't support the given dtypes."""
736    return (
737        f"{operator} on {dtypes or 'certain dtypes'} not supported by the Dynamo Spec"
738    )
739
740
741def reason_jit_tracer_error(info: str) -> str:
742    """Formats the reason: JIT tracer errors."""
743    return f"JIT tracer error on {info}"
744
745
746def reason_flaky() -> str:
747    """Formats the reason: test is flaky."""
748    return "flaky test"
749
750
751@contextlib.contextmanager
752def normal_xfail_skip_test_behaviors(
753    test_behavior: Optional[str] = None, reason: Optional[str] = None
754):
755    """This context manager is used to handle the different behaviors of xfail and skip.
756
757    Args:
758        test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None.
759        reason (optional[str]): The reason for the failure or skip.
760
761    Raises:
762        e: Any exception raised by the test case if it's not an expected failure.
763    """
764
765    # We need to skip as soon as possible, as SegFault might also be a case.
766    if test_behavior == "skip":
767        pytest.skip(reason=reason)
768
769    try:
770        yield
771    # We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
772    # to go over all test cases to find the right exception type.
773    except Exception as e:  # pylint: disable=broad-exception-caught
774        if test_behavior is None:
775            raise e
776        if test_behavior == "xfail":
777            pytest.xfail(reason=reason)
778    else:
779        if test_behavior == "xfail":
780            pytest.fail("Test unexpectedly passed")
781