xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/exporter/_testing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Test utilities for ONNX export."""
2
3from __future__ import annotations
4
5
6__all__ = ["assert_onnx_program"]
7
8from typing import Any, TYPE_CHECKING
9
10import torch
11from torch.utils import _pytree
12
13
14if TYPE_CHECKING:
15    from torch.onnx._internal.exporter import _onnx_program
16
17
18def assert_onnx_program(
19    program: _onnx_program.ONNXProgram,
20    *,
21    rtol: float | None = None,
22    atol: float | None = None,
23    args: tuple[Any, ...] | None = None,
24    kwargs: dict[str, Any] | None = None,
25) -> None:
26    """Assert that the ONNX model produces the same output as the PyTorch ExportedProgram.
27    Args:
28        program: The ``ONNXProgram`` to verify.
29        rtol: Relative tolerance.
30        atol: Absolute tolerance.
31        args: The positional arguments to pass to the program.
32            If None, the default example inputs in the ExportedProgram will be used.
33        kwargs: The keyword arguments to pass to the program.
34            If None, the default example inputs in the ExportedProgram will be used.
35    """
36    exported_program = program.exported_program
37    if exported_program is None:
38        raise ValueError(
39            "The ONNXProgram does not contain an ExportedProgram. "
40            "To verify the ONNX program, initialize ONNXProgram with an ExportedProgram, "
41            "or assign the ExportedProgram to the ONNXProgram.exported_program attribute."
42        )
43    if args is None and kwargs is None:
44        # User did not provide example inputs, use the default example inputs
45        if exported_program.example_inputs is None:
46            raise ValueError(
47                "No example inputs provided and the exported_program does not contain example inputs. "
48                "Please provide arguments to verify the ONNX program."
49            )
50        args, kwargs = exported_program.example_inputs
51    if args is None:
52        args = ()
53    if kwargs is None:
54        kwargs = {}
55    torch_module = exported_program.module()
56    torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs))
57    onnx_outputs = program(*args, **kwargs)
58    # TODO(justinchuby): Include output names in the error message
59    torch.testing.assert_close(
60        tuple(onnx_outputs),
61        tuple(torch_outputs),
62        rtol=rtol,
63        atol=atol,
64        equal_nan=True,
65        check_device=False,
66    )
67