1# Owner(s): ["module: onnx"] 2 3"""Test consistency between the output values of torch.onnx exported operators 4and torch operators given the same inputs. 5 6Usage: 7 8 pytest test/onnx/test_op_consistency.py 9 10 To run tests on a specific operator (e.g. torch.ceil): 11 12 pytest test/onnx/test_op_consistency.py -k ceil 13 pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention 14 15 Read more on Running and writing tests: 16 https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests 17 18Note: 19 20 When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and 21 TESTED_OPS lists. See "Modify this section" 22 23""" 24 25from __future__ import annotations 26 27import copy 28from typing import Optional, Tuple 29 30import onnx_test_common 31import parameterized 32 33# For readability, these two are allowed to be imported as function 34from onnx_test_common import skip, xfail 35 36import torch 37from torch.testing._internal import ( 38 common_device_type, 39 common_methods_invocations, 40 common_utils, 41) 42 43 44OPS_DB = copy.deepcopy(common_methods_invocations.op_db) 45 46# Modify this section ########################################################## 47# NOTE: Modify this section as more ops are supported. The list should be sorted 48# alphabetically. 49# 50# For example, to add a test for torch.ceil: 51# 1. Add "ceil" to TESTED_OPS then run pytest. 52# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS. 53 54# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled. 55# Ops to be tested for numerical consistency between onnx and pytorch 56# TODO: https://github.com/pytorch/pytorch/issues/102211 57TESTED_OPS: frozenset[str] = frozenset( 58 [ 59 "atan", 60 "atan2", 61 # "atleast_1d", # How to support list input? 62 # "atleast_2d", 63 # "atleast_3d", 64 "broadcast_to", 65 "ceil", 66 "expand", 67 "flatten", 68 "hstack", 69 "logical_not", 70 # "logit", 71 "nn.functional.scaled_dot_product_attention", 72 "repeat", 73 "round", 74 # "scatter_add", 75 # "scatter_reduce", 76 "sqrt", 77 "stft", 78 "t", 79 "tile", 80 "unflatten", 81 "vstack", 82 ] 83) 84 85# fmt: off 86# Turn off black formatting to keep the list compact 87 88# Expected failures for onnx export. 89# The list should be sorted alphabetically by op name. 90# Q: When should I use fixme vs vs skip vs xfail? 91# A: Prefer xfail over skip when possible. 92# 2a. If a test is now failing because of xpass, because some previous errors 93# are now fixed, removed the corresponding xfail. 94# 2b. If a test is not failing consistently, use skip. 95EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = ( 96 skip( 97 "atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, 98 reason=onnx_test_common.reason_onnx_does_not_support("Atan") 99 ), 100 xfail("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])), 101 skip( 102 "atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, 103 reason=onnx_test_common.reason_onnx_does_not_support("Atan") 104 ), 105 xfail( 106 "atan2", dtypes=[torch.float64], 107 reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"]) 108 ), 109 xfail( 110 "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, 111 reason=onnx_test_common.reason_onnx_does_not_support("Ceil") 112 ), 113 skip("hstack", opsets=[onnx_test_common.opsets_before(11)], 114 reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")), 115 xfail( 116 "logit", 117 dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, 118 reason=onnx_test_common.reason_onnx_does_not_support("Log", "bool, int"), 119 ), 120 skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."), 121 skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"), 122 xfail("round", opsets=[onnx_test_common.opsets_before(11)], 123 reason=onnx_test_common.reason_onnx_does_not_support("Round")), 124 xfail("round", variant_name="decimals_0", opsets=[onnx_test_common.opsets_before(11)], 125 reason=onnx_test_common.reason_onnx_does_not_support("Round")), 126 xfail("round", variant_name="decimals_3", opsets=[onnx_test_common.opsets_before(11)], 127 reason=onnx_test_common.reason_onnx_does_not_support("Round")), 128 xfail("round", variant_name="decimals_neg_3", opsets=[onnx_test_common.opsets_before(11)], 129 reason=onnx_test_common.reason_onnx_does_not_support("Round")), 130 skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)], 131 reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), 132 skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)], 133 reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), 134 skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)], 135 reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), 136 xfail("scatter_reduce", variant_name="mean", 137 reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")), 138 skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)], 139 reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), 140 xfail( 141 "scatter_reduce", 142 variant_name="sum", 143 dtypes=(torch.float16,), 144 reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), 145 ), 146 xfail( 147 "scatter_reduce", 148 variant_name="prod", 149 dtypes=(torch.float16,), 150 reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), 151 ), 152 xfail( 153 "scatter_reduce", 154 variant_name="amin", 155 dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), 156 reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), 157 ), 158 xfail( 159 "scatter_reduce", 160 variant_name="amax", 161 dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), 162 reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), 163 ), 164 xfail( 165 "scatter_reduce", 166 variant_name="mean", 167 reason="ONNX doesn't support reduce='mean' option", 168 ), 169 skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")), 170 skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")), 171 xfail("stft", 172 reason=onnx_test_common.reason_onnx_runtime_does_not_support("STFT", "Regression on ORT=1.15 4 percent difference")), 173 skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")), 174 xfail("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."), 175 skip("vstack", opsets=[onnx_test_common.opsets_before(11)], 176 reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")), 177) 178# fmt: on 179 180SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = ( 181 skip( 182 "nn.functional.scaled_dot_product_attention", 183 matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, 184 reason="dropout is random so the results do not match", 185 ), 186 skip( 187 "repeat", 188 reason="Empty repeats value leads to an invalid graph", 189 matcher=lambda sample: not sample.args[0], 190 ), 191 skip( 192 "scatter_reduce", 193 # ONNX has not include_self parameter and default is include_self=True mode 194 matcher=lambda sample: sample.kwargs.get("include_self") is False, 195 reason="ONNX does't support include_self=False option", 196 ), 197 skip( 198 "stft", 199 reason="ONNX STFT does not support complex results", 200 matcher=lambda sample: sample.kwargs.get("return_complex") is True, 201 ), 202 skip( 203 "tile", 204 matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) 205 or not sample.input.shape, 206 reason="Logic not implemented for size 0 inputs in op.Reshape", 207 ), 208 skip( 209 "unflatten", 210 reason="Logic not implemented for size 0 inputs in op.Reshape", 211 matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), 212 ), 213) 214 215 216# END OF SECTION TO MODIFY ##################################################### 217 218OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS) 219ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) 220# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB 221assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB" 222 223 224class SingleOpModel(torch.nn.Module): 225 """Test model to wrap around a single op for export.""" 226 227 def __init__(self, op, kwargs): 228 super().__init__() 229 self.operator = op 230 self.kwargs = kwargs 231 232 def forward(self, *args): 233 return self.operator(*args, **self.kwargs) 234 235 236def _should_skip_xfail_test_sample( 237 op_name: str, sample 238) -> Tuple[Optional[str], Optional[str]]: 239 """Returns a reason if a test sample should be skipped.""" 240 if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: 241 return None, None 242 for decorator_meta in SKIP_XFAIL_SUBTESTS: 243 # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small. 244 if decorator_meta.op_name == op_name: 245 assert decorator_meta.matcher is not None, "Matcher must be defined" 246 if decorator_meta.matcher(sample): 247 return decorator_meta.test_behavior, decorator_meta.reason 248 return None, None 249 250 251def _get_test_class_name(cls, num, params_dict) -> str: 252 del cls # unused 253 del num # unused 254 return params_dict["name"] 255 256 257@parameterized.parameterized_class( 258 [ 259 { 260 "name": f"TestOnnxModelOutputConsistency_opset{opset}", 261 "opset_version": opset, 262 } 263 for opset in onnx_test_common.TESTED_OPSETS 264 ], 265 class_name_func=_get_test_class_name, 266) 267class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): 268 """Test output consistency between exported ONNX models and PyTorch eager mode. 269 270 This is a parameterized test suite. 271 """ 272 273 opset_version = -1 274 275 @common_device_type.ops( 276 [op for op in OPS_DB if op.name in TESTED_OPS], 277 allowed_dtypes=onnx_test_common.INT_TYPES 278 + onnx_test_common.FLOAT_TYPES 279 + onnx_test_common.BOOL_TYPES, 280 ) 281 def test_output_match(self, device: str, dtype: torch.dtype, op): 282 """Test the ONNX exporter.""" 283 # device is provided by instantiate_device_type_tests, but we only want to run in cpu. 284 assert device == "cpu" 285 286 samples = op.sample_inputs( 287 device, 288 dtype, 289 requires_grad=False, 290 ) 291 292 for i, cpu_sample in enumerate(samples): 293 inputs = (cpu_sample.input, *cpu_sample.args) 294 # Provide the repr to subtest because tensors are not serializable in parallel test runs 295 with self.subTest( 296 opset=self.opset_version, 297 sample_num=i, 298 inputs=repr(inputs), 299 kwargs=repr(cpu_sample.kwargs), 300 ): 301 test_behavior, reason = _should_skip_xfail_test_sample( 302 op.name, cpu_sample 303 ) 304 with onnx_test_common.normal_xfail_skip_test_behaviors( 305 test_behavior, reason 306 ): 307 model = SingleOpModel(op, cpu_sample.kwargs) 308 model.eval() 309 310 if dtype == torch.float32: 311 # Relax atol and rtol for float32 based on empirical results 312 # The current most relaxed values are for aten::stft 313 rtol = 1e-5 314 atol = 2e-5 315 elif dtype == torch.float64: 316 # The current most relaxed values are for aten::stft 317 rtol = 1e-5 318 atol = 2e-5 319 else: 320 rtol = None 321 atol = None 322 # Run the test 323 self.run_test(model, inputs, rtol=rtol, atol=atol) 324 325 326for opset in onnx_test_common.TESTED_OPSETS: 327 # The name needs to match the parameterized_class name. 328 test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}" 329 onnx_test_common.add_decorate_info( 330 OPS_DB, 331 test_class_name, 332 "test_output_match", 333 opset=opset, 334 skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, 335 ) 336 common_device_type.instantiate_device_type_tests( 337 globals()[test_class_name], globals(), only_for="cpu" 338 ) 339 340 341if __name__ == "__main__": 342 common_utils.run_tests() 343