# Owner(s): ["module: onnx"] from __future__ import annotations import contextlib import dataclasses import io import logging import typing from typing import AbstractSet, Protocol, Tuple import torch from torch.onnx import errors from torch.onnx._internal import diagnostics from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import formatter, sarif from torch.onnx._internal.fx import diagnostics as fx_diagnostics from torch.testing._internal import common_utils, logging_utils if typing.TYPE_CHECKING: import unittest class _SarifLogBuilder(Protocol): def sarif_log(self) -> sarif.SarifLog: ... def _assert_has_diagnostics( sarif_log_builder: _SarifLogBuilder, rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], ): sarif_log = sarif_log_builder.sarif_log() unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} actual_results = [] for run in sarif_log.runs: if run.results is None: continue for result in run.results: id_level_pair = (result.rule_id, result.level) unseen_pairs.discard(id_level_pair) actual_results.append(id_level_pair) if unseen_pairs: raise AssertionError( f"Expected diagnostic results of rule id and level pair {unseen_pairs} not found. " f"Actual diagnostic results: {actual_results}" ) @dataclasses.dataclass class _RuleCollectionForTest(infra.RuleCollection): rule_without_message_args: infra.Rule = dataclasses.field( default=infra.Rule( "1", "rule-without-message-args", message_default_template="rule message", ) ) @contextlib.contextmanager def assert_all_diagnostics( test_suite: unittest.TestCase, sarif_log_builder: _SarifLogBuilder, rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], ): """Context manager to assert that all diagnostics are emitted. Usage: with assert_all_diagnostics( self, diagnostics.engine, {(rule, infra.Level.Error)}, ): torch.onnx.export(...) Args: test_suite: The test suite instance. sarif_log_builder: The SARIF log builder. rule_level_pairs: A set of rule and level pairs to assert. Returns: A context manager. Raises: AssertionError: If not all diagnostics are emitted. """ try: yield except errors.OnnxExporterError: test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs}) finally: _assert_has_diagnostics(sarif_log_builder, rule_level_pairs) def assert_diagnostic( test_suite: unittest.TestCase, sarif_log_builder: _SarifLogBuilder, rule: infra.Rule, level: infra.Level, ): """Context manager to assert that a diagnostic is emitted. Usage: with assert_diagnostic( self, diagnostics.engine, rule, infra.Level.Error, ): torch.onnx.export(...) Args: test_suite: The test suite instance. sarif_log_builder: The SARIF log builder. rule: The rule to assert. level: The level to assert. Returns: A context manager. Raises: AssertionError: If the diagnostic is not emitted. """ return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)}) class TestDynamoOnnxDiagnostics(common_utils.TestCase): """Test cases for diagnostics emitted by the Dynamo ONNX export code.""" def setUp(self): self.diagnostic_context = fx_diagnostics.DiagnosticContext("dynamo_export", "") self.rules = _RuleCollectionForTest() return super().setUp() def test_log_is_recorded_in_sarif_additional_messages_according_to_diagnostic_options_verbosity_level( self, ): logging_levels = [ logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, ] for verbosity_level in logging_levels: self.diagnostic_context.options.verbosity_level = verbosity_level with self.diagnostic_context: diagnostic = fx_diagnostics.Diagnostic( self.rules.rule_without_message_args, infra.Level.NONE ) additional_messages_count = len(diagnostic.additional_messages) for log_level in logging_levels: diagnostic.log(level=log_level, message="log message") if log_level >= verbosity_level: self.assertGreater( len(diagnostic.additional_messages), additional_messages_count, f"Additional message should be recorded when log level is {log_level} " f"and verbosity level is {verbosity_level}", ) else: self.assertEqual( len(diagnostic.additional_messages), additional_messages_count, f"Additional message should not be recorded when log level is " f"{log_level} and verbosity level is {verbosity_level}", ) def test_torch_logs_environment_variable_precedes_diagnostic_options_verbosity_level( self, ): self.diagnostic_context.options.verbosity_level = logging.ERROR with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context: diagnostic = fx_diagnostics.Diagnostic( self.rules.rule_without_message_args, infra.Level.NONE ) additional_messages_count = len(diagnostic.additional_messages) diagnostic.debug("message") self.assertGreater( len(diagnostic.additional_messages), additional_messages_count ) def test_log_is_not_emitted_to_terminal_when_log_artifact_is_not_enabled(self): self.diagnostic_context.options.verbosity_level = logging.INFO with self.diagnostic_context: diagnostic = fx_diagnostics.Diagnostic( self.rules.rule_without_message_args, infra.Level.NONE ) with self.assertLogs( diagnostic.logger, level=logging.INFO ) as assert_log_context: diagnostic.info("message") # NOTE: self.assertNoLogs only exist >= Python 3.10 # Add this dummy log such that we can pass self.assertLogs, and inspect # assert_log_context.records to check if the log we don't want is not emitted. diagnostic.logger.log(logging.ERROR, "dummy message") self.assertEqual(len(assert_log_context.records), 1) def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self): self.diagnostic_context.options.verbosity_level = logging.INFO with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context: diagnostic = fx_diagnostics.Diagnostic( self.rules.rule_without_message_args, infra.Level.NONE ) with self.assertLogs(diagnostic.logger, level=logging.INFO): diagnostic.info("message") def test_diagnostic_log_emit_correctly_formatted_string(self): verbosity_level = logging.INFO self.diagnostic_context.options.verbosity_level = verbosity_level with self.diagnostic_context: diagnostic = fx_diagnostics.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) diagnostic.log( logging.INFO, "%s", formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"), ) self.assertIn("hello world", diagnostic.additional_messages) def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong( self, ): with self.diagnostic_context: # Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic # instead of base infra.Diagnostic. diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) with self.assertRaises(TypeError): self.diagnostic_context.log(diagnostic) class TestTorchScriptOnnxDiagnostics(common_utils.TestCase): """Test cases for diagnostics emitted by the TorchScript ONNX export code.""" def setUp(self): engine = diagnostics.engine engine.clear() self._sample_rule = diagnostics.rules.missing_custom_symbolic_function super().setUp() def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp( self, ) -> diagnostics.TorchScriptOnnxExportDiagnostic: class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): return x + y @staticmethod def symbolic(g, x, y): return g.op("custom::CustomAdd", x, y) class M(torch.nn.Module): def forward(self, x): return CustomAdd.apply(x, x) # trigger warning for missing shape inference. rule = diagnostics.rules.node_missing_onnx_shape_inference torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) context = diagnostics.engine.contexts[-1] for diagnostic in context.diagnostics: if ( diagnostic.rule == rule and diagnostic.level == diagnostics.levels.WARNING ): return typing.cast( diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic ) raise AssertionError("No diagnostic found.") def test_assert_diagnostic_raises_when_diagnostic_not_found(self): with self.assertRaises(AssertionError): with assert_diagnostic( self, diagnostics.engine, diagnostics.rules.node_missing_onnx_shape_inference, diagnostics.levels.WARNING, ): pass def test_cpp_diagnose_emits_warning(self): with assert_diagnostic( self, diagnostics.engine, diagnostics.rules.node_missing_onnx_shape_inference, diagnostics.levels.WARNING, ): # trigger warning for missing shape inference. self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() def test_py_diagnose_emits_error(self): class M(torch.nn.Module): def forward(self, x): return torch.diagonal(x) with assert_diagnostic( self, diagnostics.engine, diagnostics.rules.operator_supported_in_newer_opset_version, diagnostics.levels.ERROR, ): # trigger error for operator unsupported until newer opset version. torch.onnx.export( M(), torch.randn(3, 4), io.BytesIO(), opset_version=9, ) def test_diagnostics_engine_records_diagnosis_reported_outside_of_export( self, ): sample_level = diagnostics.levels.ERROR with assert_diagnostic( self, diagnostics.engine, self._sample_rule, sample_level, ): diagnostic = infra.Diagnostic(self._sample_rule, sample_level) diagnostics.export_context().log(diagnostic) def test_diagnostics_records_python_call_stack(self): diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip # Do not break the above line, otherwise it will not work with Python-3.8+ stack = diagnostic.python_call_stack assert stack is not None # for mypy self.assertGreater(len(stack.frames), 0) frame = stack.frames[0] assert frame.location.snippet is not None # for mypy self.assertIn("self._sample_rule", frame.location.snippet) assert frame.location.uri is not None # for mypy self.assertIn("test_diagnostics.py", frame.location.uri) def test_diagnostics_records_cpp_call_stack(self): diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() stack = diagnostic.cpp_call_stack assert stack is not None # for mypy self.assertGreater(len(stack.frames), 0) frame_messages = [frame.location.message for frame in stack.frames] # node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx) # after node-level shape type inference and processed symbolic_fn output type self.assertTrue( any( isinstance(message, str) and "torch::jit::NodeToONNX" in message for message in frame_messages ) ) @common_utils.instantiate_parametrized_tests class TestDiagnosticsInfra(common_utils.TestCase): """Test cases for diagnostics infra.""" def setUp(self): self.rules = _RuleCollectionForTest() with contextlib.ExitStack() as stack: self.context: infra.DiagnosticContext[infra.Diagnostic] = ( stack.enter_context(infra.DiagnosticContext("test", "1.0.0")) ) self.addCleanup(stack.pop_all().close) return super().setUp() def test_diagnostics_engine_records_diagnosis_with_custom_rules(self): custom_rules = infra.RuleCollection.custom_collection_from_list( "CustomRuleCollection", [ infra.Rule( "1", "custom-rule", message_default_template="custom rule message", ), infra.Rule( "2", "custom-rule-2", message_default_template="custom rule message 2", ), ], ) with assert_all_diagnostics( self, self.context, { (custom_rules.custom_rule, infra.Level.WARNING), # type: ignore[attr-defined] (custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined] }, ): diagnostic1 = infra.Diagnostic( custom_rules.custom_rule, # type: ignore[attr-defined] infra.Level.WARNING, ) self.context.log(diagnostic1) diagnostic2 = infra.Diagnostic( custom_rules.custom_rule_2, # type: ignore[attr-defined] infra.Level.ERROR, ) self.context.log(diagnostic2) def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level( self, ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) with self.assertLogs( diagnostic.logger, level=verbosity_level ) as assert_log_context: diagnostic.log(logging.DEBUG, "debug message") # NOTE: self.assertNoLogs only exist >= Python 3.10 # Add this dummy log such that we can pass self.assertLogs, and inspect # assert_log_context.records to check if the log level is correct. diagnostic.log(logging.INFO, "info message") for record in assert_log_context.records: self.assertGreaterEqual(record.levelno, logging.INFO) self.assertFalse( any( message.find("debug message") >= 0 for message in diagnostic.additional_messages ) ) def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level( self, ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) level_message_pairs = [ (logging.INFO, "info message"), (logging.WARNING, "warning message"), (logging.ERROR, "error message"), ] for level, message in level_message_pairs: with self.assertLogs(diagnostic.logger, level=verbosity_level): diagnostic.log(level, message) self.assertTrue( any( message.find(message) >= 0 for message in diagnostic.additional_messages ) ) @common_utils.parametrize( "log_api, log_level", [ ("debug", logging.DEBUG), ("info", logging.INFO), ("warning", logging.WARNING), ("error", logging.ERROR), ], ) def test_diagnostic_log_is_emitted_according_to_api_level_and_diagnostic_options_verbosity_level( self, log_api: str, log_level: int ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) message = "log message" with self.assertLogs( diagnostic.logger, level=verbosity_level ) as assert_log_context: getattr(diagnostic, log_api)(message) # NOTE: self.assertNoLogs only exist >= Python 3.10 # Add this dummy log such that we can pass self.assertLogs, and inspect # assert_log_context.records to check if the log level is correct. diagnostic.log(logging.ERROR, "dummy message") for record in assert_log_context.records: self.assertGreaterEqual(record.levelno, logging.INFO) if log_level >= verbosity_level: self.assertIn(message, diagnostic.additional_messages) else: self.assertNotIn(message, diagnostic.additional_messages) def test_diagnostic_log_lazy_string_is_not_evaluated_when_level_less_than_diagnostic_options_verbosity_level( self, ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) reference_val = 0 def expensive_formatting_function() -> str: # Modify the reference_val to reflect this function is evaluated nonlocal reference_val reference_val += 1 return f"expensive formatting {reference_val}" # `expensive_formatting_function` should NOT be evaluated. diagnostic.debug("%s", formatter.LazyString(expensive_formatting_function)) self.assertEqual( reference_val, 0, "expensive_formatting_function should not be evaluated after being wrapped under LazyString", ) def test_diagnostic_log_lazy_string_is_evaluated_once_when_level_not_less_than_diagnostic_options_verbosity_level( self, ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) reference_val = 0 def expensive_formatting_function() -> str: # Modify the reference_val to reflect this function is evaluated nonlocal reference_val reference_val += 1 return f"expensive formatting {reference_val}" # `expensive_formatting_function` should NOT be evaluated. diagnostic.info("%s", formatter.LazyString(expensive_formatting_function)) self.assertEqual( reference_val, 1, "expensive_formatting_function should only be evaluated once after being wrapped under LazyString", ) def test_diagnostic_log_emit_correctly_formatted_string(self): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) diagnostic.log( logging.INFO, "%s", formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"), ) self.assertIn("hello world", diagnostic.additional_messages) def test_diagnostic_nested_log_section_emits_messages_with_correct_section_title_indentation( self, ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) with diagnostic.log_section(logging.INFO, "My Section"): diagnostic.log(logging.INFO, "My Message") with diagnostic.log_section(logging.INFO, "My Subsection"): diagnostic.log(logging.INFO, "My Submessage") with diagnostic.log_section(logging.INFO, "My Section 2"): diagnostic.log(logging.INFO, "My Message 2") self.assertIn("## My Section", diagnostic.additional_messages) self.assertIn("### My Subsection", diagnostic.additional_messages) self.assertIn("## My Section 2", diagnostic.additional_messages) def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_message( self, ): verbosity_level = logging.INFO self.context.options.verbosity_level = verbosity_level with self.context: try: raise ValueError("original exception") except ValueError as e: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.NOTE ) diagnostic.log_source_exception(logging.ERROR, e) diagnostic_message = "\n".join(diagnostic.additional_messages) self.assertIn("ValueError: original exception", diagnostic_message) self.assertIn("Traceback (most recent call last):", diagnostic_message) def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong( self, ): with self.context: with self.assertRaises(TypeError): # The method expects 'Diagnostic' or its subclasses as arguments. # Passing any other type will trigger a TypeError. self.context.log("This is a str message.") def test_diagnostic_context_raises_if_diagnostic_is_error(self): with self.assertRaises(infra.RuntimeErrorWithDiagnostic): self.context.log_and_raise_if_error( infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.ERROR ) ) def test_diagnostic_context_raises_original_exception_from_diagnostic_created_from_it( self, ): with self.assertRaises(ValueError): try: raise ValueError("original exception") except ValueError as e: diagnostic = infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.ERROR ) diagnostic.log_source_exception(logging.ERROR, e) self.context.log_and_raise_if_error(diagnostic) def test_diagnostic_context_raises_if_diagnostic_is_warning_and_warnings_as_errors_is_true( self, ): with self.assertRaises(infra.RuntimeErrorWithDiagnostic): self.context.options.warnings_as_errors = True self.context.log_and_raise_if_error( infra.Diagnostic( self.rules.rule_without_message_args, infra.Level.WARNING ) ) if __name__ == "__main__": common_utils.run_tests()