1# Owner(s): ["module: onnx"] 2from __future__ import annotations 3 4import contextlib 5import dataclasses 6import io 7import logging 8import typing 9from typing import AbstractSet, Protocol, Tuple 10 11import torch 12from torch.onnx import errors 13from torch.onnx._internal import diagnostics 14from torch.onnx._internal.diagnostics import infra 15from torch.onnx._internal.diagnostics.infra import formatter, sarif 16from torch.onnx._internal.fx import diagnostics as fx_diagnostics 17from torch.testing._internal import common_utils, logging_utils 18 19 20if typing.TYPE_CHECKING: 21 import unittest 22 23 24class _SarifLogBuilder(Protocol): 25 def sarif_log(self) -> sarif.SarifLog: ... 26 27 28def _assert_has_diagnostics( 29 sarif_log_builder: _SarifLogBuilder, 30 rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], 31): 32 sarif_log = sarif_log_builder.sarif_log() 33 unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} 34 actual_results = [] 35 for run in sarif_log.runs: 36 if run.results is None: 37 continue 38 for result in run.results: 39 id_level_pair = (result.rule_id, result.level) 40 unseen_pairs.discard(id_level_pair) 41 actual_results.append(id_level_pair) 42 43 if unseen_pairs: 44 raise AssertionError( 45 f"Expected diagnostic results of rule id and level pair {unseen_pairs} not found. " 46 f"Actual diagnostic results: {actual_results}" 47 ) 48 49 50@dataclasses.dataclass 51class _RuleCollectionForTest(infra.RuleCollection): 52 rule_without_message_args: infra.Rule = dataclasses.field( 53 default=infra.Rule( 54 "1", 55 "rule-without-message-args", 56 message_default_template="rule message", 57 ) 58 ) 59 60 61@contextlib.contextmanager 62def assert_all_diagnostics( 63 test_suite: unittest.TestCase, 64 sarif_log_builder: _SarifLogBuilder, 65 rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], 66): 67 """Context manager to assert that all diagnostics are emitted. 68 69 Usage: 70 with assert_all_diagnostics( 71 self, 72 diagnostics.engine, 73 {(rule, infra.Level.Error)}, 74 ): 75 torch.onnx.export(...) 76 77 Args: 78 test_suite: The test suite instance. 79 sarif_log_builder: The SARIF log builder. 80 rule_level_pairs: A set of rule and level pairs to assert. 81 82 Returns: 83 A context manager. 84 85 Raises: 86 AssertionError: If not all diagnostics are emitted. 87 """ 88 89 try: 90 yield 91 except errors.OnnxExporterError: 92 test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs}) 93 finally: 94 _assert_has_diagnostics(sarif_log_builder, rule_level_pairs) 95 96 97def assert_diagnostic( 98 test_suite: unittest.TestCase, 99 sarif_log_builder: _SarifLogBuilder, 100 rule: infra.Rule, 101 level: infra.Level, 102): 103 """Context manager to assert that a diagnostic is emitted. 104 105 Usage: 106 with assert_diagnostic( 107 self, 108 diagnostics.engine, 109 rule, 110 infra.Level.Error, 111 ): 112 torch.onnx.export(...) 113 114 Args: 115 test_suite: The test suite instance. 116 sarif_log_builder: The SARIF log builder. 117 rule: The rule to assert. 118 level: The level to assert. 119 120 Returns: 121 A context manager. 122 123 Raises: 124 AssertionError: If the diagnostic is not emitted. 125 """ 126 127 return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)}) 128 129 130class TestDynamoOnnxDiagnostics(common_utils.TestCase): 131 """Test cases for diagnostics emitted by the Dynamo ONNX export code.""" 132 133 def setUp(self): 134 self.diagnostic_context = fx_diagnostics.DiagnosticContext("dynamo_export", "") 135 self.rules = _RuleCollectionForTest() 136 return super().setUp() 137 138 def test_log_is_recorded_in_sarif_additional_messages_according_to_diagnostic_options_verbosity_level( 139 self, 140 ): 141 logging_levels = [ 142 logging.DEBUG, 143 logging.INFO, 144 logging.WARNING, 145 logging.ERROR, 146 ] 147 for verbosity_level in logging_levels: 148 self.diagnostic_context.options.verbosity_level = verbosity_level 149 with self.diagnostic_context: 150 diagnostic = fx_diagnostics.Diagnostic( 151 self.rules.rule_without_message_args, infra.Level.NONE 152 ) 153 additional_messages_count = len(diagnostic.additional_messages) 154 for log_level in logging_levels: 155 diagnostic.log(level=log_level, message="log message") 156 if log_level >= verbosity_level: 157 self.assertGreater( 158 len(diagnostic.additional_messages), 159 additional_messages_count, 160 f"Additional message should be recorded when log level is {log_level} " 161 f"and verbosity level is {verbosity_level}", 162 ) 163 else: 164 self.assertEqual( 165 len(diagnostic.additional_messages), 166 additional_messages_count, 167 f"Additional message should not be recorded when log level is " 168 f"{log_level} and verbosity level is {verbosity_level}", 169 ) 170 171 def test_torch_logs_environment_variable_precedes_diagnostic_options_verbosity_level( 172 self, 173 ): 174 self.diagnostic_context.options.verbosity_level = logging.ERROR 175 with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context: 176 diagnostic = fx_diagnostics.Diagnostic( 177 self.rules.rule_without_message_args, infra.Level.NONE 178 ) 179 additional_messages_count = len(diagnostic.additional_messages) 180 diagnostic.debug("message") 181 self.assertGreater( 182 len(diagnostic.additional_messages), additional_messages_count 183 ) 184 185 def test_log_is_not_emitted_to_terminal_when_log_artifact_is_not_enabled(self): 186 self.diagnostic_context.options.verbosity_level = logging.INFO 187 with self.diagnostic_context: 188 diagnostic = fx_diagnostics.Diagnostic( 189 self.rules.rule_without_message_args, infra.Level.NONE 190 ) 191 192 with self.assertLogs( 193 diagnostic.logger, level=logging.INFO 194 ) as assert_log_context: 195 diagnostic.info("message") 196 # NOTE: self.assertNoLogs only exist >= Python 3.10 197 # Add this dummy log such that we can pass self.assertLogs, and inspect 198 # assert_log_context.records to check if the log we don't want is not emitted. 199 diagnostic.logger.log(logging.ERROR, "dummy message") 200 201 self.assertEqual(len(assert_log_context.records), 1) 202 203 def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self): 204 self.diagnostic_context.options.verbosity_level = logging.INFO 205 206 with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context: 207 diagnostic = fx_diagnostics.Diagnostic( 208 self.rules.rule_without_message_args, infra.Level.NONE 209 ) 210 211 with self.assertLogs(diagnostic.logger, level=logging.INFO): 212 diagnostic.info("message") 213 214 def test_diagnostic_log_emit_correctly_formatted_string(self): 215 verbosity_level = logging.INFO 216 self.diagnostic_context.options.verbosity_level = verbosity_level 217 with self.diagnostic_context: 218 diagnostic = fx_diagnostics.Diagnostic( 219 self.rules.rule_without_message_args, infra.Level.NOTE 220 ) 221 diagnostic.log( 222 logging.INFO, 223 "%s", 224 formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"), 225 ) 226 self.assertIn("hello world", diagnostic.additional_messages) 227 228 def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong( 229 self, 230 ): 231 with self.diagnostic_context: 232 # Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic 233 # instead of base infra.Diagnostic. 234 diagnostic = infra.Diagnostic( 235 self.rules.rule_without_message_args, infra.Level.NOTE 236 ) 237 with self.assertRaises(TypeError): 238 self.diagnostic_context.log(diagnostic) 239 240 241class TestTorchScriptOnnxDiagnostics(common_utils.TestCase): 242 """Test cases for diagnostics emitted by the TorchScript ONNX export code.""" 243 244 def setUp(self): 245 engine = diagnostics.engine 246 engine.clear() 247 self._sample_rule = diagnostics.rules.missing_custom_symbolic_function 248 super().setUp() 249 250 def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp( 251 self, 252 ) -> diagnostics.TorchScriptOnnxExportDiagnostic: 253 class CustomAdd(torch.autograd.Function): 254 @staticmethod 255 def forward(ctx, x, y): 256 return x + y 257 258 @staticmethod 259 def symbolic(g, x, y): 260 return g.op("custom::CustomAdd", x, y) 261 262 class M(torch.nn.Module): 263 def forward(self, x): 264 return CustomAdd.apply(x, x) 265 266 # trigger warning for missing shape inference. 267 rule = diagnostics.rules.node_missing_onnx_shape_inference 268 torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) 269 270 context = diagnostics.engine.contexts[-1] 271 for diagnostic in context.diagnostics: 272 if ( 273 diagnostic.rule == rule 274 and diagnostic.level == diagnostics.levels.WARNING 275 ): 276 return typing.cast( 277 diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic 278 ) 279 raise AssertionError("No diagnostic found.") 280 281 def test_assert_diagnostic_raises_when_diagnostic_not_found(self): 282 with self.assertRaises(AssertionError): 283 with assert_diagnostic( 284 self, 285 diagnostics.engine, 286 diagnostics.rules.node_missing_onnx_shape_inference, 287 diagnostics.levels.WARNING, 288 ): 289 pass 290 291 def test_cpp_diagnose_emits_warning(self): 292 with assert_diagnostic( 293 self, 294 diagnostics.engine, 295 diagnostics.rules.node_missing_onnx_shape_inference, 296 diagnostics.levels.WARNING, 297 ): 298 # trigger warning for missing shape inference. 299 self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() 300 301 def test_py_diagnose_emits_error(self): 302 class M(torch.nn.Module): 303 def forward(self, x): 304 return torch.diagonal(x) 305 306 with assert_diagnostic( 307 self, 308 diagnostics.engine, 309 diagnostics.rules.operator_supported_in_newer_opset_version, 310 diagnostics.levels.ERROR, 311 ): 312 # trigger error for operator unsupported until newer opset version. 313 torch.onnx.export( 314 M(), 315 torch.randn(3, 4), 316 io.BytesIO(), 317 opset_version=9, 318 ) 319 320 def test_diagnostics_engine_records_diagnosis_reported_outside_of_export( 321 self, 322 ): 323 sample_level = diagnostics.levels.ERROR 324 with assert_diagnostic( 325 self, 326 diagnostics.engine, 327 self._sample_rule, 328 sample_level, 329 ): 330 diagnostic = infra.Diagnostic(self._sample_rule, sample_level) 331 diagnostics.export_context().log(diagnostic) 332 333 def test_diagnostics_records_python_call_stack(self): 334 diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip 335 # Do not break the above line, otherwise it will not work with Python-3.8+ 336 stack = diagnostic.python_call_stack 337 assert stack is not None # for mypy 338 self.assertGreater(len(stack.frames), 0) 339 frame = stack.frames[0] 340 assert frame.location.snippet is not None # for mypy 341 self.assertIn("self._sample_rule", frame.location.snippet) 342 assert frame.location.uri is not None # for mypy 343 self.assertIn("test_diagnostics.py", frame.location.uri) 344 345 def test_diagnostics_records_cpp_call_stack(self): 346 diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() 347 stack = diagnostic.cpp_call_stack 348 assert stack is not None # for mypy 349 self.assertGreater(len(stack.frames), 0) 350 frame_messages = [frame.location.message for frame in stack.frames] 351 # node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx) 352 # after node-level shape type inference and processed symbolic_fn output type 353 self.assertTrue( 354 any( 355 isinstance(message, str) and "torch::jit::NodeToONNX" in message 356 for message in frame_messages 357 ) 358 ) 359 360 361@common_utils.instantiate_parametrized_tests 362class TestDiagnosticsInfra(common_utils.TestCase): 363 """Test cases for diagnostics infra.""" 364 365 def setUp(self): 366 self.rules = _RuleCollectionForTest() 367 with contextlib.ExitStack() as stack: 368 self.context: infra.DiagnosticContext[infra.Diagnostic] = ( 369 stack.enter_context(infra.DiagnosticContext("test", "1.0.0")) 370 ) 371 self.addCleanup(stack.pop_all().close) 372 return super().setUp() 373 374 def test_diagnostics_engine_records_diagnosis_with_custom_rules(self): 375 custom_rules = infra.RuleCollection.custom_collection_from_list( 376 "CustomRuleCollection", 377 [ 378 infra.Rule( 379 "1", 380 "custom-rule", 381 message_default_template="custom rule message", 382 ), 383 infra.Rule( 384 "2", 385 "custom-rule-2", 386 message_default_template="custom rule message 2", 387 ), 388 ], 389 ) 390 391 with assert_all_diagnostics( 392 self, 393 self.context, 394 { 395 (custom_rules.custom_rule, infra.Level.WARNING), # type: ignore[attr-defined] 396 (custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined] 397 }, 398 ): 399 diagnostic1 = infra.Diagnostic( 400 custom_rules.custom_rule, # type: ignore[attr-defined] 401 infra.Level.WARNING, 402 ) 403 self.context.log(diagnostic1) 404 405 diagnostic2 = infra.Diagnostic( 406 custom_rules.custom_rule_2, # type: ignore[attr-defined] 407 infra.Level.ERROR, 408 ) 409 self.context.log(diagnostic2) 410 411 def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level( 412 self, 413 ): 414 verbosity_level = logging.INFO 415 self.context.options.verbosity_level = verbosity_level 416 with self.context: 417 diagnostic = infra.Diagnostic( 418 self.rules.rule_without_message_args, infra.Level.NOTE 419 ) 420 421 with self.assertLogs( 422 diagnostic.logger, level=verbosity_level 423 ) as assert_log_context: 424 diagnostic.log(logging.DEBUG, "debug message") 425 # NOTE: self.assertNoLogs only exist >= Python 3.10 426 # Add this dummy log such that we can pass self.assertLogs, and inspect 427 # assert_log_context.records to check if the log level is correct. 428 diagnostic.log(logging.INFO, "info message") 429 430 for record in assert_log_context.records: 431 self.assertGreaterEqual(record.levelno, logging.INFO) 432 self.assertFalse( 433 any( 434 message.find("debug message") >= 0 435 for message in diagnostic.additional_messages 436 ) 437 ) 438 439 def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level( 440 self, 441 ): 442 verbosity_level = logging.INFO 443 self.context.options.verbosity_level = verbosity_level 444 with self.context: 445 diagnostic = infra.Diagnostic( 446 self.rules.rule_without_message_args, infra.Level.NOTE 447 ) 448 449 level_message_pairs = [ 450 (logging.INFO, "info message"), 451 (logging.WARNING, "warning message"), 452 (logging.ERROR, "error message"), 453 ] 454 455 for level, message in level_message_pairs: 456 with self.assertLogs(diagnostic.logger, level=verbosity_level): 457 diagnostic.log(level, message) 458 459 self.assertTrue( 460 any( 461 message.find(message) >= 0 462 for message in diagnostic.additional_messages 463 ) 464 ) 465 466 @common_utils.parametrize( 467 "log_api, log_level", 468 [ 469 ("debug", logging.DEBUG), 470 ("info", logging.INFO), 471 ("warning", logging.WARNING), 472 ("error", logging.ERROR), 473 ], 474 ) 475 def test_diagnostic_log_is_emitted_according_to_api_level_and_diagnostic_options_verbosity_level( 476 self, log_api: str, log_level: int 477 ): 478 verbosity_level = logging.INFO 479 self.context.options.verbosity_level = verbosity_level 480 with self.context: 481 diagnostic = infra.Diagnostic( 482 self.rules.rule_without_message_args, infra.Level.NOTE 483 ) 484 485 message = "log message" 486 with self.assertLogs( 487 diagnostic.logger, level=verbosity_level 488 ) as assert_log_context: 489 getattr(diagnostic, log_api)(message) 490 # NOTE: self.assertNoLogs only exist >= Python 3.10 491 # Add this dummy log such that we can pass self.assertLogs, and inspect 492 # assert_log_context.records to check if the log level is correct. 493 diagnostic.log(logging.ERROR, "dummy message") 494 495 for record in assert_log_context.records: 496 self.assertGreaterEqual(record.levelno, logging.INFO) 497 498 if log_level >= verbosity_level: 499 self.assertIn(message, diagnostic.additional_messages) 500 else: 501 self.assertNotIn(message, diagnostic.additional_messages) 502 503 def test_diagnostic_log_lazy_string_is_not_evaluated_when_level_less_than_diagnostic_options_verbosity_level( 504 self, 505 ): 506 verbosity_level = logging.INFO 507 self.context.options.verbosity_level = verbosity_level 508 with self.context: 509 diagnostic = infra.Diagnostic( 510 self.rules.rule_without_message_args, infra.Level.NOTE 511 ) 512 513 reference_val = 0 514 515 def expensive_formatting_function() -> str: 516 # Modify the reference_val to reflect this function is evaluated 517 nonlocal reference_val 518 reference_val += 1 519 return f"expensive formatting {reference_val}" 520 521 # `expensive_formatting_function` should NOT be evaluated. 522 diagnostic.debug("%s", formatter.LazyString(expensive_formatting_function)) 523 self.assertEqual( 524 reference_val, 525 0, 526 "expensive_formatting_function should not be evaluated after being wrapped under LazyString", 527 ) 528 529 def test_diagnostic_log_lazy_string_is_evaluated_once_when_level_not_less_than_diagnostic_options_verbosity_level( 530 self, 531 ): 532 verbosity_level = logging.INFO 533 self.context.options.verbosity_level = verbosity_level 534 with self.context: 535 diagnostic = infra.Diagnostic( 536 self.rules.rule_without_message_args, infra.Level.NOTE 537 ) 538 539 reference_val = 0 540 541 def expensive_formatting_function() -> str: 542 # Modify the reference_val to reflect this function is evaluated 543 nonlocal reference_val 544 reference_val += 1 545 return f"expensive formatting {reference_val}" 546 547 # `expensive_formatting_function` should NOT be evaluated. 548 diagnostic.info("%s", formatter.LazyString(expensive_formatting_function)) 549 self.assertEqual( 550 reference_val, 551 1, 552 "expensive_formatting_function should only be evaluated once after being wrapped under LazyString", 553 ) 554 555 def test_diagnostic_log_emit_correctly_formatted_string(self): 556 verbosity_level = logging.INFO 557 self.context.options.verbosity_level = verbosity_level 558 with self.context: 559 diagnostic = infra.Diagnostic( 560 self.rules.rule_without_message_args, infra.Level.NOTE 561 ) 562 diagnostic.log( 563 logging.INFO, 564 "%s", 565 formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"), 566 ) 567 self.assertIn("hello world", diagnostic.additional_messages) 568 569 def test_diagnostic_nested_log_section_emits_messages_with_correct_section_title_indentation( 570 self, 571 ): 572 verbosity_level = logging.INFO 573 self.context.options.verbosity_level = verbosity_level 574 with self.context: 575 diagnostic = infra.Diagnostic( 576 self.rules.rule_without_message_args, infra.Level.NOTE 577 ) 578 579 with diagnostic.log_section(logging.INFO, "My Section"): 580 diagnostic.log(logging.INFO, "My Message") 581 with diagnostic.log_section(logging.INFO, "My Subsection"): 582 diagnostic.log(logging.INFO, "My Submessage") 583 584 with diagnostic.log_section(logging.INFO, "My Section 2"): 585 diagnostic.log(logging.INFO, "My Message 2") 586 587 self.assertIn("## My Section", diagnostic.additional_messages) 588 self.assertIn("### My Subsection", diagnostic.additional_messages) 589 self.assertIn("## My Section 2", diagnostic.additional_messages) 590 591 def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_message( 592 self, 593 ): 594 verbosity_level = logging.INFO 595 self.context.options.verbosity_level = verbosity_level 596 with self.context: 597 try: 598 raise ValueError("original exception") 599 except ValueError as e: 600 diagnostic = infra.Diagnostic( 601 self.rules.rule_without_message_args, infra.Level.NOTE 602 ) 603 diagnostic.log_source_exception(logging.ERROR, e) 604 605 diagnostic_message = "\n".join(diagnostic.additional_messages) 606 607 self.assertIn("ValueError: original exception", diagnostic_message) 608 self.assertIn("Traceback (most recent call last):", diagnostic_message) 609 610 def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong( 611 self, 612 ): 613 with self.context: 614 with self.assertRaises(TypeError): 615 # The method expects 'Diagnostic' or its subclasses as arguments. 616 # Passing any other type will trigger a TypeError. 617 self.context.log("This is a str message.") 618 619 def test_diagnostic_context_raises_if_diagnostic_is_error(self): 620 with self.assertRaises(infra.RuntimeErrorWithDiagnostic): 621 self.context.log_and_raise_if_error( 622 infra.Diagnostic( 623 self.rules.rule_without_message_args, infra.Level.ERROR 624 ) 625 ) 626 627 def test_diagnostic_context_raises_original_exception_from_diagnostic_created_from_it( 628 self, 629 ): 630 with self.assertRaises(ValueError): 631 try: 632 raise ValueError("original exception") 633 except ValueError as e: 634 diagnostic = infra.Diagnostic( 635 self.rules.rule_without_message_args, infra.Level.ERROR 636 ) 637 diagnostic.log_source_exception(logging.ERROR, e) 638 self.context.log_and_raise_if_error(diagnostic) 639 640 def test_diagnostic_context_raises_if_diagnostic_is_warning_and_warnings_as_errors_is_true( 641 self, 642 ): 643 with self.assertRaises(infra.RuntimeErrorWithDiagnostic): 644 self.context.options.warnings_as_errors = True 645 self.context.log_and_raise_if_error( 646 infra.Diagnostic( 647 self.rules.rule_without_message_args, infra.Level.WARNING 648 ) 649 ) 650 651 652if __name__ == "__main__": 653 common_utils.run_tests() 654