xref: /aosp_15_r20/external/pytorch/test/onnx/internal/test_diagnostics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2from __future__ import annotations
3
4import contextlib
5import dataclasses
6import io
7import logging
8import typing
9from typing import AbstractSet, Protocol, Tuple
10
11import torch
12from torch.onnx import errors
13from torch.onnx._internal import diagnostics
14from torch.onnx._internal.diagnostics import infra
15from torch.onnx._internal.diagnostics.infra import formatter, sarif
16from torch.onnx._internal.fx import diagnostics as fx_diagnostics
17from torch.testing._internal import common_utils, logging_utils
18
19
20if typing.TYPE_CHECKING:
21    import unittest
22
23
24class _SarifLogBuilder(Protocol):
25    def sarif_log(self) -> sarif.SarifLog: ...
26
27
28def _assert_has_diagnostics(
29    sarif_log_builder: _SarifLogBuilder,
30    rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
31):
32    sarif_log = sarif_log_builder.sarif_log()
33    unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
34    actual_results = []
35    for run in sarif_log.runs:
36        if run.results is None:
37            continue
38        for result in run.results:
39            id_level_pair = (result.rule_id, result.level)
40            unseen_pairs.discard(id_level_pair)
41            actual_results.append(id_level_pair)
42
43    if unseen_pairs:
44        raise AssertionError(
45            f"Expected diagnostic results of rule id and level pair {unseen_pairs} not found. "
46            f"Actual diagnostic results: {actual_results}"
47        )
48
49
50@dataclasses.dataclass
51class _RuleCollectionForTest(infra.RuleCollection):
52    rule_without_message_args: infra.Rule = dataclasses.field(
53        default=infra.Rule(
54            "1",
55            "rule-without-message-args",
56            message_default_template="rule message",
57        )
58    )
59
60
61@contextlib.contextmanager
62def assert_all_diagnostics(
63    test_suite: unittest.TestCase,
64    sarif_log_builder: _SarifLogBuilder,
65    rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
66):
67    """Context manager to assert that all diagnostics are emitted.
68
69    Usage:
70        with assert_all_diagnostics(
71            self,
72            diagnostics.engine,
73            {(rule, infra.Level.Error)},
74        ):
75            torch.onnx.export(...)
76
77    Args:
78        test_suite: The test suite instance.
79        sarif_log_builder: The SARIF log builder.
80        rule_level_pairs: A set of rule and level pairs to assert.
81
82    Returns:
83        A context manager.
84
85    Raises:
86        AssertionError: If not all diagnostics are emitted.
87    """
88
89    try:
90        yield
91    except errors.OnnxExporterError:
92        test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs})
93    finally:
94        _assert_has_diagnostics(sarif_log_builder, rule_level_pairs)
95
96
97def assert_diagnostic(
98    test_suite: unittest.TestCase,
99    sarif_log_builder: _SarifLogBuilder,
100    rule: infra.Rule,
101    level: infra.Level,
102):
103    """Context manager to assert that a diagnostic is emitted.
104
105    Usage:
106        with assert_diagnostic(
107            self,
108            diagnostics.engine,
109            rule,
110            infra.Level.Error,
111        ):
112            torch.onnx.export(...)
113
114    Args:
115        test_suite: The test suite instance.
116        sarif_log_builder: The SARIF log builder.
117        rule: The rule to assert.
118        level: The level to assert.
119
120    Returns:
121        A context manager.
122
123    Raises:
124        AssertionError: If the diagnostic is not emitted.
125    """
126
127    return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)})
128
129
130class TestDynamoOnnxDiagnostics(common_utils.TestCase):
131    """Test cases for diagnostics emitted by the Dynamo ONNX export code."""
132
133    def setUp(self):
134        self.diagnostic_context = fx_diagnostics.DiagnosticContext("dynamo_export", "")
135        self.rules = _RuleCollectionForTest()
136        return super().setUp()
137
138    def test_log_is_recorded_in_sarif_additional_messages_according_to_diagnostic_options_verbosity_level(
139        self,
140    ):
141        logging_levels = [
142            logging.DEBUG,
143            logging.INFO,
144            logging.WARNING,
145            logging.ERROR,
146        ]
147        for verbosity_level in logging_levels:
148            self.diagnostic_context.options.verbosity_level = verbosity_level
149            with self.diagnostic_context:
150                diagnostic = fx_diagnostics.Diagnostic(
151                    self.rules.rule_without_message_args, infra.Level.NONE
152                )
153                additional_messages_count = len(diagnostic.additional_messages)
154                for log_level in logging_levels:
155                    diagnostic.log(level=log_level, message="log message")
156                    if log_level >= verbosity_level:
157                        self.assertGreater(
158                            len(diagnostic.additional_messages),
159                            additional_messages_count,
160                            f"Additional message should be recorded when log level is {log_level} "
161                            f"and verbosity level is {verbosity_level}",
162                        )
163                    else:
164                        self.assertEqual(
165                            len(diagnostic.additional_messages),
166                            additional_messages_count,
167                            f"Additional message should not be recorded when log level is "
168                            f"{log_level} and verbosity level is {verbosity_level}",
169                        )
170
171    def test_torch_logs_environment_variable_precedes_diagnostic_options_verbosity_level(
172        self,
173    ):
174        self.diagnostic_context.options.verbosity_level = logging.ERROR
175        with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
176            diagnostic = fx_diagnostics.Diagnostic(
177                self.rules.rule_without_message_args, infra.Level.NONE
178            )
179            additional_messages_count = len(diagnostic.additional_messages)
180            diagnostic.debug("message")
181            self.assertGreater(
182                len(diagnostic.additional_messages), additional_messages_count
183            )
184
185    def test_log_is_not_emitted_to_terminal_when_log_artifact_is_not_enabled(self):
186        self.diagnostic_context.options.verbosity_level = logging.INFO
187        with self.diagnostic_context:
188            diagnostic = fx_diagnostics.Diagnostic(
189                self.rules.rule_without_message_args, infra.Level.NONE
190            )
191
192            with self.assertLogs(
193                diagnostic.logger, level=logging.INFO
194            ) as assert_log_context:
195                diagnostic.info("message")
196                # NOTE: self.assertNoLogs only exist >= Python 3.10
197                # Add this dummy log such that we can pass self.assertLogs, and inspect
198                # assert_log_context.records to check if the log we don't want is not emitted.
199                diagnostic.logger.log(logging.ERROR, "dummy message")
200
201            self.assertEqual(len(assert_log_context.records), 1)
202
203    def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self):
204        self.diagnostic_context.options.verbosity_level = logging.INFO
205
206        with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
207            diagnostic = fx_diagnostics.Diagnostic(
208                self.rules.rule_without_message_args, infra.Level.NONE
209            )
210
211            with self.assertLogs(diagnostic.logger, level=logging.INFO):
212                diagnostic.info("message")
213
214    def test_diagnostic_log_emit_correctly_formatted_string(self):
215        verbosity_level = logging.INFO
216        self.diagnostic_context.options.verbosity_level = verbosity_level
217        with self.diagnostic_context:
218            diagnostic = fx_diagnostics.Diagnostic(
219                self.rules.rule_without_message_args, infra.Level.NOTE
220            )
221            diagnostic.log(
222                logging.INFO,
223                "%s",
224                formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
225            )
226            self.assertIn("hello world", diagnostic.additional_messages)
227
228    def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
229        self,
230    ):
231        with self.diagnostic_context:
232            # Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic
233            # instead of base infra.Diagnostic.
234            diagnostic = infra.Diagnostic(
235                self.rules.rule_without_message_args, infra.Level.NOTE
236            )
237            with self.assertRaises(TypeError):
238                self.diagnostic_context.log(diagnostic)
239
240
241class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
242    """Test cases for diagnostics emitted by the TorchScript ONNX export code."""
243
244    def setUp(self):
245        engine = diagnostics.engine
246        engine.clear()
247        self._sample_rule = diagnostics.rules.missing_custom_symbolic_function
248        super().setUp()
249
250    def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
251        self,
252    ) -> diagnostics.TorchScriptOnnxExportDiagnostic:
253        class CustomAdd(torch.autograd.Function):
254            @staticmethod
255            def forward(ctx, x, y):
256                return x + y
257
258            @staticmethod
259            def symbolic(g, x, y):
260                return g.op("custom::CustomAdd", x, y)
261
262        class M(torch.nn.Module):
263            def forward(self, x):
264                return CustomAdd.apply(x, x)
265
266        # trigger warning for missing shape inference.
267        rule = diagnostics.rules.node_missing_onnx_shape_inference
268        torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())
269
270        context = diagnostics.engine.contexts[-1]
271        for diagnostic in context.diagnostics:
272            if (
273                diagnostic.rule == rule
274                and diagnostic.level == diagnostics.levels.WARNING
275            ):
276                return typing.cast(
277                    diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic
278                )
279        raise AssertionError("No diagnostic found.")
280
281    def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
282        with self.assertRaises(AssertionError):
283            with assert_diagnostic(
284                self,
285                diagnostics.engine,
286                diagnostics.rules.node_missing_onnx_shape_inference,
287                diagnostics.levels.WARNING,
288            ):
289                pass
290
291    def test_cpp_diagnose_emits_warning(self):
292        with assert_diagnostic(
293            self,
294            diagnostics.engine,
295            diagnostics.rules.node_missing_onnx_shape_inference,
296            diagnostics.levels.WARNING,
297        ):
298            # trigger warning for missing shape inference.
299            self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
300
301    def test_py_diagnose_emits_error(self):
302        class M(torch.nn.Module):
303            def forward(self, x):
304                return torch.diagonal(x)
305
306        with assert_diagnostic(
307            self,
308            diagnostics.engine,
309            diagnostics.rules.operator_supported_in_newer_opset_version,
310            diagnostics.levels.ERROR,
311        ):
312            # trigger error for operator unsupported until newer opset version.
313            torch.onnx.export(
314                M(),
315                torch.randn(3, 4),
316                io.BytesIO(),
317                opset_version=9,
318            )
319
320    def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
321        self,
322    ):
323        sample_level = diagnostics.levels.ERROR
324        with assert_diagnostic(
325            self,
326            diagnostics.engine,
327            self._sample_rule,
328            sample_level,
329        ):
330            diagnostic = infra.Diagnostic(self._sample_rule, sample_level)
331            diagnostics.export_context().log(diagnostic)
332
333    def test_diagnostics_records_python_call_stack(self):
334        diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE)  # fmt: skip
335        # Do not break the above line, otherwise it will not work with Python-3.8+
336        stack = diagnostic.python_call_stack
337        assert stack is not None  # for mypy
338        self.assertGreater(len(stack.frames), 0)
339        frame = stack.frames[0]
340        assert frame.location.snippet is not None  # for mypy
341        self.assertIn("self._sample_rule", frame.location.snippet)
342        assert frame.location.uri is not None  # for mypy
343        self.assertIn("test_diagnostics.py", frame.location.uri)
344
345    def test_diagnostics_records_cpp_call_stack(self):
346        diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
347        stack = diagnostic.cpp_call_stack
348        assert stack is not None  # for mypy
349        self.assertGreater(len(stack.frames), 0)
350        frame_messages = [frame.location.message for frame in stack.frames]
351        # node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx)
352        # after node-level shape type inference and processed symbolic_fn output type
353        self.assertTrue(
354            any(
355                isinstance(message, str) and "torch::jit::NodeToONNX" in message
356                for message in frame_messages
357            )
358        )
359
360
361@common_utils.instantiate_parametrized_tests
362class TestDiagnosticsInfra(common_utils.TestCase):
363    """Test cases for diagnostics infra."""
364
365    def setUp(self):
366        self.rules = _RuleCollectionForTest()
367        with contextlib.ExitStack() as stack:
368            self.context: infra.DiagnosticContext[infra.Diagnostic] = (
369                stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
370            )
371            self.addCleanup(stack.pop_all().close)
372        return super().setUp()
373
374    def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
375        custom_rules = infra.RuleCollection.custom_collection_from_list(
376            "CustomRuleCollection",
377            [
378                infra.Rule(
379                    "1",
380                    "custom-rule",
381                    message_default_template="custom rule message",
382                ),
383                infra.Rule(
384                    "2",
385                    "custom-rule-2",
386                    message_default_template="custom rule message 2",
387                ),
388            ],
389        )
390
391        with assert_all_diagnostics(
392            self,
393            self.context,
394            {
395                (custom_rules.custom_rule, infra.Level.WARNING),  # type: ignore[attr-defined]
396                (custom_rules.custom_rule_2, infra.Level.ERROR),  # type: ignore[attr-defined]
397            },
398        ):
399            diagnostic1 = infra.Diagnostic(
400                custom_rules.custom_rule,  # type: ignore[attr-defined]
401                infra.Level.WARNING,
402            )
403            self.context.log(diagnostic1)
404
405            diagnostic2 = infra.Diagnostic(
406                custom_rules.custom_rule_2,  # type: ignore[attr-defined]
407                infra.Level.ERROR,
408            )
409            self.context.log(diagnostic2)
410
411    def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level(
412        self,
413    ):
414        verbosity_level = logging.INFO
415        self.context.options.verbosity_level = verbosity_level
416        with self.context:
417            diagnostic = infra.Diagnostic(
418                self.rules.rule_without_message_args, infra.Level.NOTE
419            )
420
421            with self.assertLogs(
422                diagnostic.logger, level=verbosity_level
423            ) as assert_log_context:
424                diagnostic.log(logging.DEBUG, "debug message")
425                # NOTE: self.assertNoLogs only exist >= Python 3.10
426                # Add this dummy log such that we can pass self.assertLogs, and inspect
427                # assert_log_context.records to check if the log level is correct.
428                diagnostic.log(logging.INFO, "info message")
429
430        for record in assert_log_context.records:
431            self.assertGreaterEqual(record.levelno, logging.INFO)
432        self.assertFalse(
433            any(
434                message.find("debug message") >= 0
435                for message in diagnostic.additional_messages
436            )
437        )
438
439    def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level(
440        self,
441    ):
442        verbosity_level = logging.INFO
443        self.context.options.verbosity_level = verbosity_level
444        with self.context:
445            diagnostic = infra.Diagnostic(
446                self.rules.rule_without_message_args, infra.Level.NOTE
447            )
448
449            level_message_pairs = [
450                (logging.INFO, "info message"),
451                (logging.WARNING, "warning message"),
452                (logging.ERROR, "error message"),
453            ]
454
455            for level, message in level_message_pairs:
456                with self.assertLogs(diagnostic.logger, level=verbosity_level):
457                    diagnostic.log(level, message)
458
459            self.assertTrue(
460                any(
461                    message.find(message) >= 0
462                    for message in diagnostic.additional_messages
463                )
464            )
465
466    @common_utils.parametrize(
467        "log_api, log_level",
468        [
469            ("debug", logging.DEBUG),
470            ("info", logging.INFO),
471            ("warning", logging.WARNING),
472            ("error", logging.ERROR),
473        ],
474    )
475    def test_diagnostic_log_is_emitted_according_to_api_level_and_diagnostic_options_verbosity_level(
476        self, log_api: str, log_level: int
477    ):
478        verbosity_level = logging.INFO
479        self.context.options.verbosity_level = verbosity_level
480        with self.context:
481            diagnostic = infra.Diagnostic(
482                self.rules.rule_without_message_args, infra.Level.NOTE
483            )
484
485            message = "log message"
486            with self.assertLogs(
487                diagnostic.logger, level=verbosity_level
488            ) as assert_log_context:
489                getattr(diagnostic, log_api)(message)
490                # NOTE: self.assertNoLogs only exist >= Python 3.10
491                # Add this dummy log such that we can pass self.assertLogs, and inspect
492                # assert_log_context.records to check if the log level is correct.
493                diagnostic.log(logging.ERROR, "dummy message")
494
495            for record in assert_log_context.records:
496                self.assertGreaterEqual(record.levelno, logging.INFO)
497
498            if log_level >= verbosity_level:
499                self.assertIn(message, diagnostic.additional_messages)
500            else:
501                self.assertNotIn(message, diagnostic.additional_messages)
502
503    def test_diagnostic_log_lazy_string_is_not_evaluated_when_level_less_than_diagnostic_options_verbosity_level(
504        self,
505    ):
506        verbosity_level = logging.INFO
507        self.context.options.verbosity_level = verbosity_level
508        with self.context:
509            diagnostic = infra.Diagnostic(
510                self.rules.rule_without_message_args, infra.Level.NOTE
511            )
512
513            reference_val = 0
514
515            def expensive_formatting_function() -> str:
516                # Modify the reference_val to reflect this function is evaluated
517                nonlocal reference_val
518                reference_val += 1
519                return f"expensive formatting {reference_val}"
520
521            # `expensive_formatting_function` should NOT be evaluated.
522            diagnostic.debug("%s", formatter.LazyString(expensive_formatting_function))
523            self.assertEqual(
524                reference_val,
525                0,
526                "expensive_formatting_function should not be evaluated after being wrapped under LazyString",
527            )
528
529    def test_diagnostic_log_lazy_string_is_evaluated_once_when_level_not_less_than_diagnostic_options_verbosity_level(
530        self,
531    ):
532        verbosity_level = logging.INFO
533        self.context.options.verbosity_level = verbosity_level
534        with self.context:
535            diagnostic = infra.Diagnostic(
536                self.rules.rule_without_message_args, infra.Level.NOTE
537            )
538
539            reference_val = 0
540
541            def expensive_formatting_function() -> str:
542                # Modify the reference_val to reflect this function is evaluated
543                nonlocal reference_val
544                reference_val += 1
545                return f"expensive formatting {reference_val}"
546
547            # `expensive_formatting_function` should NOT be evaluated.
548            diagnostic.info("%s", formatter.LazyString(expensive_formatting_function))
549            self.assertEqual(
550                reference_val,
551                1,
552                "expensive_formatting_function should only be evaluated once after being wrapped under LazyString",
553            )
554
555    def test_diagnostic_log_emit_correctly_formatted_string(self):
556        verbosity_level = logging.INFO
557        self.context.options.verbosity_level = verbosity_level
558        with self.context:
559            diagnostic = infra.Diagnostic(
560                self.rules.rule_without_message_args, infra.Level.NOTE
561            )
562            diagnostic.log(
563                logging.INFO,
564                "%s",
565                formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
566            )
567            self.assertIn("hello world", diagnostic.additional_messages)
568
569    def test_diagnostic_nested_log_section_emits_messages_with_correct_section_title_indentation(
570        self,
571    ):
572        verbosity_level = logging.INFO
573        self.context.options.verbosity_level = verbosity_level
574        with self.context:
575            diagnostic = infra.Diagnostic(
576                self.rules.rule_without_message_args, infra.Level.NOTE
577            )
578
579            with diagnostic.log_section(logging.INFO, "My Section"):
580                diagnostic.log(logging.INFO, "My Message")
581                with diagnostic.log_section(logging.INFO, "My Subsection"):
582                    diagnostic.log(logging.INFO, "My Submessage")
583
584            with diagnostic.log_section(logging.INFO, "My Section 2"):
585                diagnostic.log(logging.INFO, "My Message 2")
586
587            self.assertIn("## My Section", diagnostic.additional_messages)
588            self.assertIn("### My Subsection", diagnostic.additional_messages)
589            self.assertIn("## My Section 2", diagnostic.additional_messages)
590
591    def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_message(
592        self,
593    ):
594        verbosity_level = logging.INFO
595        self.context.options.verbosity_level = verbosity_level
596        with self.context:
597            try:
598                raise ValueError("original exception")
599            except ValueError as e:
600                diagnostic = infra.Diagnostic(
601                    self.rules.rule_without_message_args, infra.Level.NOTE
602                )
603                diagnostic.log_source_exception(logging.ERROR, e)
604
605            diagnostic_message = "\n".join(diagnostic.additional_messages)
606
607            self.assertIn("ValueError: original exception", diagnostic_message)
608            self.assertIn("Traceback (most recent call last):", diagnostic_message)
609
610    def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
611        self,
612    ):
613        with self.context:
614            with self.assertRaises(TypeError):
615                # The method expects 'Diagnostic' or its subclasses as arguments.
616                # Passing any other type will trigger a TypeError.
617                self.context.log("This is a str message.")
618
619    def test_diagnostic_context_raises_if_diagnostic_is_error(self):
620        with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
621            self.context.log_and_raise_if_error(
622                infra.Diagnostic(
623                    self.rules.rule_without_message_args, infra.Level.ERROR
624                )
625            )
626
627    def test_diagnostic_context_raises_original_exception_from_diagnostic_created_from_it(
628        self,
629    ):
630        with self.assertRaises(ValueError):
631            try:
632                raise ValueError("original exception")
633            except ValueError as e:
634                diagnostic = infra.Diagnostic(
635                    self.rules.rule_without_message_args, infra.Level.ERROR
636                )
637                diagnostic.log_source_exception(logging.ERROR, e)
638                self.context.log_and_raise_if_error(diagnostic)
639
640    def test_diagnostic_context_raises_if_diagnostic_is_warning_and_warnings_as_errors_is_true(
641        self,
642    ):
643        with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
644            self.context.options.warnings_as_errors = True
645            self.context.log_and_raise_if_error(
646                infra.Diagnostic(
647                    self.rules.rule_without_message_args, infra.Level.WARNING
648                )
649            )
650
651
652if __name__ == "__main__":
653    common_utils.run_tests()
654