1# Owner(s): ["module: onnx"] 2 3from __future__ import annotations 4 5import contextlib 6import copy 7import dataclasses 8import io 9import logging 10import os 11import unittest 12import warnings 13from typing import ( 14 Any, 15 Callable, 16 Collection, 17 Iterable, 18 List, 19 Mapping, 20 Optional, 21 Sequence, 22 Tuple, 23 Type, 24 Union, 25) 26 27import numpy as np 28import onnxruntime 29import pytest 30import pytorch_test_common 31 32import torch 33from torch import export as torch_export 34from torch.onnx import _constants, verification 35from torch.testing._internal import common_utils 36from torch.testing._internal.opinfo import core as opinfo_core 37from torch.types import Number 38 39 40_NumericType = Union[Number, torch.Tensor, np.ndarray] 41_ModelType = Union[torch.nn.Module, Callable, torch_export.ExportedProgram] 42_InputArgsType = Optional[ 43 Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]] 44] 45_OutputsType = Sequence[_NumericType] 46 47onnx_model_dir = os.path.join( 48 os.path.dirname(os.path.realpath(__file__)), 49 os.pardir, 50 "repos", 51 "onnx", 52 "onnx", 53 "backend", 54 "test", 55 "data", 56) 57 58 59pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted") 60 61 62pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator") 63 64 65def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): 66 options = verification.VerificationOptions() 67 68 kwargs["opset_version"] = test_suite.opset_version 69 kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs 70 if hasattr(test_suite, "check_shape"): 71 options.check_shape = test_suite.check_shape 72 if hasattr(test_suite, "check_dtype"): 73 options.check_dtype = test_suite.check_dtype 74 75 names = {f.name for f in dataclasses.fields(options)} 76 keywords_to_pop = [] 77 for k, v in kwargs.items(): 78 if k in names: 79 setattr(options, k, v) 80 keywords_to_pop.append(k) 81 for k in keywords_to_pop: 82 kwargs.pop(k) 83 84 return verification.verify(*args, options=options, **kwargs) 85 86 87def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool): 88 """Assert whether the exported model has dynamic shapes or not. 89 90 Args: 91 onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export. 92 dynamic_shapes (bool): Whether the exported model has dynamic shapes or not. 93 When True, raises if graph inputs don't have at least one dynamic dimension 94 When False, raises if graph inputs have at least one dynamic dimension. 95 96 Raises: 97 AssertionError: If the exported model has dynamic shapes and dynamic_shapes is False and vice-versa. 98 """ 99 100 if dynamic_shapes is None: 101 return 102 103 model_proto = onnx_program.model_proto 104 # Process graph inputs 105 dynamic_inputs = [] 106 for inp in model_proto.graph.input: 107 dynamic_inputs += [ 108 dim 109 for dim in inp.type.tensor_type.shape.dim 110 if dim.dim_value == 0 and dim.dim_param != "" 111 ] 112 assert dynamic_shapes == ( 113 len(dynamic_inputs) > 0 114 ), "Dynamic shape check failed for graph inputs" 115 116 117def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): 118 """Combine class name with the parameterized arguments. 119 120 This function is passed to `parameterized.parameterized_class` as the 121 `class_name_func` argument. 122 """ 123 suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items()) 124 return f"{cls.__name__}_{suffix}" 125 126 127class _TestONNXRuntime(pytorch_test_common.ExportTestCase): 128 opset_version = _constants.ONNX_DEFAULT_OPSET 129 keep_initializers_as_inputs = True # For IR version 3 type export. 130 is_script = False 131 check_shape = True 132 check_dtype = True 133 134 def setUp(self): 135 super().setUp() 136 onnxruntime.set_seed(0) 137 if torch.cuda.is_available(): 138 torch.cuda.manual_seed_all(0) 139 os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0" 140 self.is_script_test_enabled = True 141 142 # The exported ONNX model may have less inputs than the pytorch model because of const folding. 143 # This mostly happens in unit test, where we widely use torch.size or torch.shape. 144 # So the output is only dependent on the input shape, not value. 145 # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model. 146 def run_test( 147 self, 148 model, 149 input_args, 150 input_kwargs=None, 151 rtol=1e-3, 152 atol=1e-7, 153 do_constant_folding=True, 154 dynamic_axes=None, 155 additional_test_inputs=None, 156 input_names=None, 157 output_names=None, 158 fixed_batch_size=False, 159 training=torch.onnx.TrainingMode.EVAL, 160 remained_onnx_input_idx=None, 161 verbose=False, 162 ): 163 def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True): 164 return run_model_test( 165 self, 166 m, 167 input_args=input_args, 168 input_kwargs=input_kwargs, 169 rtol=rtol, 170 atol=atol, 171 do_constant_folding=do_constant_folding, 172 dynamic_axes=dynamic_axes, 173 additional_test_inputs=additional_test_inputs, 174 input_names=input_names, 175 output_names=output_names, 176 fixed_batch_size=fixed_batch_size, 177 training=training, 178 remained_onnx_input_idx=remained_onnx_input_idx, 179 flatten=flatten, 180 ignore_none=ignore_none, 181 verbose=verbose, 182 ) 183 184 if isinstance(remained_onnx_input_idx, dict): 185 scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"] 186 tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"] 187 else: 188 scripting_remained_onnx_input_idx = remained_onnx_input_idx 189 tracing_remained_onnx_input_idx = remained_onnx_input_idx 190 191 is_model_script = isinstance( 192 model, (torch.jit.ScriptModule, torch.jit.ScriptFunction) 193 ) 194 195 if self.is_script_test_enabled and self.is_script: 196 script_model = model if is_model_script else torch.jit.script(model) 197 _run_test( 198 script_model, 199 scripting_remained_onnx_input_idx, 200 flatten=False, 201 ignore_none=False, 202 ) 203 if not is_model_script and not self.is_script: 204 _run_test(model, tracing_remained_onnx_input_idx) 205 206 def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( 207 self, 208 model: _ModelType, 209 input_args: Sequence[_InputArgsType], 210 *, 211 input_kwargs: Optional[Mapping[str, _InputArgsType]] = None, 212 rtol: Optional[float] = 1e-3, 213 atol: Optional[float] = 1e-7, 214 has_mutation: bool = False, 215 additional_test_inputs: Optional[ 216 List[ 217 Union[ 218 Tuple[Sequence[_InputArgsType], Mapping[str, _InputArgsType]], 219 Tuple[Sequence[_InputArgsType]], 220 ] 221 ] 222 ] = None, 223 skip_dynamic_shapes_check: bool = False, 224 ): 225 """Compare the results of PyTorch model with exported ONNX model 226 227 Args: 228 model (_ModelType): PyTorch model 229 input_args (Sequence[_InputArgsType]): torch input arguments 230 input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs 231 rtol (float, optional): relative tolerance. Defaults to 1e-3. 232 atol (float, optional): absolute tolerance. Defaults to 1e-7. 233 has_mutation (bool, optional): Whether the model mutates its input or state. 234 `mutation` as `True` incurs extra overhead of cloning the inputs and model. 235 Defaults to False. 236 additional_test_inputs: Test the models with another dataset input, which 237 is designed for dynamic axes testing. Defaults to None. It's a list of 238 different input sets in tuples. Inside tuple, the first element is a tuple 239 of args, and the second element is a dict of kwargs. Remember to put comma 240 even if the following element is not provided. 241 For example, 242 additional_test_inputs = [((args1, args2), {"kwargs":1}), ((args1,),), ((), {"kwargs":1})] 243 skip_dynamic_shapes_check: Whether to skip dynamic shape check. Defaults to False. 244 Must be used when tests do not produce dynamic shapes even when dynamic shape feature is enabled. 245 This is needed because Torch Dynamo uses the dynamic_shapes flag as a hint, only. 246 247 """ 248 from torch._dynamo import config as _dynamo_config 249 250 # avoid mutable data structure 251 if input_kwargs is None: 252 input_kwargs = {} 253 254 if ( 255 has_mutation 256 and self.model_type 257 != pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM 258 ): 259 ref_model = _try_clone_model(model) 260 ref_input_args, ref_input_kwargs = _try_clone_inputs( 261 input_args, input_kwargs 262 ) 263 else: 264 ref_model = model 265 ref_input_args = input_args 266 ref_input_kwargs = input_kwargs 267 268 assert isinstance(ref_model, torch.nn.Module) or callable( 269 ref_model 270 ), "Model must be a torch.nn.Module or callable" 271 if ( 272 self.model_type 273 == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM 274 ): 275 with _dynamo_config.patch(do_not_emit_runtime_asserts=True): 276 ref_model = torch.export.export(ref_model, args=ref_input_args) 277 if ( 278 self.dynamic_shapes 279 ): # TODO: Support dynamic shapes for torch.export.ExportedProgram 280 # https://github.com/pytorch/pytorch/issues/113705 281 pytest.xfail( 282 reason="torch.export.ExportedProgram does not support dynamic shapes" 283 ) 284 285 # Feed args and kwargs into exporter. 286 # Note that exporter should flatten kwargs into positional args the exported model; 287 # since ONNX doesn't represent kwargs. 288 with _dynamo_config.patch(do_not_emit_runtime_asserts=True): 289 onnx_program = torch.onnx.dynamo_export( 290 ref_model, 291 *ref_input_args, 292 **ref_input_kwargs, 293 export_options=torch.onnx.ExportOptions( 294 dynamic_shapes=self.dynamic_shapes, 295 diagnostic_options=torch.onnx.DiagnosticOptions( 296 verbosity_level=logging.DEBUG 297 ), 298 ), 299 ) 300 301 if not skip_dynamic_shapes_check: 302 assert_dynamic_shapes(onnx_program, self.dynamic_shapes) 303 304 if isinstance(ref_model, torch.export.ExportedProgram): 305 ref_model = ref_model.module() 306 307 _compare_pytorch_onnx_with_ort( 308 onnx_program, 309 ref_model, 310 input_args, 311 input_kwargs, 312 atol, 313 rtol, 314 has_mutation=has_mutation, 315 ) 316 # This confirms the exported mode accepts different input shapes 317 # when dynamic shape is enabled. 318 if additional_test_inputs and self.dynamic_shapes: 319 for another_input in additional_test_inputs: 320 if len(another_input) > 2: 321 raise ValueError( 322 f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}" 323 ) 324 additional_input_args = another_input[0] 325 additional_input_kwargs = ( 326 another_input[1] 327 if len(another_input) == 2 and another_input[1] is not None 328 else {} 329 ) 330 _compare_pytorch_onnx_with_ort( 331 onnx_program, 332 ref_model, 333 additional_input_args, 334 additional_input_kwargs, 335 atol, 336 rtol, 337 has_mutation=has_mutation, 338 ) 339 340 341def run_ort( 342 onnx_model: Union[str, torch.onnx.ONNXProgram], 343 pytorch_inputs: Sequence[_InputArgsType], 344) -> _OutputsType: 345 """Run ORT on the given ONNX model and inputs 346 347 Used in test_fx_to_onnx_with_onnxruntime.py 348 349 Args: 350 onnx_model (Union[str, torch.onnx.ONNXProgram]): Converter ONNX model 351 pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs 352 353 Raises: 354 AssertionError: ONNX and PyTorch should have the same input sizes 355 356 Returns: 357 _OutputsType: ONNX model predictions 358 """ 359 if isinstance(onnx_model, torch.onnx.ONNXProgram): 360 buffer = io.BytesIO() 361 onnx_model.save(buffer) 362 ort_model = buffer.getvalue() 363 else: 364 ort_model = onnx_model 365 366 # Suppress floods of warnings from ONNX Runtime 367 session_options = onnxruntime.SessionOptions() 368 session_options.log_severity_level = 3 # Error 369 session = onnxruntime.InferenceSession( 370 ort_model, providers=["CPUExecutionProvider"], sess_options=session_options 371 ) 372 input_names = [ort_input.name for ort_input in session.get_inputs()] 373 374 if len(input_names) != len(pytorch_inputs): 375 raise AssertionError( 376 f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}" 377 ) 378 379 ort_input = { 380 k: torch.Tensor.numpy(v, force=True) 381 for k, v in zip(input_names, pytorch_inputs) 382 } 383 return session.run(None, ort_input) 384 385 386def _try_clone_model(model: _ModelType) -> _ModelType: 387 """Used for preserving original model in case forward mutates model states.""" 388 try: 389 return copy.deepcopy(model) 390 except Exception: 391 warnings.warn( 392 "Failed to clone model. Model state might be mutated during verification." 393 ) 394 return model 395 396 397def _try_clone_inputs(input_args, input_kwargs): 398 ref_input_args = copy.deepcopy(input_args) 399 ref_input_kwargs = copy.deepcopy(input_kwargs) 400 return ref_input_args, ref_input_kwargs 401 402 403def _compare_pytorch_onnx_with_ort( 404 onnx_program: torch.onnx.ONNXProgram, 405 model: _ModelType, 406 input_args: Sequence[_InputArgsType], 407 input_kwargs: Mapping[str, _InputArgsType], 408 atol: Optional[float] = None, 409 rtol: Optional[float] = None, 410 has_mutation: bool = False, 411): 412 if has_mutation: 413 ref_model = _try_clone_model(model) 414 ref_input_args, ref_input_kwargs = _try_clone_inputs(input_args, input_kwargs) 415 else: 416 ref_model = model 417 ref_input_args = input_args 418 ref_input_kwargs = input_kwargs 419 420 # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. 421 # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. 422 # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() 423 # NOTE: `model_with_state_dict=ref_model` is specified to cover runs with FakeTensor support 424 ort_outputs = onnx_program(*input_args, **input_kwargs) 425 ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs) 426 ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs) 427 428 if len(ref_outputs) != len(ort_outputs): 429 raise AssertionError( 430 f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" 431 ) 432 433 for ref_output, ort_output in zip(ref_outputs, ort_outputs): 434 torch.testing.assert_close( 435 ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol 436 ) 437 438 439# The min onnx opset version to test for 440MIN_ONNX_OPSET_VERSION = 9 441# The max onnx opset version to test for 442MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET 443TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1) 444 445# The min onnx opset version to test for 446FX_MIN_ONNX_OPSET_VERSION = 18 447# The max onnx opset version to test for 448FX_MAX_ONNX_OPSET_VERSION = 18 449FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1) 450 451BOOL_TYPES = (torch.bool,) 452 453INT_TYPES = ( 454 # torch.int8, 455 # torch.int16, 456 torch.int32, 457 torch.int64, 458 # torch.uint8, 459) 460 461QINT_TYPES = ( 462 torch.qint8, 463 torch.quint8, 464) 465 466FLOAT_TYPES = ( 467 torch.float16, 468 torch.float32, 469 # torch.float64, ORT doesn't support 470) 471 472COMPLEX_TYPES = ( 473 # torch.complex32, NOTE: torch.complex32 is experimental in torch 474 torch.complex64, 475 # torch.complex128, ORT doesn't support 476) 477 478TESTED_DTYPES = ( 479 # Boolean 480 torch.bool, 481 # Integers 482 *INT_TYPES, 483 # Floating types 484 *FLOAT_TYPES, 485 # Complex types 486 *COMPLEX_TYPES, 487) 488 489 490@dataclasses.dataclass 491class DecorateMeta: 492 """Information about a test case to skip or xfail. 493 494 Adapted from functorch: functorch/test/common_utils.py 495 496 Attributes: 497 op_name: The name of the operator. 498 variant_name: The name of the OpInfo variant. 499 decorator: The decorator to apply to the test case. 500 opsets: The opsets to apply the decorator to. 501 dtypes: The dtypes to apply the decorator to. 502 reason: The reason for skipping. 503 test_behavior: The behavior of the test case. [skip or xfail] 504 matcher: The matcher to apply to the test case. 505 enabled_if: Whether to enable test behavior. Usually used on onnx/ort version control 506 model_type: The type of the torch model. Defaults to None. 507 """ 508 509 op_name: str 510 variant_name: str 511 decorator: Callable 512 opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] 513 dtypes: Optional[Collection[torch.dtype]] 514 reason: str 515 test_behavior: str 516 matcher: Optional[Callable[[Any], bool]] = None 517 enabled_if: bool = True 518 model_type: Optional[pytorch_test_common.TorchModelType] = None 519 520 def contains_opset(self, opset: int) -> bool: 521 if self.opsets is None: 522 return True 523 return any( 524 opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset) 525 for opset_spec in self.opsets 526 ) 527 528 529def xfail( 530 op_name: str, 531 variant_name: str = "", 532 *, 533 reason: str, 534 opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, 535 dtypes: Optional[Collection[torch.dtype]] = None, 536 matcher: Optional[Callable[[Any], bool]] = None, 537 enabled_if: bool = True, 538 model_type: Optional[pytorch_test_common.TorchModelType] = None, 539): 540 """Expects a OpInfo test to fail. 541 542 Args: 543 op_name: The name of the operator. 544 variant_name: The name of the variant. 545 opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)] 546 dtypes: The dtypes to expect the failure. 547 reason: The reason for the failure. 548 matcher: A function that matches the test sample input. It is used only when 549 xfail is in the SKIP_XFAIL_SUBTESTS list. 550 enabled_if: Whether to enable xfail. Usually used on onnx/ort version control 551 model_type: The type of the torch model. Defaults to None. 552 """ 553 return DecorateMeta( 554 op_name=op_name, 555 variant_name=variant_name, 556 decorator=unittest.expectedFailure, 557 opsets=opsets, 558 dtypes=dtypes, 559 enabled_if=enabled_if, 560 matcher=matcher, 561 reason=reason, 562 test_behavior="xfail", 563 model_type=model_type, 564 ) 565 566 567def skip( 568 op_name: str, 569 variant_name: str = "", 570 *, 571 reason: str, 572 opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, 573 dtypes: Optional[Collection[torch.dtype]] = None, 574 matcher: Optional[Callable[[Any], Any]] = None, 575 enabled_if: bool = True, 576 model_type: Optional[pytorch_test_common.TorchModelType] = None, 577): 578 """Skips a test case in OpInfo that we don't care about. 579 580 Likely because ONNX does not support the use case or it is by design. 581 582 Args: 583 op_name: The name of the operator. 584 variant_name: The name of the variant. 585 opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)] 586 dtypes: The dtypes to expect the failure. 587 reason: The reason for the failure. 588 matcher: A function that matches the test sample input. It is used only when 589 skip is in the SKIP_XFAIL_SUBTESTS list. 590 enabled_if: Whether to enable skip. Usually used on onnx/ort version control 591 model_type: The type of the torch model. Defaults to None. 592 """ 593 return DecorateMeta( 594 op_name=op_name, 595 variant_name=variant_name, 596 decorator=unittest.skip(f"Skip: {reason}"), 597 opsets=opsets, 598 dtypes=dtypes, 599 reason=reason, 600 matcher=matcher, 601 enabled_if=enabled_if, 602 test_behavior="skip", 603 model_type=model_type, 604 ) 605 606 607def skip_slow( 608 op_name: str, 609 variant_name: str = "", 610 *, 611 reason: str, 612 opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, 613 dtypes: Optional[Collection[torch.dtype]] = None, 614 matcher: Optional[Callable[[Any], Any]] = None, 615 model_type: Optional[pytorch_test_common.TorchModelType] = None, 616): 617 """Skips a test case in OpInfo that is too slow. 618 619 It needs further investigation to understand why it is slow. 620 621 Args: 622 op_name: The name of the operator. 623 variant_name: The name of the variant. 624 opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)] 625 dtypes: The dtypes to expect the failure. 626 reason: The reason for the failure. 627 matcher: A function that matches the test sample input. It is used only when 628 skip is in the SKIP_XFAIL_SUBTESTS list. 629 model_type: The type of the torch model. Defaults to None. 630 """ 631 return DecorateMeta( 632 op_name=op_name, 633 variant_name=variant_name, 634 decorator=common_utils.slowTest, 635 opsets=opsets, 636 dtypes=dtypes, 637 reason=reason, 638 matcher=matcher, 639 enabled_if=not common_utils.TEST_WITH_SLOW, 640 test_behavior="skip", 641 model_type=model_type, 642 ) 643 644 645def add_decorate_info( 646 all_opinfos: Sequence[opinfo_core.OpInfo], 647 test_class_name: str, 648 base_test_name: str, 649 opset: int, 650 skip_or_xfails: Iterable[DecorateMeta], 651): 652 """Decorates OpInfo tests with decorators based on the skip_or_xfails list. 653 654 Args: 655 all_opinfos: All OpInfos. 656 test_class_name: The name of the test class. 657 base_test_name: The name of the test method. 658 opset: The opset to decorate for. 659 skip_or_xfails: DecorateMeta's. 660 """ 661 ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos} 662 for decorate_meta in skip_or_xfails: 663 if not decorate_meta.contains_opset(opset): 664 # Skip does not apply to this opset 665 continue 666 opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) 667 assert ( 668 opinfo is not None 669 ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" 670 assert decorate_meta.model_type is None, ( 671 f"Tested op: {decorate_meta.op_name} in wrong position! " 672 "If model_type needs to be specified, it should be " 673 "put under SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE." 674 ) 675 decorators = list(opinfo.decorators) 676 new_decorator = opinfo_core.DecorateInfo( 677 decorate_meta.decorator, 678 test_class_name, 679 base_test_name, 680 dtypes=decorate_meta.dtypes, 681 active_if=decorate_meta.enabled_if, 682 ) 683 decorators.append(new_decorator) 684 opinfo.decorators = tuple(decorators) 685 686 # This decorator doesn't modify fn in any way 687 def wrapped(fn): 688 return fn 689 690 return wrapped 691 692 693def opsets_before(opset: int) -> Callable[[int], bool]: 694 """Returns a comparison function that decides if the given opset is before the specified.""" 695 696 def compare(other_opset: int): 697 return other_opset < opset 698 699 return compare 700 701 702def opsets_after(opset: int) -> Callable[[int], bool]: 703 """Returns a comparison function that decides if the given opset is after the specified.""" 704 705 def compare(other_opset: int): 706 return other_opset > opset 707 708 return compare 709 710 711def reason_onnx_script_does_not_support( 712 operator: str, dtypes: Optional[Sequence[str]] = None 713) -> str: 714 """Formats the reason: ONNX script doesn't support the given dtypes.""" 715 return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX script" 716 717 718def reason_onnx_runtime_does_not_support( 719 operator: str, dtypes: Optional[Sequence[str]] = None 720) -> str: 721 """Formats the reason: ONNX Runtime doesn't support the given dtypes.""" 722 return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime" 723 724 725def reason_onnx_does_not_support( 726 operator: str, dtypes: Optional[Sequence[str]] = None 727) -> str: 728 """Formats the reason: ONNX doesn't support the given dtypes.""" 729 return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec" 730 731 732def reason_dynamo_does_not_support( 733 operator: str, dtypes: Optional[Sequence[str]] = None 734) -> str: 735 """Formats the reason: Dynamo doesn't support the given dtypes.""" 736 return ( 737 f"{operator} on {dtypes or 'certain dtypes'} not supported by the Dynamo Spec" 738 ) 739 740 741def reason_jit_tracer_error(info: str) -> str: 742 """Formats the reason: JIT tracer errors.""" 743 return f"JIT tracer error on {info}" 744 745 746def reason_flaky() -> str: 747 """Formats the reason: test is flaky.""" 748 return "flaky test" 749 750 751@contextlib.contextmanager 752def normal_xfail_skip_test_behaviors( 753 test_behavior: Optional[str] = None, reason: Optional[str] = None 754): 755 """This context manager is used to handle the different behaviors of xfail and skip. 756 757 Args: 758 test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None. 759 reason (optional[str]): The reason for the failure or skip. 760 761 Raises: 762 e: Any exception raised by the test case if it's not an expected failure. 763 """ 764 765 # We need to skip as soon as possible, as SegFault might also be a case. 766 if test_behavior == "skip": 767 pytest.skip(reason=reason) 768 769 try: 770 yield 771 # We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs 772 # to go over all test cases to find the right exception type. 773 except Exception as e: # pylint: disable=broad-exception-caught 774 if test_behavior is None: 775 raise e 776 if test_behavior == "xfail": 777 pytest.xfail(reason=reason) 778 else: 779 if test_behavior == "xfail": 780 pytest.fail("Test unexpectedly passed") 781