1"""Test utilities for ONNX export.""" 2 3from __future__ import annotations 4 5 6__all__ = ["assert_onnx_program"] 7 8from typing import Any, TYPE_CHECKING 9 10import torch 11from torch.utils import _pytree 12 13 14if TYPE_CHECKING: 15 from torch.onnx._internal.exporter import _onnx_program 16 17 18def assert_onnx_program( 19 program: _onnx_program.ONNXProgram, 20 *, 21 rtol: float | None = None, 22 atol: float | None = None, 23 args: tuple[Any, ...] | None = None, 24 kwargs: dict[str, Any] | None = None, 25) -> None: 26 """Assert that the ONNX model produces the same output as the PyTorch ExportedProgram. 27 Args: 28 program: The ``ONNXProgram`` to verify. 29 rtol: Relative tolerance. 30 atol: Absolute tolerance. 31 args: The positional arguments to pass to the program. 32 If None, the default example inputs in the ExportedProgram will be used. 33 kwargs: The keyword arguments to pass to the program. 34 If None, the default example inputs in the ExportedProgram will be used. 35 """ 36 exported_program = program.exported_program 37 if exported_program is None: 38 raise ValueError( 39 "The ONNXProgram does not contain an ExportedProgram. " 40 "To verify the ONNX program, initialize ONNXProgram with an ExportedProgram, " 41 "or assign the ExportedProgram to the ONNXProgram.exported_program attribute." 42 ) 43 if args is None and kwargs is None: 44 # User did not provide example inputs, use the default example inputs 45 if exported_program.example_inputs is None: 46 raise ValueError( 47 "No example inputs provided and the exported_program does not contain example inputs. " 48 "Please provide arguments to verify the ONNX program." 49 ) 50 args, kwargs = exported_program.example_inputs 51 if args is None: 52 args = () 53 if kwargs is None: 54 kwargs = {} 55 torch_module = exported_program.module() 56 torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs)) 57 onnx_outputs = program(*args, **kwargs) 58 # TODO(justinchuby): Include output names in the error message 59 torch.testing.assert_close( 60 tuple(onnx_outputs), 61 tuple(torch_outputs), 62 rtol=rtol, 63 atol=atol, 64 equal_nan=True, 65 check_device=False, 66 ) 67