xref: /aosp_15_r20/external/pytorch/torch/onnx/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""Functions to export models into the ONNX IR format.
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard WorkerThese models can be loaded with the ONNX library and then
5*da0073e9SAndroid Build Coastguard Workerconverted to models which run on other deep learning frameworks.
6*da0073e9SAndroid Build Coastguard Worker"""
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport contextlib
11*da0073e9SAndroid Build Coastguard Workerimport copy
12*da0073e9SAndroid Build Coastguard Workerimport inspect
13*da0073e9SAndroid Build Coastguard Workerimport re
14*da0073e9SAndroid Build Coastguard Workerimport typing
15*da0073e9SAndroid Build Coastguard Workerimport warnings
16*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, Collection, Mapping, Sequence
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerimport torch
19*da0073e9SAndroid Build Coastguard Workerimport torch._C._onnx as _C_onnx
20*da0073e9SAndroid Build Coastguard Workerimport torch.jit._trace
21*da0073e9SAndroid Build Coastguard Workerimport torch.serialization
22*da0073e9SAndroid Build Coastguard Workerfrom torch import _C
23*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import (  # noqa: F401
24*da0073e9SAndroid Build Coastguard Worker    _constants,
25*da0073e9SAndroid Build Coastguard Worker    _deprecation,
26*da0073e9SAndroid Build Coastguard Worker    _exporter_states,
27*da0073e9SAndroid Build Coastguard Worker    errors,
28*da0073e9SAndroid Build Coastguard Worker    symbolic_helper,
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._globals import GLOBALS
31*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker__all__ = [
35*da0073e9SAndroid Build Coastguard Worker    "is_in_onnx_export",
36*da0073e9SAndroid Build Coastguard Worker    "select_model_mode_for_export",
37*da0073e9SAndroid Build Coastguard Worker    "disable_apex_o2_state_dict_hook",
38*da0073e9SAndroid Build Coastguard Worker    "setup_onnx_logging",
39*da0073e9SAndroid Build Coastguard Worker    "exporter_context",
40*da0073e9SAndroid Build Coastguard Worker    "export",
41*da0073e9SAndroid Build Coastguard Worker    "model_signature",
42*da0073e9SAndroid Build Coastguard Worker    "warn_on_static_input_change",
43*da0073e9SAndroid Build Coastguard Worker    "unpack_quantized_tensor",
44*da0073e9SAndroid Build Coastguard Worker    "export_to_pretty_string",
45*da0073e9SAndroid Build Coastguard Worker    "unconvertible_ops",
46*da0073e9SAndroid Build Coastguard Worker    "register_custom_op_symbolic",
47*da0073e9SAndroid Build Coastguard Worker    "unregister_custom_op_symbolic",
48*da0073e9SAndroid Build Coastguard Worker]
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerdef is_in_onnx_export() -> bool:
52*da0073e9SAndroid Build Coastguard Worker    """Returns whether it is in the middle of ONNX export."""
53*da0073e9SAndroid Build Coastguard Worker    return GLOBALS.in_onnx_export
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
57*da0073e9SAndroid Build Coastguard Worker# Skip check due to cannot import IValue from torch._C
58*da0073e9SAndroid Build Coastguard Worker_params_dict = {}  # type: ignore[var-annotated]
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
62*da0073e9SAndroid Build Coastguard Workerdef select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
63*da0073e9SAndroid Build Coastguard Worker    r"""A context manager to temporarily set the training mode of ``model``
64*da0073e9SAndroid Build Coastguard Worker    to ``mode``, resetting it when we exit the with-block.
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    Args:
67*da0073e9SAndroid Build Coastguard Worker        model: Same type and meaning as ``model`` arg to :func:`export`.
68*da0073e9SAndroid Build Coastguard Worker        mode: Same type and meaning as ``training`` arg to :func:`export`.
69*da0073e9SAndroid Build Coastguard Worker    """
70*da0073e9SAndroid Build Coastguard Worker    if not isinstance(mode, _C_onnx.TrainingMode):
71*da0073e9SAndroid Build Coastguard Worker        raise TypeError(
72*da0073e9SAndroid Build Coastguard Worker            f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'."
73*da0073e9SAndroid Build Coastguard Worker        )
74*da0073e9SAndroid Build Coastguard Worker    originally_training: bool = False
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    if hasattr(model, "training"):
77*da0073e9SAndroid Build Coastguard Worker        originally_training = model.training
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        # ONNX opset 12 has better support for training amenable models, with updated
80*da0073e9SAndroid Build Coastguard Worker        # versions of the dropout and batch_norm operators
81*da0073e9SAndroid Build Coastguard Worker        if mode == _C_onnx.TrainingMode.TRAINING or (
82*da0073e9SAndroid Build Coastguard Worker            mode == _C_onnx.TrainingMode.PRESERVE and originally_training
83*da0073e9SAndroid Build Coastguard Worker        ):
84*da0073e9SAndroid Build Coastguard Worker            GLOBALS.export_training = True
85*da0073e9SAndroid Build Coastguard Worker            if GLOBALS.export_onnx_opset_version < 12:
86*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
87*da0073e9SAndroid Build Coastguard Worker                    "You are exporting the model in training mode with onnx opset "
88*da0073e9SAndroid Build Coastguard Worker                    f"version {GLOBALS.export_onnx_opset_version}. "
89*da0073e9SAndroid Build Coastguard Worker                    "Opset versions lower than opset 12 will not be able to export "
90*da0073e9SAndroid Build Coastguard Worker                    "nodes such as Dropout and BatchNorm correctly."
91*da0073e9SAndroid Build Coastguard Worker                )
92*da0073e9SAndroid Build Coastguard Worker        else:
93*da0073e9SAndroid Build Coastguard Worker            GLOBALS.export_training = False
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        GLOBALS.training_mode = mode
96*da0073e9SAndroid Build Coastguard Worker        if mode == _C_onnx.TrainingMode.TRAINING:
97*da0073e9SAndroid Build Coastguard Worker            model.train(True)
98*da0073e9SAndroid Build Coastguard Worker        elif mode == _C_onnx.TrainingMode.EVAL:
99*da0073e9SAndroid Build Coastguard Worker            model.train(False)
100*da0073e9SAndroid Build Coastguard Worker        # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    try:
103*da0073e9SAndroid Build Coastguard Worker        yield
104*da0073e9SAndroid Build Coastguard Worker    finally:
105*da0073e9SAndroid Build Coastguard Worker        if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE:
106*da0073e9SAndroid Build Coastguard Worker            model.train(originally_training)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
110*da0073e9SAndroid Build Coastguard Workerdef disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction):
111*da0073e9SAndroid Build Coastguard Worker    # Apex O2 hook state_dict to return fp16 weights as fp32.
112*da0073e9SAndroid Build Coastguard Worker    # Exporter cannot identify them as same tensors.
113*da0073e9SAndroid Build Coastguard Worker    # Since this hook is only used by optimizer, it is safe to
114*da0073e9SAndroid Build Coastguard Worker    # remove this hook while exporting.
115*da0073e9SAndroid Build Coastguard Worker    if not isinstance(model, torch.jit.ScriptFunction):
116*da0073e9SAndroid Build Coastguard Worker        model_hooks = {}  # type: ignore[var-annotated]
117*da0073e9SAndroid Build Coastguard Worker        for module in model.modules():
118*da0073e9SAndroid Build Coastguard Worker            for key, hook in module._state_dict_hooks.items():
119*da0073e9SAndroid Build Coastguard Worker                if type(hook).__name__ == "O2StateDictHook":
120*da0073e9SAndroid Build Coastguard Worker                    if module not in model_hooks:
121*da0073e9SAndroid Build Coastguard Worker                        model_hooks[module] = {}
122*da0073e9SAndroid Build Coastguard Worker                    model_hooks[module][key] = hook
123*da0073e9SAndroid Build Coastguard Worker            if module in model_hooks:
124*da0073e9SAndroid Build Coastguard Worker                for key in model_hooks[module]:
125*da0073e9SAndroid Build Coastguard Worker                    module._state_dict_hooks.pop(key)
126*da0073e9SAndroid Build Coastguard Worker        try:
127*da0073e9SAndroid Build Coastguard Worker            yield
128*da0073e9SAndroid Build Coastguard Worker        finally:
129*da0073e9SAndroid Build Coastguard Worker            # Add the hooks back
130*da0073e9SAndroid Build Coastguard Worker            for module, m_map in model_hooks.items():
131*da0073e9SAndroid Build Coastguard Worker                for key, hook in m_map.items():
132*da0073e9SAndroid Build Coastguard Worker                    module._state_dict_hooks[key] = hook
133*da0073e9SAndroid Build Coastguard Worker    else:
134*da0073e9SAndroid Build Coastguard Worker        try:
135*da0073e9SAndroid Build Coastguard Worker            yield
136*da0073e9SAndroid Build Coastguard Worker        finally:
137*da0073e9SAndroid Build Coastguard Worker            pass
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
141*da0073e9SAndroid Build Coastguard Workerdef setup_onnx_logging(verbose: bool):
142*da0073e9SAndroid Build Coastguard Worker    is_originally_enabled = torch.onnx.is_onnx_log_enabled()
143*da0073e9SAndroid Build Coastguard Worker    if is_originally_enabled or verbose:
144*da0073e9SAndroid Build Coastguard Worker        torch.onnx.enable_log()
145*da0073e9SAndroid Build Coastguard Worker    try:
146*da0073e9SAndroid Build Coastguard Worker        yield
147*da0073e9SAndroid Build Coastguard Worker    finally:
148*da0073e9SAndroid Build Coastguard Worker        if not is_originally_enabled:
149*da0073e9SAndroid Build Coastguard Worker            torch.onnx.disable_log()
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
153*da0073e9SAndroid Build Coastguard Workerdef exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
154*da0073e9SAndroid Build Coastguard Worker    with select_model_mode_for_export(
155*da0073e9SAndroid Build Coastguard Worker        model, mode
156*da0073e9SAndroid Build Coastguard Worker    ) as mode_ctx, disable_apex_o2_state_dict_hook(
157*da0073e9SAndroid Build Coastguard Worker        model
158*da0073e9SAndroid Build Coastguard Worker    ) as apex_ctx, setup_onnx_logging(
159*da0073e9SAndroid Build Coastguard Worker        verbose
160*da0073e9SAndroid Build Coastguard Worker    ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx:
161*da0073e9SAndroid Build Coastguard Worker        yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Workerdef _get_torch_export_args(
165*da0073e9SAndroid Build Coastguard Worker    args: tuple[Any, ...],
166*da0073e9SAndroid Build Coastguard Worker    kwargs: dict[str, Any] | None,
167*da0073e9SAndroid Build Coastguard Worker) -> tuple[tuple[Any, ...], dict[str, Any] | None]:
168*da0073e9SAndroid Build Coastguard Worker    """Obtain the arguments for torch.onnx.export from the model and the input arguments."""
169*da0073e9SAndroid Build Coastguard Worker    if not kwargs and args and isinstance(args[-1], dict):
170*da0073e9SAndroid Build Coastguard Worker        kwargs = args[-1]
171*da0073e9SAndroid Build Coastguard Worker        args = args[:-1]
172*da0073e9SAndroid Build Coastguard Worker    return args, kwargs
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Workerdef export(
176*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction,
177*da0073e9SAndroid Build Coastguard Worker    args: tuple[Any, ...] | torch.Tensor,
178*da0073e9SAndroid Build Coastguard Worker    f: str,
179*da0073e9SAndroid Build Coastguard Worker    *,
180*da0073e9SAndroid Build Coastguard Worker    kwargs: dict[str, Any] | None = None,
181*da0073e9SAndroid Build Coastguard Worker    export_params: bool = True,
182*da0073e9SAndroid Build Coastguard Worker    verbose: bool = False,
183*da0073e9SAndroid Build Coastguard Worker    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
184*da0073e9SAndroid Build Coastguard Worker    input_names: Sequence[str] | None = None,
185*da0073e9SAndroid Build Coastguard Worker    output_names: Sequence[str] | None = None,
186*da0073e9SAndroid Build Coastguard Worker    operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
187*da0073e9SAndroid Build Coastguard Worker    opset_version: int | None = None,
188*da0073e9SAndroid Build Coastguard Worker    do_constant_folding: bool = True,
189*da0073e9SAndroid Build Coastguard Worker    dynamic_axes: Mapping[str, Mapping[int, str]]
190*da0073e9SAndroid Build Coastguard Worker    | Mapping[str, Sequence[int]]
191*da0073e9SAndroid Build Coastguard Worker    | None = None,
192*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs: bool | None = None,
193*da0073e9SAndroid Build Coastguard Worker    custom_opsets: Mapping[str, int] | None = None,
194*da0073e9SAndroid Build Coastguard Worker    export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
195*da0073e9SAndroid Build Coastguard Worker    autograd_inlining: bool = True,
196*da0073e9SAndroid Build Coastguard Worker) -> None:
197*da0073e9SAndroid Build Coastguard Worker    r"""Exports a model into ONNX format.
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
200*da0073e9SAndroid Build Coastguard Worker    :class:`torch.jit.ScriptFunction`, this runs
201*da0073e9SAndroid Build Coastguard Worker    ``model`` once in order to convert it to a TorchScript graph to be exported
202*da0073e9SAndroid Build Coastguard Worker    (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support
203*da0073e9SAndroid Build Coastguard Worker    for dynamic control flow as :func:`torch.jit.trace`.
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    Args:
206*da0073e9SAndroid Build Coastguard Worker        model: The model to be exported.
207*da0073e9SAndroid Build Coastguard Worker        args:
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker            args can be structured either as:
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker            1. ONLY A TUPLE OF ARGUMENTS::
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker                args = (x, y, z)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker            The tuple should contain model inputs such that ``model(*args)`` is a valid
216*da0073e9SAndroid Build Coastguard Worker            invocation of the model. Any non-Tensor arguments will be hard-coded into the
217*da0073e9SAndroid Build Coastguard Worker            exported model; any Tensor arguments will become inputs of the exported model,
218*da0073e9SAndroid Build Coastguard Worker            in the order they occur in the tuple.
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker            2. A TENSOR::
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker                args = torch.Tensor([1])
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker            This is equivalent to a 1-ary tuple of that Tensor.
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker            3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker                args = (x, {"y": input_y, "z": input_z})
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker            All but the last element of the tuple will be passed as non-keyword arguments,
231*da0073e9SAndroid Build Coastguard Worker            and named arguments will be set from the last element. If a named argument is
232*da0073e9SAndroid Build Coastguard Worker            not present in the dictionary, it is assigned the default value, or None if a
233*da0073e9SAndroid Build Coastguard Worker            default value is not provided.
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            .. warning::
236*da0073e9SAndroid Build Coastguard Worker                This behavior will be deprecated in a future release. Please use the
237*da0073e9SAndroid Build Coastguard Worker                kwargs argument instead.
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker            .. note::
240*da0073e9SAndroid Build Coastguard Worker                If a dictionary is the last element of the args tuple, it will be
241*da0073e9SAndroid Build Coastguard Worker                interpreted as containing named arguments. In order to pass a dict as the
242*da0073e9SAndroid Build Coastguard Worker                last non-keyword arg, provide an empty dict as the last element of the args
243*da0073e9SAndroid Build Coastguard Worker                tuple. For example, instead of::
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker                    torch.onnx.export(
246*da0073e9SAndroid Build Coastguard Worker                        model,
247*da0073e9SAndroid Build Coastguard Worker                        (
248*da0073e9SAndroid Build Coastguard Worker                            x,
249*da0073e9SAndroid Build Coastguard Worker                            # WRONG: will be interpreted as named arguments
250*da0073e9SAndroid Build Coastguard Worker                            {y: z},
251*da0073e9SAndroid Build Coastguard Worker                        ),
252*da0073e9SAndroid Build Coastguard Worker                        "test.onnx.pb",
253*da0073e9SAndroid Build Coastguard Worker                    )
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker                Write::
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker                    torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb")
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        f: Path to the output ONNX model file. E.g. "model.onnx".
260*da0073e9SAndroid Build Coastguard Worker        kwargs: Named arguments to the model.
261*da0073e9SAndroid Build Coastguard Worker        export_params: If True, all parameters will
262*da0073e9SAndroid Build Coastguard Worker            be exported. Set this to False if you want to export an untrained model.
263*da0073e9SAndroid Build Coastguard Worker            In this case, the exported model will first take all of its parameters
264*da0073e9SAndroid Build Coastguard Worker            as arguments, with the ordering as specified by ``model.state_dict().values()``
265*da0073e9SAndroid Build Coastguard Worker        verbose: if True, prints a description of the
266*da0073e9SAndroid Build Coastguard Worker            model being exported to stdout. In addition, the final ONNX graph will include the
267*da0073e9SAndroid Build Coastguard Worker            field ``doc_string``` from the exported model which mentions the source code locations
268*da0073e9SAndroid Build Coastguard Worker            for ``model``. If True, ONNX exporter logging will be turned on.
269*da0073e9SAndroid Build Coastguard Worker        training:
270*da0073e9SAndroid Build Coastguard Worker            * ``TrainingMode.EVAL``: export the model in inference mode.
271*da0073e9SAndroid Build Coastguard Worker            * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
272*da0073e9SAndroid Build Coastguard Worker                False and in training mode if model.training is True.
273*da0073e9SAndroid Build Coastguard Worker            * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations
274*da0073e9SAndroid Build Coastguard Worker                which might interfere with training.
275*da0073e9SAndroid Build Coastguard Worker        input_names (list of str, default empty list): names to assign to the
276*da0073e9SAndroid Build Coastguard Worker            input nodes of the graph, in order.
277*da0073e9SAndroid Build Coastguard Worker        output_names (list of str, default empty list): names to assign to the
278*da0073e9SAndroid Build Coastguard Worker            output nodes of the graph, in order.
279*da0073e9SAndroid Build Coastguard Worker        operator_export_type (enum, default OperatorExportTypes.ONNX):
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker            .. warning::
282*da0073e9SAndroid Build Coastguard Worker                This option will be deprecated in a future release. Future exported
283*da0073e9SAndroid Build Coastguard Worker                graphs will always use the default opset domain.
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker            * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops
286*da0073e9SAndroid Build Coastguard Worker                (in the default opset domain).
287*da0073e9SAndroid Build Coastguard Worker            * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops
288*da0073e9SAndroid Build Coastguard Worker                to standard ONNX ops in the default opset domain. If unable to do so
289*da0073e9SAndroid Build Coastguard Worker                (e.g. because support has not been added to convert a particular torch op to ONNX),
290*da0073e9SAndroid Build Coastguard Worker                fall back to exporting the op into a custom opset domain without conversion. Applies
291*da0073e9SAndroid Build Coastguard Worker                to `custom ops <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
292*da0073e9SAndroid Build Coastguard Worker                as well as ATen ops. For the exported model to be usable, the runtime must support
293*da0073e9SAndroid Build Coastguard Worker                these non-standard ops.
294*da0073e9SAndroid Build Coastguard Worker            * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten")
295*da0073e9SAndroid Build Coastguard Worker                are exported as ATen ops (in opset domain "org.pytorch.aten").
296*da0073e9SAndroid Build Coastguard Worker                `ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library, so
297*da0073e9SAndroid Build Coastguard Worker                this instructs the runtime to use PyTorch's implementation of these ops.
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker                .. warning::
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker                    Models exported this way are probably runnable only by Caffe2.
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker                    This may be useful if the numeric differences in implementations of operators are
304*da0073e9SAndroid Build Coastguard Worker                    causing large differences in behavior between PyTorch and Caffe2 (which is more
305*da0073e9SAndroid Build Coastguard Worker                    common on untrained models).
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker            * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op
308*da0073e9SAndroid Build Coastguard Worker                (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so
309*da0073e9SAndroid Build Coastguard Worker                (e.g. because support has not been added to convert a particular torch op to ONNX),
310*da0073e9SAndroid Build Coastguard Worker                fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for
311*da0073e9SAndroid Build Coastguard Worker                context.
312*da0073e9SAndroid Build Coastguard Worker                For example::
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker                    graph(%0 : Float):
315*da0073e9SAndroid Build Coastguard Worker                    %3 : int = prim::Constant[value=0]()
316*da0073e9SAndroid Build Coastguard Worker                    # conversion unsupported
317*da0073e9SAndroid Build Coastguard Worker                    %4 : Float = aten::triu(%0, %3)
318*da0073e9SAndroid Build Coastguard Worker                    # conversion supported
319*da0073e9SAndroid Build Coastguard Worker                    %5 : Float = aten::mul(%4, %0)
320*da0073e9SAndroid Build Coastguard Worker                    return (%5)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker                Assuming ``aten::triu`` is not supported in ONNX, this will be exported as::
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker                    graph(%0 : Float):
325*da0073e9SAndroid Build Coastguard Worker                    %1 : Long() = onnx::Constant[value={0}]()
326*da0073e9SAndroid Build Coastguard Worker                    # not converted
327*da0073e9SAndroid Build Coastguard Worker                    %2 : Float = aten::ATen[operator="triu"](%0, %1)
328*da0073e9SAndroid Build Coastguard Worker                    # converted
329*da0073e9SAndroid Build Coastguard Worker                    %3 : Float = onnx::Mul(%2, %0)
330*da0073e9SAndroid Build Coastguard Worker                    return (%3)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker                .. warning::
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker                    Models exported this way are probably runnable only by Caffe2.
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        opset_version (int, default 17): The version of the
337*da0073e9SAndroid Build Coastguard Worker            `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
338*da0073e9SAndroid Build Coastguard Worker            to target. Must be >= 7 and <= 17.
339*da0073e9SAndroid Build Coastguard Worker        do_constant_folding: Apply the constant-folding optimization.
340*da0073e9SAndroid Build Coastguard Worker            Constant-folding will replace some of the ops that have all constant inputs
341*da0073e9SAndroid Build Coastguard Worker            with pre-computed constant nodes.
342*da0073e9SAndroid Build Coastguard Worker        dynamic_axes:
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker            By default the exported model will have the shapes of all input and output tensors
345*da0073e9SAndroid Build Coastguard Worker            set to exactly match those given in ``args``. To specify axes of tensors as
346*da0073e9SAndroid Build Coastguard Worker            dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker            * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
349*da0073e9SAndroid Build Coastguard Worker                ``output_names``.
350*da0073e9SAndroid Build Coastguard Worker            * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
351*da0073e9SAndroid Build Coastguard Worker                list, each element is an axis index.
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker            For example::
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker                class SumModule(torch.nn.Module):
356*da0073e9SAndroid Build Coastguard Worker                    def forward(self, x):
357*da0073e9SAndroid Build Coastguard Worker                        return torch.sum(x, dim=1)
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker                torch.onnx.export(
361*da0073e9SAndroid Build Coastguard Worker                    SumModule(),
362*da0073e9SAndroid Build Coastguard Worker                    (torch.ones(2, 2),),
363*da0073e9SAndroid Build Coastguard Worker                    "onnx.pb",
364*da0073e9SAndroid Build Coastguard Worker                    input_names=["x"],
365*da0073e9SAndroid Build Coastguard Worker                    output_names=["sum"],
366*da0073e9SAndroid Build Coastguard Worker                )
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker            Produces::
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker                input {
371*da0073e9SAndroid Build Coastguard Worker                  name: "x"
372*da0073e9SAndroid Build Coastguard Worker                  ...
373*da0073e9SAndroid Build Coastguard Worker                      shape {
374*da0073e9SAndroid Build Coastguard Worker                        dim {
375*da0073e9SAndroid Build Coastguard Worker                          dim_value: 2  # axis 0
376*da0073e9SAndroid Build Coastguard Worker                        }
377*da0073e9SAndroid Build Coastguard Worker                        dim {
378*da0073e9SAndroid Build Coastguard Worker                          dim_value: 2  # axis 1
379*da0073e9SAndroid Build Coastguard Worker                ...
380*da0073e9SAndroid Build Coastguard Worker                output {
381*da0073e9SAndroid Build Coastguard Worker                  name: "sum"
382*da0073e9SAndroid Build Coastguard Worker                  ...
383*da0073e9SAndroid Build Coastguard Worker                      shape {
384*da0073e9SAndroid Build Coastguard Worker                        dim {
385*da0073e9SAndroid Build Coastguard Worker                          dim_value: 2  # axis 0
386*da0073e9SAndroid Build Coastguard Worker                ...
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker            While::
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker                torch.onnx.export(
391*da0073e9SAndroid Build Coastguard Worker                    SumModule(),
392*da0073e9SAndroid Build Coastguard Worker                    (torch.ones(2, 2),),
393*da0073e9SAndroid Build Coastguard Worker                    "onnx.pb",
394*da0073e9SAndroid Build Coastguard Worker                    input_names=["x"],
395*da0073e9SAndroid Build Coastguard Worker                    output_names=["sum"],
396*da0073e9SAndroid Build Coastguard Worker                    dynamic_axes={
397*da0073e9SAndroid Build Coastguard Worker                        # dict value: manually named axes
398*da0073e9SAndroid Build Coastguard Worker                        "x": {0: "my_custom_axis_name"},
399*da0073e9SAndroid Build Coastguard Worker                        # list value: automatic names
400*da0073e9SAndroid Build Coastguard Worker                        "sum": [0],
401*da0073e9SAndroid Build Coastguard Worker                    },
402*da0073e9SAndroid Build Coastguard Worker                )
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker            Produces::
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker                input {
407*da0073e9SAndroid Build Coastguard Worker                  name: "x"
408*da0073e9SAndroid Build Coastguard Worker                  ...
409*da0073e9SAndroid Build Coastguard Worker                      shape {
410*da0073e9SAndroid Build Coastguard Worker                        dim {
411*da0073e9SAndroid Build Coastguard Worker                          dim_param: "my_custom_axis_name"  # axis 0
412*da0073e9SAndroid Build Coastguard Worker                        }
413*da0073e9SAndroid Build Coastguard Worker                        dim {
414*da0073e9SAndroid Build Coastguard Worker                          dim_value: 2  # axis 1
415*da0073e9SAndroid Build Coastguard Worker                ...
416*da0073e9SAndroid Build Coastguard Worker                output {
417*da0073e9SAndroid Build Coastguard Worker                  name: "sum"
418*da0073e9SAndroid Build Coastguard Worker                  ...
419*da0073e9SAndroid Build Coastguard Worker                      shape {
420*da0073e9SAndroid Build Coastguard Worker                        dim {
421*da0073e9SAndroid Build Coastguard Worker                          dim_param: "sum_dynamic_axes_1"  # axis 0
422*da0073e9SAndroid Build Coastguard Worker                ...
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        keep_initializers_as_inputs: If True, all the
425*da0073e9SAndroid Build Coastguard Worker            initializers (typically corresponding to parameters) in the
426*da0073e9SAndroid Build Coastguard Worker            exported graph will also be added as inputs to the graph. If False,
427*da0073e9SAndroid Build Coastguard Worker            then initializers are not added as inputs to the graph, and only
428*da0073e9SAndroid Build Coastguard Worker            the non-parameter inputs are added as inputs.
429*da0073e9SAndroid Build Coastguard Worker            This may allow for better optimizations (e.g. constant folding) by
430*da0073e9SAndroid Build Coastguard Worker            backends/runtimes.
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker            If True, `deduplicate_initializers` pass will not be executed. This means
433*da0073e9SAndroid Build Coastguard Worker            initializers with duplicated values will not be deduplicated and
434*da0073e9SAndroid Build Coastguard Worker            will be treated as distinct inputs to the graph. This allows different
435*da0073e9SAndroid Build Coastguard Worker            input initializers to be supplied at the runtime following export.
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker            If ``opset_version < 9``, initializers MUST be part of graph
438*da0073e9SAndroid Build Coastguard Worker            inputs and this argument will be ignored and the behavior will be
439*da0073e9SAndroid Build Coastguard Worker            equivalent to setting this argument to True.
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        custom_opsets (dict[str, int], default empty dict): A dict with schema:
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker            * KEY (str): opset domain name
444*da0073e9SAndroid Build Coastguard Worker            * VALUE (int): opset version
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker            If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
447*da0073e9SAndroid Build Coastguard Worker            the opset version is set to 1. Only custom opset domain name and version should be
448*da0073e9SAndroid Build Coastguard Worker            indicated through this argument.
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker        export_modules_as_functions: Flag to enable
451*da0073e9SAndroid Build Coastguard Worker            exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
452*da0073e9SAndroid Build Coastguard Worker            particular types of modules to export as local functions in ONNX.
453*da0073e9SAndroid Build Coastguard Worker            This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
454*da0073e9SAndroid Build Coastguard Worker            ``opset_version`` < 15 implies IR version < 8, which means no local function support.
455*da0073e9SAndroid Build Coastguard Worker            Module variables will be exported as function attributes. There are two categories of function
456*da0073e9SAndroid Build Coastguard Worker            attributes.
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker            1. Annotated attributes: class variables that have type annotations via
459*da0073e9SAndroid Build Coastguard Worker            `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
460*da0073e9SAndroid Build Coastguard Worker            will be exported as attributes.
461*da0073e9SAndroid Build Coastguard Worker            Annotated attributes are not used inside the subgraph of ONNX local function because
462*da0073e9SAndroid Build Coastguard Worker            they are not created by PyTorch JIT tracing, but they may be used by consumers
463*da0073e9SAndroid Build Coastguard Worker            to determine whether or not to replace the function with a particular fused kernel.
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker            2. Inferred attributes: variables that are used by operators inside the module. Attribute names
466*da0073e9SAndroid Build Coastguard Worker            will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
467*da0073e9SAndroid Build Coastguard Worker            python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker            * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
470*da0073e9SAndroid Build Coastguard Worker            * ``True``: export all ``nn.Module`` forward calls as local function nodes.
471*da0073e9SAndroid Build Coastguard Worker            * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
472*da0073e9SAndroid Build Coastguard Worker                only if the type of the ``nn.Module`` is found in the set.
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        autograd_inlining: Flag used to control whether to inline autograd functions.
475*da0073e9SAndroid Build Coastguard Worker            Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    Raises:
478*da0073e9SAndroid Build Coastguard Worker        :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph.
479*da0073e9SAndroid Build Coastguard Worker        :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it
480*da0073e9SAndroid Build Coastguard Worker            uses an operator that is not supported by the exporter.
481*da0073e9SAndroid Build Coastguard Worker        :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export.
482*da0073e9SAndroid Build Coastguard Worker            All errors are subclasses of :class:`errors.OnnxExporterError`.
483*da0073e9SAndroid Build Coastguard Worker    """
484*da0073e9SAndroid Build Coastguard Worker    if operator_export_type != _C_onnx.OperatorExportTypes.ONNX:
485*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
486*da0073e9SAndroid Build Coastguard Worker            "Setting `operator_export_type` to something other than default is deprecated. "
487*da0073e9SAndroid Build Coastguard Worker            "The option will be removed in a future release.",
488*da0073e9SAndroid Build Coastguard Worker            category=FutureWarning,
489*da0073e9SAndroid Build Coastguard Worker        )
490*da0073e9SAndroid Build Coastguard Worker    if training == _C_onnx.TrainingMode.TRAINING:
491*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
492*da0073e9SAndroid Build Coastguard Worker            "Setting `training` to something other than default is deprecated. "
493*da0073e9SAndroid Build Coastguard Worker            "The option will be removed in a future release. Please set the training mode "
494*da0073e9SAndroid Build Coastguard Worker            "before exporting the model.",
495*da0073e9SAndroid Build Coastguard Worker            category=FutureWarning,
496*da0073e9SAndroid Build Coastguard Worker        )
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    args = (args,) if isinstance(args, torch.Tensor) else args
499*da0073e9SAndroid Build Coastguard Worker    if kwargs is not None:
500*da0073e9SAndroid Build Coastguard Worker        args = args + (kwargs,)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    _export(
503*da0073e9SAndroid Build Coastguard Worker        model,
504*da0073e9SAndroid Build Coastguard Worker        args,
505*da0073e9SAndroid Build Coastguard Worker        f,
506*da0073e9SAndroid Build Coastguard Worker        export_params,
507*da0073e9SAndroid Build Coastguard Worker        verbose,
508*da0073e9SAndroid Build Coastguard Worker        training,
509*da0073e9SAndroid Build Coastguard Worker        input_names,
510*da0073e9SAndroid Build Coastguard Worker        output_names,
511*da0073e9SAndroid Build Coastguard Worker        operator_export_type=operator_export_type,
512*da0073e9SAndroid Build Coastguard Worker        opset_version=opset_version,
513*da0073e9SAndroid Build Coastguard Worker        do_constant_folding=do_constant_folding,
514*da0073e9SAndroid Build Coastguard Worker        dynamic_axes=dynamic_axes,
515*da0073e9SAndroid Build Coastguard Worker        keep_initializers_as_inputs=keep_initializers_as_inputs,
516*da0073e9SAndroid Build Coastguard Worker        custom_opsets=custom_opsets,
517*da0073e9SAndroid Build Coastguard Worker        export_modules_as_functions=export_modules_as_functions,
518*da0073e9SAndroid Build Coastguard Worker        autograd_inlining=autograd_inlining,
519*da0073e9SAndroid Build Coastguard Worker    )
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    return None
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Workerdef _is_constant_tensor_list(node):
525*da0073e9SAndroid Build Coastguard Worker    if node.kind() != "prim::Constant":
526*da0073e9SAndroid Build Coastguard Worker        return False
527*da0073e9SAndroid Build Coastguard Worker    output_type = node.output().type()
528*da0073e9SAndroid Build Coastguard Worker    if output_type.isSubtypeOf(_C.ListType.ofTensors()):
529*da0073e9SAndroid Build Coastguard Worker        return True
530*da0073e9SAndroid Build Coastguard Worker    if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())):
531*da0073e9SAndroid Build Coastguard Worker        return True
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker# ONNX can't handle constants that are lists of tensors, which can
535*da0073e9SAndroid Build Coastguard Worker# get generated in constant prop. So we split them back into prim::ListConstructs
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Workerdef _split_tensor_list_constants(g, block):
539*da0073e9SAndroid Build Coastguard Worker    for node in block.nodes():
540*da0073e9SAndroid Build Coastguard Worker        for subblock in node.blocks():
541*da0073e9SAndroid Build Coastguard Worker            _split_tensor_list_constants(g, subblock)
542*da0073e9SAndroid Build Coastguard Worker        if _is_constant_tensor_list(node):
543*da0073e9SAndroid Build Coastguard Worker            inputs = []
544*da0073e9SAndroid Build Coastguard Worker            for val in node.output().toIValue():
545*da0073e9SAndroid Build Coastguard Worker                input = g.insertConstant(val)
546*da0073e9SAndroid Build Coastguard Worker                input.node().moveBefore(node)
547*da0073e9SAndroid Build Coastguard Worker                input.node().copyMetadata(node)
548*da0073e9SAndroid Build Coastguard Worker                inputs.append(input)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker            lc = (
551*da0073e9SAndroid Build Coastguard Worker                g.create("prim::ListConstruct", inputs)
552*da0073e9SAndroid Build Coastguard Worker                .insertBefore(node)
553*da0073e9SAndroid Build Coastguard Worker                .output()
554*da0073e9SAndroid Build Coastguard Worker                .setType(_C.ListType.ofTensors())
555*da0073e9SAndroid Build Coastguard Worker            )
556*da0073e9SAndroid Build Coastguard Worker            lc.node().copyMetadata(node)
557*da0073e9SAndroid Build Coastguard Worker            node.output().replaceAllUsesWith(lc)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Workerdef _optimize_graph(
561*da0073e9SAndroid Build Coastguard Worker    graph: _C.Graph,
562*da0073e9SAndroid Build Coastguard Worker    operator_export_type: _C_onnx.OperatorExportTypes,
563*da0073e9SAndroid Build Coastguard Worker    _disable_torch_constant_prop: bool = False,
564*da0073e9SAndroid Build Coastguard Worker    fixed_batch_size: bool = False,
565*da0073e9SAndroid Build Coastguard Worker    params_dict=None,
566*da0073e9SAndroid Build Coastguard Worker    dynamic_axes=None,
567*da0073e9SAndroid Build Coastguard Worker    input_names=None,
568*da0073e9SAndroid Build Coastguard Worker    module=None,
569*da0073e9SAndroid Build Coastguard Worker):
570*da0073e9SAndroid Build Coastguard Worker    if params_dict is None:
571*da0073e9SAndroid Build Coastguard Worker        params_dict = {}
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    # Inline everything
574*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_inline(graph)
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker    # Remove fork/wait nodes
577*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_inline_fork_wait(graph)
578*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
579*da0073e9SAndroid Build Coastguard Worker    if GLOBALS.autograd_inlining:
580*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_autograd_function_process(graph)
581*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lower_all_tuples(graph)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker    # we now record some ops like ones/zeros
584*da0073e9SAndroid Build Coastguard Worker    # into a trace where we previously recorded constants.
585*da0073e9SAndroid Build Coastguard Worker    # use constant prop to maintain our current level of onnx support
586*da0073e9SAndroid Build Coastguard Worker    # without implementing symbolics for all of them
587*da0073e9SAndroid Build Coastguard Worker    if _disable_torch_constant_prop is False:
588*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_constant_propagation(graph)
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker    _split_tensor_list_constants(graph, graph)
591*da0073e9SAndroid Build Coastguard Worker    # run dce to eliminate dead parts of the graph that might have been
592*da0073e9SAndroid Build Coastguard Worker    # left behind by things like symbolic_override
593*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_dce(graph)
594*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    # CSE should improve perf when Autocast is used with disabled cache
597*da0073e9SAndroid Build Coastguard Worker    # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092
598*da0073e9SAndroid Build Coastguard Worker    # Must run before _C._jit_pass_erase_number_types to prevent type substitution
599*da0073e9SAndroid Build Coastguard Worker    if _C._jit_pass_cse(graph):
600*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_lint(graph)
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_canonicalize_graph_fuser_ops(graph)
603*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
604*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_peephole(graph, True)
605*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_fuse_addmm(graph)
606*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_peephole(graph, True)
609*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lower_all_tuples(graph)
610*da0073e9SAndroid Build Coastguard Worker    # in _jit_pass_onnx, symbolic functions are called for each node for conversion.
611*da0073e9SAndroid Build Coastguard Worker    # However, there are nodes that cannot be converted without additional context.
612*da0073e9SAndroid Build Coastguard Worker    # For example, the number of outputs from split (and whether it is static or dynamic) is unknown
613*da0073e9SAndroid Build Coastguard Worker    # until the point where it is unpacked by listUnpack node.
614*da0073e9SAndroid Build Coastguard Worker    # This pass does a preprocess, and prepares the nodes such that enough context can be received
615*da0073e9SAndroid Build Coastguard Worker    # by the symbolic function.
616*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
617*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_preprocess(graph)
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker    # onnx does not support tuples, so try to remove them
620*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker    # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
623*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_prepare_division_for_onnx(graph)
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_remove_print(graph)
626*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_preprocess_caffe2(graph)
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker    symbolic_helper._quantized_ops.clear()
629*da0073e9SAndroid Build Coastguard Worker    # Unpack quantized weights for conv and linear ops and insert into graph.
630*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict)
631*da0073e9SAndroid Build Coastguard Worker    # onnx only supports tensors, so we turn all out number types into tensors
632*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_erase_number_types(graph)
633*da0073e9SAndroid Build Coastguard Worker    if GLOBALS.onnx_shape_inference:
634*da0073e9SAndroid Build Coastguard Worker        input_names = [] if input_names is None else input_names
635*da0073e9SAndroid Build Coastguard Worker        dynamic_axes = {} if dynamic_axes is None else dynamic_axes
636*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
637*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_lint(graph)
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker    graph = _C._jit_pass_onnx(graph, operator_export_type)
640*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_lint(graph)
641*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_scalar_type_analysis(
644*da0073e9SAndroid Build Coastguard Worker        graph, True, GLOBALS.export_onnx_opset_version
645*da0073e9SAndroid Build Coastguard Worker    )
646*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_peephole(
649*da0073e9SAndroid Build Coastguard Worker        graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
650*da0073e9SAndroid Build Coastguard Worker    )
651*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker    # graph is not a valid jit graph anymore because types have been replaced
654*da0073e9SAndroid Build Coastguard Worker    # (e.g. int with Tensor), so it now contains operators that don't actually
655*da0073e9SAndroid Build Coastguard Worker    # exist. We can't run normal dead code elimination because it'd fail trying
656*da0073e9SAndroid Build Coastguard Worker    # to look up if an operator has side effects, but we can run a dead code
657*da0073e9SAndroid Build Coastguard Worker    # elimination variant that doesn't need to look up if an op has side effects.
658*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
659*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
660*da0073e9SAndroid Build Coastguard Worker    graph = _C._jit_pass_canonicalize(graph)
661*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_lint(graph)
662*da0073e9SAndroid Build Coastguard Worker    if GLOBALS.onnx_shape_inference:
663*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_graph_shape_type_inference(
664*da0073e9SAndroid Build Coastguard Worker            graph, params_dict, GLOBALS.export_onnx_opset_version
665*da0073e9SAndroid Build Coastguard Worker        )
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker    return graph
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Workerdef warn_on_static_input_change(input_states):
671*da0073e9SAndroid Build Coastguard Worker    """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker    We accept dictionaries and strings as ONNX inputs, but they should be only for
674*da0073e9SAndroid Build Coastguard Worker    configuration use. we detect here if these inputs are modified, and if so we warn
675*da0073e9SAndroid Build Coastguard Worker    the user that the changes won't take effect in the traced ONNX graph.
676*da0073e9SAndroid Build Coastguard Worker    """
677*da0073e9SAndroid Build Coastguard Worker    for input, traced_input in zip(input_states[0], input_states[1]):
678*da0073e9SAndroid Build Coastguard Worker        if isinstance(input, dict):
679*da0073e9SAndroid Build Coastguard Worker            if list(input.keys()) != list(traced_input.keys()):
680*da0073e9SAndroid Build Coastguard Worker                warning = (
681*da0073e9SAndroid Build Coastguard Worker                    "We detected that you are modifying a dictionary that is an input to your "
682*da0073e9SAndroid Build Coastguard Worker                    "model. "
683*da0073e9SAndroid Build Coastguard Worker                    "Note that dictionaries are allowed as inputs in ONNX but they should be "
684*da0073e9SAndroid Build Coastguard Worker                    "handled with care. "
685*da0073e9SAndroid Build Coastguard Worker                    "Usages of dictionaries is not recommended, and should not be used except "
686*da0073e9SAndroid Build Coastguard Worker                    "for configuration use. "
687*da0073e9SAndroid Build Coastguard Worker                    "Also note that the order and values of the keys must remain the same. "
688*da0073e9SAndroid Build Coastguard Worker                )
689*da0073e9SAndroid Build Coastguard Worker                warnings.warn(warning)
690*da0073e9SAndroid Build Coastguard Worker        elif isinstance(input, str):
691*da0073e9SAndroid Build Coastguard Worker            if input != traced_input:
692*da0073e9SAndroid Build Coastguard Worker                warning = (
693*da0073e9SAndroid Build Coastguard Worker                    "The model seems to have string inputs/outputs. "
694*da0073e9SAndroid Build Coastguard Worker                    "Note that strings will not appear as inputs/outputs of the ONNX graph. "
695*da0073e9SAndroid Build Coastguard Worker                )
696*da0073e9SAndroid Build Coastguard Worker                warnings.warn(warning)
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Workerdef _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
700*da0073e9SAndroid Build Coastguard Worker    """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
701*da0073e9SAndroid Build Coastguard Worker    return arg_value
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Workerdef _decide_keep_init_as_input(
705*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs: bool | None,
706*da0073e9SAndroid Build Coastguard Worker    operator_export_type: _C_onnx.OperatorExportTypes,
707*da0073e9SAndroid Build Coastguard Worker    opset_version: int,
708*da0073e9SAndroid Build Coastguard Worker):
709*da0073e9SAndroid Build Coastguard Worker    """Decides whether the initializers in the graph should be listed as ONNX graph inputs.
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    This method encapsulates the logic to decide whether the initializers in the graph
712*da0073e9SAndroid Build Coastguard Worker    should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4).
713*da0073e9SAndroid Build Coastguard Worker    If keep_initializers_as_inputs is not specified (None), then we decide whether to keep
714*da0073e9SAndroid Build Coastguard Worker    initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type
715*da0073e9SAndroid Build Coastguard Worker    is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other
716*da0073e9SAndroid Build Coastguard Worker    export types keep initializers as input (val_keep_init_as_ip=True).
717*da0073e9SAndroid Build Coastguard Worker    If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8,
718*da0073e9SAndroid Build Coastguard Worker    in which case it must be ignored because for opset version <= 8, all initializers MUST be
719*da0073e9SAndroid Build Coastguard Worker    part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True.
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    Special handling is needed for opset version 8 or lower, because irrespective
722*da0073e9SAndroid Build Coastguard Worker    of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3
723*da0073e9SAndroid Build Coastguard Worker    semantics, i.e. all initializers must be listed as ONNX graph input.
724*da0073e9SAndroid Build Coastguard Worker    """
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker    if opset_version < 9:
727*da0073e9SAndroid Build Coastguard Worker        if keep_initializers_as_inputs is False:
728*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
729*da0073e9SAndroid Build Coastguard Worker                "Setting 'keep_initializers_as_inputs=False' for opset version"
730*da0073e9SAndroid Build Coastguard Worker                "8 or lower would lead to an invalid ONNX graph. Therefore, "
731*da0073e9SAndroid Build Coastguard Worker                "'keep_initializers_as_inputs=False' is ignored during export."
732*da0073e9SAndroid Build Coastguard Worker                "Exported model will have initializers as graph inputs (compliant "
733*da0073e9SAndroid Build Coastguard Worker                " to ONNX IR v3)."
734*da0073e9SAndroid Build Coastguard Worker            )
735*da0073e9SAndroid Build Coastguard Worker        return True  # i.e. True == initializers are part of graph input (ONNX IR v3)
736*da0073e9SAndroid Build Coastguard Worker    val_keep_init_as_ip = (
737*da0073e9SAndroid Build Coastguard Worker        True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
738*da0073e9SAndroid Build Coastguard Worker    )
739*da0073e9SAndroid Build Coastguard Worker    if (
740*da0073e9SAndroid Build Coastguard Worker        keep_initializers_as_inputs is None
741*da0073e9SAndroid Build Coastguard Worker        and operator_export_type is _C_onnx.OperatorExportTypes.ONNX
742*da0073e9SAndroid Build Coastguard Worker    ):
743*da0073e9SAndroid Build Coastguard Worker        val_keep_init_as_ip = False
744*da0073e9SAndroid Build Coastguard Worker    return val_keep_init_as_ip
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Workerdef _decide_add_node_names(add_node_names, operator_export_type):
748*da0073e9SAndroid Build Coastguard Worker    return _resolve_args_by_export_type(
749*da0073e9SAndroid Build Coastguard Worker        "add_node_names", add_node_names, operator_export_type
750*da0073e9SAndroid Build Coastguard Worker    )
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Workerdef _decide_constant_folding(do_constant_folding, operator_export_type, training):
754*da0073e9SAndroid Build Coastguard Worker    do_constant_folding = _resolve_args_by_export_type(
755*da0073e9SAndroid Build Coastguard Worker        "do_constant_folding", do_constant_folding, operator_export_type
756*da0073e9SAndroid Build Coastguard Worker    )
757*da0073e9SAndroid Build Coastguard Worker    if do_constant_folding and (
758*da0073e9SAndroid Build Coastguard Worker        training is not None and training is not _C_onnx.TrainingMode.EVAL
759*da0073e9SAndroid Build Coastguard Worker    ):
760*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
761*da0073e9SAndroid Build Coastguard Worker            "It is recommended that constant folding be turned off ('do_constant_folding=False') "
762*da0073e9SAndroid Build Coastguard Worker            "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' "
763*da0073e9SAndroid Build Coastguard Worker            "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some "
764*da0073e9SAndroid Build Coastguard Worker            "learnable model parameters may not translate correctly in the exported ONNX model "
765*da0073e9SAndroid Build Coastguard Worker            "because constant folding mutates model parameters. Please consider "
766*da0073e9SAndroid Build Coastguard Worker            "turning off constant folding or setting the training=TrainingMode.EVAL."
767*da0073e9SAndroid Build Coastguard Worker        )
768*da0073e9SAndroid Build Coastguard Worker    return do_constant_folding
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Workerdef _signature(model) -> inspect.Signature:
772*da0073e9SAndroid Build Coastguard Worker    should_be_callable = getattr(model, "forward", model)
773*da0073e9SAndroid Build Coastguard Worker    if callable(should_be_callable):
774*da0073e9SAndroid Build Coastguard Worker        return inspect.signature(should_be_callable)
775*da0073e9SAndroid Build Coastguard Worker    raise ValueError("model has no forward method and is not callable")
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Workerdef _decide_input_format(model, args):
779*da0073e9SAndroid Build Coastguard Worker    try:
780*da0073e9SAndroid Build Coastguard Worker        sig = _signature(model)
781*da0073e9SAndroid Build Coastguard Worker    except ValueError as e:
782*da0073e9SAndroid Build Coastguard Worker        warnings.warn(f"{e}, skipping _decide_input_format")
783*da0073e9SAndroid Build Coastguard Worker        return args
784*da0073e9SAndroid Build Coastguard Worker    try:
785*da0073e9SAndroid Build Coastguard Worker        ordered_list_keys = list(sig.parameters.keys())
786*da0073e9SAndroid Build Coastguard Worker        if ordered_list_keys[0] == "self":
787*da0073e9SAndroid Build Coastguard Worker            ordered_list_keys = ordered_list_keys[1:]
788*da0073e9SAndroid Build Coastguard Worker        args_dict: dict = {}
789*da0073e9SAndroid Build Coastguard Worker        if isinstance(args, list):
790*da0073e9SAndroid Build Coastguard Worker            args_list = args
791*da0073e9SAndroid Build Coastguard Worker        elif isinstance(args, tuple):
792*da0073e9SAndroid Build Coastguard Worker            args_list = list(args)
793*da0073e9SAndroid Build Coastguard Worker        else:
794*da0073e9SAndroid Build Coastguard Worker            args_list = [args]
795*da0073e9SAndroid Build Coastguard Worker        if isinstance(args_list[-1], dict):
796*da0073e9SAndroid Build Coastguard Worker            args_dict = args_list[-1]
797*da0073e9SAndroid Build Coastguard Worker            args_list = args_list[:-1]
798*da0073e9SAndroid Build Coastguard Worker        n_nonkeyword = len(args_list)
799*da0073e9SAndroid Build Coastguard Worker        for optional_arg in ordered_list_keys[n_nonkeyword:]:
800*da0073e9SAndroid Build Coastguard Worker            if optional_arg in args_dict:
801*da0073e9SAndroid Build Coastguard Worker                args_list.append(args_dict[optional_arg])
802*da0073e9SAndroid Build Coastguard Worker            # Check if this arg has a default value
803*da0073e9SAndroid Build Coastguard Worker            else:
804*da0073e9SAndroid Build Coastguard Worker                param = sig.parameters[optional_arg]
805*da0073e9SAndroid Build Coastguard Worker                if param.default != param.empty:
806*da0073e9SAndroid Build Coastguard Worker                    args_list.append(param.default)
807*da0073e9SAndroid Build Coastguard Worker        args = args_list if isinstance(args, list) else tuple(args_list)
808*da0073e9SAndroid Build Coastguard Worker    # Cases of models with no input args
809*da0073e9SAndroid Build Coastguard Worker    except IndexError:
810*da0073e9SAndroid Build Coastguard Worker        warnings.warn("No input args, skipping _decide_input_format")
811*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
812*da0073e9SAndroid Build Coastguard Worker        warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}")
813*da0073e9SAndroid Build Coastguard Worker    return args
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Workerdef _from_dynamic_axes_to_dynamic_shapes(
817*da0073e9SAndroid Build Coastguard Worker    model,
818*da0073e9SAndroid Build Coastguard Worker    dynamic_axes: Mapping[str, Mapping[int, str]]
819*da0073e9SAndroid Build Coastguard Worker    | Mapping[str, Sequence[int]]
820*da0073e9SAndroid Build Coastguard Worker    | None = None,
821*da0073e9SAndroid Build Coastguard Worker    input_names: Sequence[str] | None = None,
822*da0073e9SAndroid Build Coastguard Worker) -> dict[str, Any] | None:
823*da0073e9SAndroid Build Coastguard Worker    """
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    dynamic_axes examples:
826*da0073e9SAndroid Build Coastguard Worker    (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
827*da0073e9SAndroid Build Coastguard Worker    (2) dynamic_axes = {"x": [0], "y": [1]}
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker    these will be converted to dynamic_shapes respectively:
830*da0073e9SAndroid Build Coastguard Worker    (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
831*da0073e9SAndroid Build Coastguard Worker    (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}}  # auto-generated dim names
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    """
834*da0073e9SAndroid Build Coastguard Worker    if dynamic_axes is None:
835*da0073e9SAndroid Build Coastguard Worker        return None
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker    if input_names is None:
838*da0073e9SAndroid Build Coastguard Worker        input_names_set = set()
839*da0073e9SAndroid Build Coastguard Worker    else:
840*da0073e9SAndroid Build Coastguard Worker        input_names_set = set(input_names)
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker    dynamic_shapes: dict[str, Any | None] = {}
843*da0073e9SAndroid Build Coastguard Worker    for input_name, axes in dynamic_axes.items():
844*da0073e9SAndroid Build Coastguard Worker        if input_name in input_names_set:
845*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
846*da0073e9SAndroid Build Coastguard Worker                "Assinging new input names is not supported yet. Please use model forward signature "
847*da0073e9SAndroid Build Coastguard Worker                "to specify input names in dynamix_axes."
848*da0073e9SAndroid Build Coastguard Worker            )
849*da0073e9SAndroid Build Coastguard Worker        if isinstance(axes, dict):
850*da0073e9SAndroid Build Coastguard Worker            dynamic_shapes[input_name] = {
851*da0073e9SAndroid Build Coastguard Worker                k: torch.export.Dim(v) for k, v in axes.items()
852*da0073e9SAndroid Build Coastguard Worker            }
853*da0073e9SAndroid Build Coastguard Worker        elif isinstance(axes, list):
854*da0073e9SAndroid Build Coastguard Worker            dynamic_shapes[input_name] = {
855*da0073e9SAndroid Build Coastguard Worker                k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes
856*da0073e9SAndroid Build Coastguard Worker            }
857*da0073e9SAndroid Build Coastguard Worker        else:
858*da0073e9SAndroid Build Coastguard Worker            raise TypeError(
859*da0073e9SAndroid Build Coastguard Worker                f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
860*da0073e9SAndroid Build Coastguard Worker            )
861*da0073e9SAndroid Build Coastguard Worker    # torch.export.export needs static dim to present in dynamic_shapes
862*da0073e9SAndroid Build Coastguard Worker    # for all input tensors, so we need to add them with None
863*da0073e9SAndroid Build Coastguard Worker    try:
864*da0073e9SAndroid Build Coastguard Worker        sig = _signature(model)
865*da0073e9SAndroid Build Coastguard Worker    except ValueError as e:
866*da0073e9SAndroid Build Coastguard Worker        warnings.warn(f"{e}, skipping auto filling None on static axes...")
867*da0073e9SAndroid Build Coastguard Worker        return dynamic_shapes
868*da0073e9SAndroid Build Coastguard Worker    for input_name in sig.parameters.keys():
869*da0073e9SAndroid Build Coastguard Worker        if input_name not in dynamic_shapes:
870*da0073e9SAndroid Build Coastguard Worker            dynamic_shapes[input_name] = None
871*da0073e9SAndroid Build Coastguard Worker    return dynamic_shapes
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Workerdef _trace(func, args, operator_export_type, return_outs=False):
875*da0073e9SAndroid Build Coastguard Worker    # Special case for common case of passing a single Tensor
876*da0073e9SAndroid Build Coastguard Worker    if isinstance(args, torch.Tensor):
877*da0073e9SAndroid Build Coastguard Worker        args = (args,)
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
880*da0073e9SAndroid Build Coastguard Worker        func,
881*da0073e9SAndroid Build Coastguard Worker        args,
882*da0073e9SAndroid Build Coastguard Worker        strict=False,
883*da0073e9SAndroid Build Coastguard Worker        _force_outplace=False,
884*da0073e9SAndroid Build Coastguard Worker        _return_inputs_states=True,
885*da0073e9SAndroid Build Coastguard Worker    )
886*da0073e9SAndroid Build Coastguard Worker    warn_on_static_input_change(inputs_states)
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker    trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
889*da0073e9SAndroid Build Coastguard Worker    if return_outs:
890*da0073e9SAndroid Build Coastguard Worker        return trace_graph, torch_out
891*da0073e9SAndroid Build Coastguard Worker    return trace_graph
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Workerdef _trace_and_get_graph_from_model(model, args):
895*da0073e9SAndroid Build Coastguard Worker    # A basic sanity check: make sure the state_dict keys are the same
896*da0073e9SAndroid Build Coastguard Worker    # before and after running the model.  Fail fast!
897*da0073e9SAndroid Build Coastguard Worker    orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker    # Disable Autocast cache because it replaces kernel's weight and bias
900*da0073e9SAndroid Build Coastguard Worker    # by (undesired) constants.
901*da0073e9SAndroid Build Coastguard Worker    # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
902*da0073e9SAndroid Build Coastguard Worker    prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
903*da0073e9SAndroid Build Coastguard Worker    torch.set_autocast_cache_enabled(False)
904*da0073e9SAndroid Build Coastguard Worker    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
905*da0073e9SAndroid Build Coastguard Worker        model,
906*da0073e9SAndroid Build Coastguard Worker        args,
907*da0073e9SAndroid Build Coastguard Worker        strict=False,
908*da0073e9SAndroid Build Coastguard Worker        _force_outplace=False,
909*da0073e9SAndroid Build Coastguard Worker        _return_inputs_states=True,
910*da0073e9SAndroid Build Coastguard Worker    )
911*da0073e9SAndroid Build Coastguard Worker    torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker    warn_on_static_input_change(inputs_states)
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker    if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
916*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
917*da0073e9SAndroid Build Coastguard Worker            "state_dict changed after running the tracer; "
918*da0073e9SAndroid Build Coastguard Worker            "something weird is happening in your model!"
919*da0073e9SAndroid Build Coastguard Worker        )
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker    return trace_graph, torch_out
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Workerdef _get_param_count_list(method_graph, args_params):
925*da0073e9SAndroid Build Coastguard Worker    param_count_list = []
926*da0073e9SAndroid Build Coastguard Worker    for input_, arg_params_ in zip(method_graph.inputs(), args_params):
927*da0073e9SAndroid Build Coastguard Worker        if "PackedParams" in str(input_.type()):
928*da0073e9SAndroid Build Coastguard Worker            in_vars, _ = torch.jit._flatten(arg_params_)
929*da0073e9SAndroid Build Coastguard Worker            param_count_list.append(len(in_vars))
930*da0073e9SAndroid Build Coastguard Worker        else:
931*da0073e9SAndroid Build Coastguard Worker            param_count_list.append(arg_params_ is not None)
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker    return param_count_list
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Workerdef _check_flatten_did_not_remove(original, jit_flattened):
937*da0073e9SAndroid Build Coastguard Worker    """torch.jit._flatten removes None. Check if it did so in this case."""
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker    def flatten(x):
940*da0073e9SAndroid Build Coastguard Worker        if isinstance(x, (list, tuple)):
941*da0073e9SAndroid Build Coastguard Worker            for inner in x:
942*da0073e9SAndroid Build Coastguard Worker                yield from flatten(inner)
943*da0073e9SAndroid Build Coastguard Worker        elif isinstance(x, dict):
944*da0073e9SAndroid Build Coastguard Worker            for inner in x.values():
945*da0073e9SAndroid Build Coastguard Worker                yield from flatten(inner)
946*da0073e9SAndroid Build Coastguard Worker        else:
947*da0073e9SAndroid Build Coastguard Worker            yield x
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker    flattened_with_none = list(flatten(original))
950*da0073e9SAndroid Build Coastguard Worker    num_none = len(flattened_with_none) - len(jit_flattened)
951*da0073e9SAndroid Build Coastguard Worker    assert num_none >= 0
952*da0073e9SAndroid Build Coastguard Worker    if num_none:
953*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
954*da0073e9SAndroid Build Coastguard Worker            f"args contained {num_none} None's after flattening. "
955*da0073e9SAndroid Build Coastguard Worker            "When exporting a ScriptModule or ScriptFunction, no args may "
956*da0073e9SAndroid Build Coastguard Worker            "be None because that breaks type propagation."
957*da0073e9SAndroid Build Coastguard Worker        )
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Workerdef _create_jit_graph(
961*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any]
962*da0073e9SAndroid Build Coastguard Worker) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]:
963*da0073e9SAndroid Build Coastguard Worker    if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
964*da0073e9SAndroid Build Coastguard Worker        flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
965*da0073e9SAndroid Build Coastguard Worker        _check_flatten_did_not_remove(args, flattened_args)
966*da0073e9SAndroid Build Coastguard Worker        torch_out = None
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        if isinstance(model, torch.jit.ScriptModule):
969*da0073e9SAndroid Build Coastguard Worker            try:
970*da0073e9SAndroid Build Coastguard Worker                graph = model.forward.graph  # type: ignore[attr-defined]
971*da0073e9SAndroid Build Coastguard Worker            except AttributeError as e:
972*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("'forward' method must be a script method") from e
973*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_onnx_function_substitution(graph)
974*da0073e9SAndroid Build Coastguard Worker            freezed_module = _C._freeze_module(
975*da0073e9SAndroid Build Coastguard Worker                cast(_C.ScriptModule, model._c), preserveParameters=True
976*da0073e9SAndroid Build Coastguard Worker            )
977*da0073e9SAndroid Build Coastguard Worker            module, params = _C._jit_onnx_list_model_parameters(freezed_module)
978*da0073e9SAndroid Build Coastguard Worker            method_graph = module._get_method("forward").graph
979*da0073e9SAndroid Build Coastguard Worker            args_params = tuple(args) + tuple(params)
980*da0073e9SAndroid Build Coastguard Worker            param_count_list = _get_param_count_list(method_graph, args_params)
981*da0073e9SAndroid Build Coastguard Worker            in_vars, _ = torch.jit._flatten(args_params)
982*da0073e9SAndroid Build Coastguard Worker            graph = _C._propagate_and_assign_input_shapes(
983*da0073e9SAndroid Build Coastguard Worker                method_graph, tuple(in_vars), param_count_list, False, False
984*da0073e9SAndroid Build Coastguard Worker            )
985*da0073e9SAndroid Build Coastguard Worker            return graph, params, torch_out, module
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker        # torch.jit.ScriptFunction
988*da0073e9SAndroid Build Coastguard Worker        params = []
989*da0073e9SAndroid Build Coastguard Worker        graph = model.graph
990*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_function_substitution(graph)
991*da0073e9SAndroid Build Coastguard Worker        param_count_list = _get_param_count_list(graph, args)
992*da0073e9SAndroid Build Coastguard Worker        graph = _C._propagate_and_assign_input_shapes(
993*da0073e9SAndroid Build Coastguard Worker            graph, flattened_args, param_count_list, False, False
994*da0073e9SAndroid Build Coastguard Worker        )
995*da0073e9SAndroid Build Coastguard Worker        return graph, params, torch_out, None
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker    graph, torch_out = _trace_and_get_graph_from_model(model, args)
998*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_lint(graph)
999*da0073e9SAndroid Build Coastguard Worker    state_dict = torch.jit._unique_state_dict(model)
1000*da0073e9SAndroid Build Coastguard Worker    params = list(state_dict.values())
1001*da0073e9SAndroid Build Coastguard Worker    graph_inputs = list(graph.inputs())
1002*da0073e9SAndroid Build Coastguard Worker    user_input_num = len(graph_inputs) - len(state_dict)
1003*da0073e9SAndroid Build Coastguard Worker    param_names = list(state_dict.keys())
1004*da0073e9SAndroid Build Coastguard Worker    for i, inp in enumerate(graph_inputs):
1005*da0073e9SAndroid Build Coastguard Worker        if i >= user_input_num:
1006*da0073e9SAndroid Build Coastguard Worker            inp.setDebugName(param_names[i - user_input_num])
1007*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_function_substitution(graph)
1008*da0073e9SAndroid Build Coastguard Worker    return graph, params, torch_out, None
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Workerdef _get_named_param_dict(graph, params):
1012*da0073e9SAndroid Build Coastguard Worker    input_and_param_names = [val.debugName() for val in graph.inputs()]
1013*da0073e9SAndroid Build Coastguard Worker    param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
1014*da0073e9SAndroid Build Coastguard Worker    _params_dict = dict(zip(param_names, params))
1015*da0073e9SAndroid Build Coastguard Worker    return _params_dict
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Workerdef _get_example_outputs(model, args):
1019*da0073e9SAndroid Build Coastguard Worker    input_args = copy.deepcopy(args)
1020*da0073e9SAndroid Build Coastguard Worker    input_kwargs = {}
1021*da0073e9SAndroid Build Coastguard Worker    if input_args and isinstance(input_args[-1], dict):
1022*da0073e9SAndroid Build Coastguard Worker        input_kwargs = input_args[-1]
1023*da0073e9SAndroid Build Coastguard Worker        input_args = input_args[:-1]
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker    example_outputs = model(*input_args, **input_kwargs)
1026*da0073e9SAndroid Build Coastguard Worker    if isinstance(example_outputs, list):
1027*da0073e9SAndroid Build Coastguard Worker        example_outputs = [example_outputs]
1028*da0073e9SAndroid Build Coastguard Worker    elif not isinstance(example_outputs, tuple):
1029*da0073e9SAndroid Build Coastguard Worker        example_outputs = (example_outputs,)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    return example_outputs
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker_qtype_vtype_map = {
1035*da0073e9SAndroid Build Coastguard Worker    torch.quint8: torch.uint8,
1036*da0073e9SAndroid Build Coastguard Worker    torch.qint8: torch.int8,
1037*da0073e9SAndroid Build Coastguard Worker    torch.qint32: torch.int32,
1038*da0073e9SAndroid Build Coastguard Worker    torch.quint4x2: torch.int8,
1039*da0073e9SAndroid Build Coastguard Worker}
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Workerdef unpack_quantized_tensor(value, cast_onnx_accepted=True):
1043*da0073e9SAndroid Build Coastguard Worker    if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
1044*da0073e9SAndroid Build Coastguard Worker        q_value_dequantize = value.dequantize()
1045*da0073e9SAndroid Build Coastguard Worker        q_scale = (
1046*da0073e9SAndroid Build Coastguard Worker            torch.tensor(value.q_scale(), dtype=torch.double)
1047*da0073e9SAndroid Build Coastguard Worker            if cast_onnx_accepted
1048*da0073e9SAndroid Build Coastguard Worker            else torch.tensor(value.q_scale(), dtype=torch.float32)
1049*da0073e9SAndroid Build Coastguard Worker        )
1050*da0073e9SAndroid Build Coastguard Worker        q_zero_point = (
1051*da0073e9SAndroid Build Coastguard Worker            torch.tensor(value.q_zero_point(), dtype=torch.int64)
1052*da0073e9SAndroid Build Coastguard Worker            if cast_onnx_accepted
1053*da0073e9SAndroid Build Coastguard Worker            else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype])
1054*da0073e9SAndroid Build Coastguard Worker        )
1055*da0073e9SAndroid Build Coastguard Worker        q_value = q_value_dequantize / q_scale + q_zero_point
1056*da0073e9SAndroid Build Coastguard Worker        q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype])
1057*da0073e9SAndroid Build Coastguard Worker        return q_value, q_scale, q_zero_point
1058*da0073e9SAndroid Build Coastguard Worker    else:
1059*da0073e9SAndroid Build Coastguard Worker        return (value,)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Workerdef _pre_trace_quant_model(model, args):
1063*da0073e9SAndroid Build Coastguard Worker    r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
1064*da0073e9SAndroid Build Coastguard Worker    original model.
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    This is due to https://github.com/pytorch/pytorch/issues/75761.
1067*da0073e9SAndroid Build Coastguard Worker    """
1068*da0073e9SAndroid Build Coastguard Worker    if any(
1069*da0073e9SAndroid Build Coastguard Worker        hasattr(m, "_packed_params") for m in getattr(model, "modules", list)()
1070*da0073e9SAndroid Build Coastguard Worker    ) or any(getattr(arg, "is_quantized", False) for arg in args):
1071*da0073e9SAndroid Build Coastguard Worker        return torch.jit.trace(model, args)
1072*da0073e9SAndroid Build Coastguard Worker    return model
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Workerdef _model_to_graph(
1076*da0073e9SAndroid Build Coastguard Worker    model,
1077*da0073e9SAndroid Build Coastguard Worker    args,
1078*da0073e9SAndroid Build Coastguard Worker    verbose=False,
1079*da0073e9SAndroid Build Coastguard Worker    input_names=None,
1080*da0073e9SAndroid Build Coastguard Worker    output_names=None,
1081*da0073e9SAndroid Build Coastguard Worker    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1082*da0073e9SAndroid Build Coastguard Worker    do_constant_folding=True,
1083*da0073e9SAndroid Build Coastguard Worker    _disable_torch_constant_prop=False,
1084*da0073e9SAndroid Build Coastguard Worker    fixed_batch_size=False,
1085*da0073e9SAndroid Build Coastguard Worker    training=_C_onnx.TrainingMode.EVAL,
1086*da0073e9SAndroid Build Coastguard Worker    dynamic_axes=None,
1087*da0073e9SAndroid Build Coastguard Worker) -> tuple[
1088*da0073e9SAndroid Build Coastguard Worker    _C.Graph,
1089*da0073e9SAndroid Build Coastguard Worker    dict[str, torch.Tensor],
1090*da0073e9SAndroid Build Coastguard Worker    torch.Tensor
1091*da0073e9SAndroid Build Coastguard Worker    | tuple[torch.Tensor, ...]
1092*da0073e9SAndroid Build Coastguard Worker    | list[torch.Tensor]
1093*da0073e9SAndroid Build Coastguard Worker    | dict[str, torch.Tensor]
1094*da0073e9SAndroid Build Coastguard Worker    | Any
1095*da0073e9SAndroid Build Coastguard Worker    | None,
1096*da0073e9SAndroid Build Coastguard Worker]:
1097*da0073e9SAndroid Build Coastguard Worker    """Converts model into an ONNX graph.
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker    Returns:
1100*da0073e9SAndroid Build Coastguard Worker        graph: A TorchScript IR Graph with ONNX nodes.
1101*da0073e9SAndroid Build Coastguard Worker        params_dict: Dict from input param name to param value.
1102*da0073e9SAndroid Build Coastguard Worker        torch_out: The output tensors resulting from the trace of ``model``.
1103*da0073e9SAndroid Build Coastguard Worker            If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`,
1104*da0073e9SAndroid Build Coastguard Worker            this will be None, since we are not doing any tracing.
1105*da0073e9SAndroid Build Coastguard Worker    """
1106*da0073e9SAndroid Build Coastguard Worker    # TODO: can we simplify this to always return a tuple of Tensor or None?
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker    # Special case for common case of passing a single Tensor
1109*da0073e9SAndroid Build Coastguard Worker    if isinstance(args, (torch.Tensor, int, float, bool)):
1110*da0073e9SAndroid Build Coastguard Worker        args = (args,)
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker    model = _pre_trace_quant_model(model, args)
1113*da0073e9SAndroid Build Coastguard Worker    graph, params, torch_out, module = _create_jit_graph(model, args)
1114*da0073e9SAndroid Build Coastguard Worker    params_dict = _get_named_param_dict(graph, params)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker    try:
1117*da0073e9SAndroid Build Coastguard Worker        graph = _optimize_graph(
1118*da0073e9SAndroid Build Coastguard Worker            graph,
1119*da0073e9SAndroid Build Coastguard Worker            operator_export_type,
1120*da0073e9SAndroid Build Coastguard Worker            _disable_torch_constant_prop=_disable_torch_constant_prop,
1121*da0073e9SAndroid Build Coastguard Worker            fixed_batch_size=fixed_batch_size,
1122*da0073e9SAndroid Build Coastguard Worker            params_dict=params_dict,
1123*da0073e9SAndroid Build Coastguard Worker            dynamic_axes=dynamic_axes,
1124*da0073e9SAndroid Build Coastguard Worker            input_names=input_names,
1125*da0073e9SAndroid Build Coastguard Worker            module=module,
1126*da0073e9SAndroid Build Coastguard Worker        )
1127*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
1128*da0073e9SAndroid Build Coastguard Worker        torch.onnx.log("Torch IR graph at exception: ", graph)
1129*da0073e9SAndroid Build Coastguard Worker        raise
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker    is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
1132*da0073e9SAndroid Build Coastguard Worker    if is_script:
1133*da0073e9SAndroid Build Coastguard Worker        example_outputs = _get_example_outputs(model, args)
1134*da0073e9SAndroid Build Coastguard Worker        example_outputs_final = ()
1135*da0073e9SAndroid Build Coastguard Worker        for example_output in example_outputs:
1136*da0073e9SAndroid Build Coastguard Worker            example_outputs_final += unpack_quantized_tensor(example_output)
1137*da0073e9SAndroid Build Coastguard Worker        out_vars, desc = torch.jit._flatten(example_outputs_final)
1138*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_assign_output_shape(
1139*da0073e9SAndroid Build Coastguard Worker            graph,
1140*da0073e9SAndroid Build Coastguard Worker            out_vars,
1141*da0073e9SAndroid Build Coastguard Worker            desc,
1142*da0073e9SAndroid Build Coastguard Worker            GLOBALS.onnx_shape_inference,
1143*da0073e9SAndroid Build Coastguard Worker            is_script,
1144*da0073e9SAndroid Build Coastguard Worker            GLOBALS.export_onnx_opset_version,
1145*da0073e9SAndroid Build Coastguard Worker        )
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker    # NB: ONNX requires complete information about output types, which might be
1148*da0073e9SAndroid Build Coastguard Worker    # erased by some optimizations, so we need to set it explicitly again.
1149*da0073e9SAndroid Build Coastguard Worker    else:
1150*da0073e9SAndroid Build Coastguard Worker        if not isinstance(torch_out, (list, tuple)):
1151*da0073e9SAndroid Build Coastguard Worker            output_wrapped = [torch_out]
1152*da0073e9SAndroid Build Coastguard Worker        else:
1153*da0073e9SAndroid Build Coastguard Worker            output_wrapped = torch_out  # type: ignore[assignment]
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker        output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped))
1156*da0073e9SAndroid Build Coastguard Worker        # assign_output_shape pass is not compatible with quantized outputs.
1157*da0073e9SAndroid Build Coastguard Worker        # Quantized outputs are flattened to 3 values in ONNX, while packed as
1158*da0073e9SAndroid Build Coastguard Worker        # single value in PyTorch.
1159*da0073e9SAndroid Build Coastguard Worker        if not any(getattr(out, "is_quantized", False) for out in output_tensors):
1160*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_onnx_assign_output_shape(
1161*da0073e9SAndroid Build Coastguard Worker                graph,
1162*da0073e9SAndroid Build Coastguard Worker                output_tensors,
1163*da0073e9SAndroid Build Coastguard Worker                out_desc,
1164*da0073e9SAndroid Build Coastguard Worker                GLOBALS.onnx_shape_inference,
1165*da0073e9SAndroid Build Coastguard Worker                is_script,
1166*da0073e9SAndroid Build Coastguard Worker                GLOBALS.export_onnx_opset_version,
1167*da0073e9SAndroid Build Coastguard Worker            )
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker    _set_input_and_output_names(graph, input_names, output_names)
1170*da0073e9SAndroid Build Coastguard Worker    params_dict = _get_named_param_dict(graph, params)
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker    if (
1173*da0073e9SAndroid Build Coastguard Worker        do_constant_folding
1174*da0073e9SAndroid Build Coastguard Worker        and GLOBALS.export_onnx_opset_version
1175*da0073e9SAndroid Build Coastguard Worker        >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
1176*da0073e9SAndroid Build Coastguard Worker    ):
1177*da0073e9SAndroid Build Coastguard Worker        if training is None or training == _C_onnx.TrainingMode.EVAL:
1178*da0073e9SAndroid Build Coastguard Worker            params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict)
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker        params_dict = _C._jit_pass_onnx_constant_fold(
1181*da0073e9SAndroid Build Coastguard Worker            graph, params_dict, GLOBALS.export_onnx_opset_version
1182*da0073e9SAndroid Build Coastguard Worker        )
1183*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker    if GLOBALS.onnx_shape_inference:
1186*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_graph_shape_type_inference(
1187*da0073e9SAndroid Build Coastguard Worker            graph, params_dict, GLOBALS.export_onnx_opset_version
1188*da0073e9SAndroid Build Coastguard Worker        )
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker    params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker    # For ONNX opset < 9, constants only have three data types: float16, float, double.
1193*da0073e9SAndroid Build Coastguard Worker    # In this pass transform constants of other data types to float/double + cast operator.
1194*da0073e9SAndroid Build Coastguard Worker    if GLOBALS.export_onnx_opset_version < 9:
1195*da0073e9SAndroid Build Coastguard Worker        _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker    params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
1198*da0073e9SAndroid Build Coastguard Worker    _C._jit_decay_packed_param_input_types(graph)
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker    # If output names lack a proper name and are identified only by their unique
1201*da0073e9SAndroid Build Coastguard Worker    # give them a legible name for debugging purposes
1202*da0073e9SAndroid Build Coastguard Worker    _apply_friendly_debug_names(graph, params_dict)
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    return graph, params_dict, torch_out
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker@torch._disable_dynamo
1208*da0073e9SAndroid Build Coastguard Worker@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead")
1209*da0073e9SAndroid Build Coastguard Workerdef export_to_pretty_string(
1210*da0073e9SAndroid Build Coastguard Worker    model,
1211*da0073e9SAndroid Build Coastguard Worker    args,
1212*da0073e9SAndroid Build Coastguard Worker    export_params=True,
1213*da0073e9SAndroid Build Coastguard Worker    verbose=False,
1214*da0073e9SAndroid Build Coastguard Worker    training=_C_onnx.TrainingMode.EVAL,
1215*da0073e9SAndroid Build Coastguard Worker    input_names=None,
1216*da0073e9SAndroid Build Coastguard Worker    output_names=None,
1217*da0073e9SAndroid Build Coastguard Worker    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1218*da0073e9SAndroid Build Coastguard Worker    export_type=None,
1219*da0073e9SAndroid Build Coastguard Worker    google_printer=False,
1220*da0073e9SAndroid Build Coastguard Worker    opset_version=None,
1221*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs=None,
1222*da0073e9SAndroid Build Coastguard Worker    custom_opsets=None,
1223*da0073e9SAndroid Build Coastguard Worker    add_node_names=True,
1224*da0073e9SAndroid Build Coastguard Worker    do_constant_folding=True,
1225*da0073e9SAndroid Build Coastguard Worker    dynamic_axes=None,
1226*da0073e9SAndroid Build Coastguard Worker):
1227*da0073e9SAndroid Build Coastguard Worker    """Similar to :func:`export`, but returns a text representation of the ONNX model.
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker    Only differences in args listed below. All other args are the same
1230*da0073e9SAndroid Build Coastguard Worker    as :func:`export`.
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    Args:
1233*da0073e9SAndroid Build Coastguard Worker        add_node_names (bool, default True): Whether or not to set
1234*da0073e9SAndroid Build Coastguard Worker            NodeProto.name. This makes no difference unless
1235*da0073e9SAndroid Build Coastguard Worker            ``google_printer=True``.
1236*da0073e9SAndroid Build Coastguard Worker        google_printer (bool, default False): If False, will return a custom,
1237*da0073e9SAndroid Build Coastguard Worker            compact representation of the model. If True will return the
1238*da0073e9SAndroid Build Coastguard Worker            protobuf's `Message::DebugString()`, which is more verbose.
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker    Returns:
1241*da0073e9SAndroid Build Coastguard Worker        A UTF-8 str containing a human-readable representation of the ONNX model.
1242*da0073e9SAndroid Build Coastguard Worker    """
1243*da0073e9SAndroid Build Coastguard Worker    if opset_version is None:
1244*da0073e9SAndroid Build Coastguard Worker        opset_version = _constants.ONNX_DEFAULT_OPSET
1245*da0073e9SAndroid Build Coastguard Worker    if custom_opsets is None:
1246*da0073e9SAndroid Build Coastguard Worker        custom_opsets = {}
1247*da0073e9SAndroid Build Coastguard Worker    GLOBALS.export_onnx_opset_version = opset_version
1248*da0073e9SAndroid Build Coastguard Worker    GLOBALS.operator_export_type = operator_export_type
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker    with exporter_context(model, training, verbose):
1251*da0073e9SAndroid Build Coastguard Worker        val_keep_init_as_ip = _decide_keep_init_as_input(
1252*da0073e9SAndroid Build Coastguard Worker            keep_initializers_as_inputs, operator_export_type, opset_version
1253*da0073e9SAndroid Build Coastguard Worker        )
1254*da0073e9SAndroid Build Coastguard Worker        val_add_node_names = _decide_add_node_names(
1255*da0073e9SAndroid Build Coastguard Worker            add_node_names, operator_export_type
1256*da0073e9SAndroid Build Coastguard Worker        )
1257*da0073e9SAndroid Build Coastguard Worker        val_do_constant_folding = _decide_constant_folding(
1258*da0073e9SAndroid Build Coastguard Worker            do_constant_folding, operator_export_type, training
1259*da0073e9SAndroid Build Coastguard Worker        )
1260*da0073e9SAndroid Build Coastguard Worker        args = _decide_input_format(model, args)
1261*da0073e9SAndroid Build Coastguard Worker        graph, params_dict, torch_out = _model_to_graph(
1262*da0073e9SAndroid Build Coastguard Worker            model,
1263*da0073e9SAndroid Build Coastguard Worker            args,
1264*da0073e9SAndroid Build Coastguard Worker            verbose,
1265*da0073e9SAndroid Build Coastguard Worker            input_names,
1266*da0073e9SAndroid Build Coastguard Worker            output_names,
1267*da0073e9SAndroid Build Coastguard Worker            operator_export_type,
1268*da0073e9SAndroid Build Coastguard Worker            val_do_constant_folding,
1269*da0073e9SAndroid Build Coastguard Worker            training=training,
1270*da0073e9SAndroid Build Coastguard Worker            dynamic_axes=dynamic_axes,
1271*da0073e9SAndroid Build Coastguard Worker        )
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker        return graph._pretty_print_onnx(  # type: ignore[attr-defined]
1274*da0073e9SAndroid Build Coastguard Worker            params_dict,
1275*da0073e9SAndroid Build Coastguard Worker            opset_version,
1276*da0073e9SAndroid Build Coastguard Worker            False,
1277*da0073e9SAndroid Build Coastguard Worker            operator_export_type,
1278*da0073e9SAndroid Build Coastguard Worker            google_printer,
1279*da0073e9SAndroid Build Coastguard Worker            val_keep_init_as_ip,
1280*da0073e9SAndroid Build Coastguard Worker            custom_opsets,
1281*da0073e9SAndroid Build Coastguard Worker            val_add_node_names,
1282*da0073e9SAndroid Build Coastguard Worker        )
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker@_deprecation.deprecated("2.5", "the future", "avoid using this function")
1286*da0073e9SAndroid Build Coastguard Workerdef unconvertible_ops(
1287*da0073e9SAndroid Build Coastguard Worker    model,
1288*da0073e9SAndroid Build Coastguard Worker    args,
1289*da0073e9SAndroid Build Coastguard Worker    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
1290*da0073e9SAndroid Build Coastguard Worker    opset_version: int | None = None,
1291*da0073e9SAndroid Build Coastguard Worker) -> tuple[_C.Graph, list[str]]:
1292*da0073e9SAndroid Build Coastguard Worker    """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`.
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker    The list is approximated because some ops may be removed during the conversion
1295*da0073e9SAndroid Build Coastguard Worker    process and don't need to be converted. Some other ops may have partial support
1296*da0073e9SAndroid Build Coastguard Worker    that will fail conversion with particular inputs. Please open a Github Issue
1297*da0073e9SAndroid Build Coastguard Worker    for op support requests.
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker    Args:
1300*da0073e9SAndroid Build Coastguard Worker        model: Same as the `model` parameter in :func:`torch.onnx.export`.
1301*da0073e9SAndroid Build Coastguard Worker        args: Same as the `args` parameter in :func:`torch.onnx.export`.
1302*da0073e9SAndroid Build Coastguard Worker        training: Same as the `training` parameter in :func:`torch.onnx.export`.
1303*da0073e9SAndroid Build Coastguard Worker        opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`.
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker    Returns:
1306*da0073e9SAndroid Build Coastguard Worker        The JIT graph and a list of unconvertible ops in the format of "domain::op".
1307*da0073e9SAndroid Build Coastguard Worker    """
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker    opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET
1310*da0073e9SAndroid Build Coastguard Worker    GLOBALS.export_onnx_opset_version = opset_version
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker    try:
1313*da0073e9SAndroid Build Coastguard Worker        with exporter_context(model, training, verbose=False):
1314*da0073e9SAndroid Build Coastguard Worker            # Create a mostly clean JIT graph that contains the plain aten and
1315*da0073e9SAndroid Build Coastguard Worker            # other ops we can check with the symbolic registry.
1316*da0073e9SAndroid Build Coastguard Worker            # NOTE: We don't want to actually convert any ops to ONNX or run any
1317*da0073e9SAndroid Build Coastguard Worker            # symbolic functions because there is a higher chance that a pass
1318*da0073e9SAndroid Build Coastguard Worker            # fails or an unconvertible op messes up the graph during ONNX conversion.
1319*da0073e9SAndroid Build Coastguard Worker            # This way we can always generate a list just by looking at the names
1320*da0073e9SAndroid Build Coastguard Worker            # of the ops in the graph.
1321*da0073e9SAndroid Build Coastguard Worker            args = _decide_input_format(model, args)
1322*da0073e9SAndroid Build Coastguard Worker            model = _pre_trace_quant_model(model, args)
1323*da0073e9SAndroid Build Coastguard Worker            graph, _, _, module = _create_jit_graph(model, args)
1324*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_inline(graph)
1325*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
1326*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_erase_number_types(graph)
1327*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
1328*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
1329*da0073e9SAndroid Build Coastguard Worker        raise errors.OnnxExporterError(
1330*da0073e9SAndroid Build Coastguard Worker            "Failed to discover unconvertible ops because of errors during the JIT graph "
1331*da0073e9SAndroid Build Coastguard Worker            "generation process."
1332*da0073e9SAndroid Build Coastguard Worker        ) from e
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker    unsupported_ops = []
1335*da0073e9SAndroid Build Coastguard Worker    for node in graph.nodes():
1336*da0073e9SAndroid Build Coastguard Worker        domain_op = node.kind()
1337*da0073e9SAndroid Build Coastguard Worker        if domain_op.startswith(("onnx::", "prim::")):
1338*da0073e9SAndroid Build Coastguard Worker            # We consider onnx and prim ops as supported ops, even though some "prim"
1339*da0073e9SAndroid Build Coastguard Worker            # ops are not implemented as symbolic functions, because they may be
1340*da0073e9SAndroid Build Coastguard Worker            # eliminated in the conversion passes. Users may still see errors caused
1341*da0073e9SAndroid Build Coastguard Worker            # by prim ops even though they don't show up in the list.
1342*da0073e9SAndroid Build Coastguard Worker            continue
1343*da0073e9SAndroid Build Coastguard Worker        if not registration.registry.is_registered_op(
1344*da0073e9SAndroid Build Coastguard Worker            domain_op.rstrip("_"), opset_version
1345*da0073e9SAndroid Build Coastguard Worker        ):
1346*da0073e9SAndroid Build Coastguard Worker            # We consider all registered ops supported, even though some of them are
1347*da0073e9SAndroid Build Coastguard Worker            # only partially supported, because there is not yet a good way to check
1348*da0073e9SAndroid Build Coastguard Worker            # if an op is fully supported.
1349*da0073e9SAndroid Build Coastguard Worker            # TODO(justinchuby): Create a way to check if an op is fully supported.
1350*da0073e9SAndroid Build Coastguard Worker            unsupported_ops.append(domain_op)
1351*da0073e9SAndroid Build Coastguard Worker    return graph, unsupported_ops
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Workerdef _setup_trace_module_map(
1355*da0073e9SAndroid Build Coastguard Worker    model: torch.nn.Module | torch.jit.ScriptModule,
1356*da0073e9SAndroid Build Coastguard Worker    export_modules_as_functions: bool | Collection[type[torch.nn.Module]],
1357*da0073e9SAndroid Build Coastguard Worker) -> set[str]:
1358*da0073e9SAndroid Build Coastguard Worker    def __register_attribute_hook():
1359*da0073e9SAndroid Build Coastguard Worker        attr_name = "_onnx_attrs"
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker        def _track_module_attributes_forward_pre_hook(module, input):
1362*da0073e9SAndroid Build Coastguard Worker            setattr(module, attr_name, _get_module_attributes(module))
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker        def _track_module_attributes_forward_hook(module, input, output):
1365*da0073e9SAndroid Build Coastguard Worker            tracing_state = _C._get_tracing_state()
1366*da0073e9SAndroid Build Coastguard Worker            if not tracing_state:
1367*da0073e9SAndroid Build Coastguard Worker                return
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker            graph = tracing_state.graph()
1370*da0073e9SAndroid Build Coastguard Worker            onnx_attrs = {}
1371*da0073e9SAndroid Build Coastguard Worker            if hasattr(module, attr_name):
1372*da0073e9SAndroid Build Coastguard Worker                onnx_attrs = getattr(module, attr_name)
1373*da0073e9SAndroid Build Coastguard Worker                delattr(module, attr_name)
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker        for m in model.modules():
1378*da0073e9SAndroid Build Coastguard Worker            m.register_forward_hook(_track_module_attributes_forward_hook)
1379*da0073e9SAndroid Build Coastguard Worker            m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker    def _unqualified_variable_name(qualified_name: str) -> str:
1382*da0073e9SAndroid Build Coastguard Worker        """
1383*da0073e9SAndroid Build Coastguard Worker        Parse qualified variable name and return the unqualified version.
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        Pure numeric atoms are considered inadequate, so this function will look past them,
1386*da0073e9SAndroid Build Coastguard Worker        and start from the first non-numeric atom.
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker        Example:
1389*da0073e9SAndroid Build Coastguard Worker            >>> _unqualified_variable_name("__main__.Foo.bar")
1390*da0073e9SAndroid Build Coastguard Worker            'bar'
1391*da0073e9SAndroid Build Coastguard Worker            >>> _unqualified_variable_name("__main__.Foo.bar.0")
1392*da0073e9SAndroid Build Coastguard Worker            'bar.0'
1393*da0073e9SAndroid Build Coastguard Worker        """
1394*da0073e9SAndroid Build Coastguard Worker        name_atoms = qualified_name.split(".")
1395*da0073e9SAndroid Build Coastguard Worker        for i, atom in reversed(list(enumerate(name_atoms))):
1396*da0073e9SAndroid Build Coastguard Worker            if not atom.isnumeric():
1397*da0073e9SAndroid Build Coastguard Worker                return ".".join(name_atoms[i:])
1398*da0073e9SAndroid Build Coastguard Worker        return qualified_name
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker    trace_module_map = {
1401*da0073e9SAndroid Build Coastguard Worker        _m: torch._C._jit_onnx_create_full_scope_name(
1402*da0073e9SAndroid Build Coastguard Worker            torch.typename(type(_m)), _unqualified_variable_name(_n)
1403*da0073e9SAndroid Build Coastguard Worker        )
1404*da0073e9SAndroid Build Coastguard Worker        for _n, _m in model.named_modules()
1405*da0073e9SAndroid Build Coastguard Worker    }
1406*da0073e9SAndroid Build Coastguard Worker    torch.jit._trace._trace_module_map = trace_module_map
1407*da0073e9SAndroid Build Coastguard Worker    if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
1408*da0073e9SAndroid Build Coastguard Worker        module_typenames = {torch.typename(type(module)) for module in trace_module_map}
1409*da0073e9SAndroid Build Coastguard Worker    elif isinstance(export_modules_as_functions, set) and export_modules_as_functions:
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        def _find_typename(v):
1412*da0073e9SAndroid Build Coastguard Worker            if isinstance(v, type):
1413*da0073e9SAndroid Build Coastguard Worker                return torch.typename(v)
1414*da0073e9SAndroid Build Coastguard Worker            else:
1415*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1416*da0073e9SAndroid Build Coastguard Worker                    "Only type of the `nn.Module` should be "
1417*da0073e9SAndroid Build Coastguard Worker                    "passed in the set for argument `export_modules_as_functions`. "
1418*da0073e9SAndroid Build Coastguard Worker                    f"Got `{type(v).__name__}`."
1419*da0073e9SAndroid Build Coastguard Worker                )
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker        module_typenames = {_find_typename(v) for v in export_modules_as_functions}
1422*da0073e9SAndroid Build Coastguard Worker    else:
1423*da0073e9SAndroid Build Coastguard Worker        module_typenames = set()
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker    if module_typenames:
1426*da0073e9SAndroid Build Coastguard Worker        __register_attribute_hook()
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker    return module_typenames
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Workerdef _reset_trace_module_map():
1432*da0073e9SAndroid Build Coastguard Worker    torch.jit._trace._trace_module_map = None
1433*da0073e9SAndroid Build Coastguard Worker    _C._jit_pass_onnx_clear_scope_records()
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Workerdef _get_module_attributes(module):
1437*da0073e9SAndroid Build Coastguard Worker    annotations = typing.get_type_hints(type(module))
1438*da0073e9SAndroid Build Coastguard Worker    base_m_annotations = typing.get_type_hints(torch.nn.Module)
1439*da0073e9SAndroid Build Coastguard Worker    [annotations.pop(k, None) for k in base_m_annotations]
1440*da0073e9SAndroid Build Coastguard Worker    # Check whether module attributes can be accessed. Some classes
1441*da0073e9SAndroid Build Coastguard Worker    # define attributes but don't provide access to them in their
1442*da0073e9SAndroid Build Coastguard Worker    # constructor.
1443*da0073e9SAndroid Build Coastguard Worker    #
1444*da0073e9SAndroid Build Coastguard Worker    # For example, torch.nn.Embedding has the `freeze` variable and its
1445*da0073e9SAndroid Build Coastguard Worker    # type specified in the class but the attribute is not created in the
1446*da0073e9SAndroid Build Coastguard Worker    # constructor. In other words, there is no `self.freeze = <True | False>`
1447*da0073e9SAndroid Build Coastguard Worker    # in the constructor.
1448*da0073e9SAndroid Build Coastguard Worker    #
1449*da0073e9SAndroid Build Coastguard Worker    # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120
1450*da0073e9SAndroid Build Coastguard Worker    attrs = {}
1451*da0073e9SAndroid Build Coastguard Worker    for k in annotations:
1452*da0073e9SAndroid Build Coastguard Worker        try:
1453*da0073e9SAndroid Build Coastguard Worker            attrs[k] = getattr(module, k)
1454*da0073e9SAndroid Build Coastguard Worker        except AttributeError:
1455*da0073e9SAndroid Build Coastguard Worker            torch.onnx.log(f"Skipping module attribute '{k}'")
1456*da0073e9SAndroid Build Coastguard Worker            continue
1457*da0073e9SAndroid Build Coastguard Worker    return attrs
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Workerdef _export(
1461*da0073e9SAndroid Build Coastguard Worker    model,
1462*da0073e9SAndroid Build Coastguard Worker    args,
1463*da0073e9SAndroid Build Coastguard Worker    f,
1464*da0073e9SAndroid Build Coastguard Worker    export_params=True,
1465*da0073e9SAndroid Build Coastguard Worker    verbose=False,
1466*da0073e9SAndroid Build Coastguard Worker    training=_C_onnx.TrainingMode.EVAL,
1467*da0073e9SAndroid Build Coastguard Worker    input_names=None,
1468*da0073e9SAndroid Build Coastguard Worker    output_names=None,
1469*da0073e9SAndroid Build Coastguard Worker    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1470*da0073e9SAndroid Build Coastguard Worker    export_type=None,
1471*da0073e9SAndroid Build Coastguard Worker    opset_version=None,
1472*da0073e9SAndroid Build Coastguard Worker    do_constant_folding=True,
1473*da0073e9SAndroid Build Coastguard Worker    dynamic_axes=None,
1474*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs=None,
1475*da0073e9SAndroid Build Coastguard Worker    fixed_batch_size=False,
1476*da0073e9SAndroid Build Coastguard Worker    custom_opsets=None,
1477*da0073e9SAndroid Build Coastguard Worker    add_node_names=True,
1478*da0073e9SAndroid Build Coastguard Worker    onnx_shape_inference=True,
1479*da0073e9SAndroid Build Coastguard Worker    export_modules_as_functions: Any = False,
1480*da0073e9SAndroid Build Coastguard Worker    autograd_inlining=True,
1481*da0073e9SAndroid Build Coastguard Worker):
1482*da0073e9SAndroid Build Coastguard Worker    assert GLOBALS.in_onnx_export is False
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker    if export_type is None:
1485*da0073e9SAndroid Build Coastguard Worker        export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker    if isinstance(model, torch.nn.DataParallel):
1488*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
1489*da0073e9SAndroid Build Coastguard Worker            "torch.nn.DataParallel is not supported by ONNX "
1490*da0073e9SAndroid Build Coastguard Worker            "exporter, please use 'attribute' module to "
1491*da0073e9SAndroid Build Coastguard Worker            "unwrap model from torch.nn.DataParallel. Try "
1492*da0073e9SAndroid Build Coastguard Worker            "torch.onnx.export(model.module, ...)"
1493*da0073e9SAndroid Build Coastguard Worker        )
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker    GLOBALS.onnx_shape_inference = onnx_shape_inference
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker    if opset_version is None:
1498*da0073e9SAndroid Build Coastguard Worker        opset_version = _constants.ONNX_DEFAULT_OPSET
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    # torch.onnx.export does not support opset versions >=18
1501*da0073e9SAndroid Build Coastguard Worker    if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET:
1502*da0073e9SAndroid Build Coastguard Worker        # We do not want to fail because we should still allow users to create
1503*da0073e9SAndroid Build Coastguard Worker        # custom symbolic functions for opset>17
1504*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
1505*da0073e9SAndroid Build Coastguard Worker            f"Exporting to ONNX opset version {opset_version} is not supported. "
1506*da0073e9SAndroid Build Coastguard Worker            f"by 'torch.onnx.export()'. "
1507*da0073e9SAndroid Build Coastguard Worker            f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. "
1508*da0073e9SAndroid Build Coastguard Worker            f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ",
1509*da0073e9SAndroid Build Coastguard Worker            category=errors.OnnxExporterWarning,
1510*da0073e9SAndroid Build Coastguard Worker        )
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker    if export_modules_as_functions and opset_version < 15:
1513*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
1514*da0073e9SAndroid Build Coastguard Worker            "`export_modules_as_functions` is not supported for `opset_version` < 15."
1515*da0073e9SAndroid Build Coastguard Worker            "This is because `opset_version` < 15 implies IR version < 8, which means "
1516*da0073e9SAndroid Build Coastguard Worker            "no local function support. "
1517*da0073e9SAndroid Build Coastguard Worker        )
1518*da0073e9SAndroid Build Coastguard Worker    if not operator_export_type:
1519*da0073e9SAndroid Build Coastguard Worker        operator_export_type = _C_onnx.OperatorExportTypes.ONNX
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker    # By default, training=TrainingMode.EVAL,
1522*da0073e9SAndroid Build Coastguard Worker    # which is good because running a model in training mode could result in
1523*da0073e9SAndroid Build Coastguard Worker    # internal buffers getting updated, dropout getting applied, etc.
1524*da0073e9SAndroid Build Coastguard Worker    # If you really know what you're doing, you can turn
1525*da0073e9SAndroid Build Coastguard Worker    # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
1526*da0073e9SAndroid Build Coastguard Worker    # (to preserve whatever the original training mode was.)
1527*da0073e9SAndroid Build Coastguard Worker    GLOBALS.export_onnx_opset_version = opset_version
1528*da0073e9SAndroid Build Coastguard Worker    GLOBALS.operator_export_type = operator_export_type
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker    try:
1531*da0073e9SAndroid Build Coastguard Worker        GLOBALS.in_onnx_export = True
1532*da0073e9SAndroid Build Coastguard Worker        _autograd_inlining_previous = GLOBALS.autograd_inlining
1533*da0073e9SAndroid Build Coastguard Worker        GLOBALS.autograd_inlining = autograd_inlining
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker        module_typenames_to_export_as_functions: set[str] = set()
1536*da0073e9SAndroid Build Coastguard Worker        if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
1537*da0073e9SAndroid Build Coastguard Worker            module_typenames_to_export_as_functions = _setup_trace_module_map(
1538*da0073e9SAndroid Build Coastguard Worker                model, export_modules_as_functions
1539*da0073e9SAndroid Build Coastguard Worker            )
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Worker        with exporter_context(model, training, verbose):
1542*da0073e9SAndroid Build Coastguard Worker            val_keep_init_as_ip = _decide_keep_init_as_input(
1543*da0073e9SAndroid Build Coastguard Worker                keep_initializers_as_inputs,
1544*da0073e9SAndroid Build Coastguard Worker                operator_export_type,
1545*da0073e9SAndroid Build Coastguard Worker                opset_version,
1546*da0073e9SAndroid Build Coastguard Worker            )
1547*da0073e9SAndroid Build Coastguard Worker            val_add_node_names = _decide_add_node_names(
1548*da0073e9SAndroid Build Coastguard Worker                add_node_names, operator_export_type
1549*da0073e9SAndroid Build Coastguard Worker            )
1550*da0073e9SAndroid Build Coastguard Worker            val_do_constant_folding = _decide_constant_folding(
1551*da0073e9SAndroid Build Coastguard Worker                do_constant_folding, operator_export_type, training
1552*da0073e9SAndroid Build Coastguard Worker            )
1553*da0073e9SAndroid Build Coastguard Worker            # Normally f can be a file-like object, but for large models, the external data format requires a
1554*da0073e9SAndroid Build Coastguard Worker            # valid `model_file_location`. Code in export.cpp will enforce this.
1555*da0073e9SAndroid Build Coastguard Worker            if isinstance(f, str):
1556*da0073e9SAndroid Build Coastguard Worker                model_file_location = f
1557*da0073e9SAndroid Build Coastguard Worker            else:
1558*da0073e9SAndroid Build Coastguard Worker                model_file_location = ""
1559*da0073e9SAndroid Build Coastguard Worker            args = _decide_input_format(model, args)
1560*da0073e9SAndroid Build Coastguard Worker            if dynamic_axes is None:
1561*da0073e9SAndroid Build Coastguard Worker                dynamic_axes = {}
1562*da0073e9SAndroid Build Coastguard Worker            _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker            graph, params_dict, torch_out = _model_to_graph(
1565*da0073e9SAndroid Build Coastguard Worker                model,
1566*da0073e9SAndroid Build Coastguard Worker                args,
1567*da0073e9SAndroid Build Coastguard Worker                verbose,
1568*da0073e9SAndroid Build Coastguard Worker                input_names,
1569*da0073e9SAndroid Build Coastguard Worker                output_names,
1570*da0073e9SAndroid Build Coastguard Worker                operator_export_type,
1571*da0073e9SAndroid Build Coastguard Worker                val_do_constant_folding,
1572*da0073e9SAndroid Build Coastguard Worker                fixed_batch_size=fixed_batch_size,
1573*da0073e9SAndroid Build Coastguard Worker                training=training,
1574*da0073e9SAndroid Build Coastguard Worker                dynamic_axes=dynamic_axes,
1575*da0073e9SAndroid Build Coastguard Worker            )
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker            # TODO: Don't allocate a in-memory string for the protobuf
1578*da0073e9SAndroid Build Coastguard Worker            defer_weight_export = (
1579*da0073e9SAndroid Build Coastguard Worker                export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
1580*da0073e9SAndroid Build Coastguard Worker            )
1581*da0073e9SAndroid Build Coastguard Worker            if custom_opsets is None:
1582*da0073e9SAndroid Build Coastguard Worker                custom_opsets = {}
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
1585*da0073e9SAndroid Build Coastguard Worker            node_attr_to_name = {}  # type: ignore[var-annotated]
1586*da0073e9SAndroid Build Coastguard Worker            if module_typenames_to_export_as_functions:
1587*da0073e9SAndroid Build Coastguard Worker                # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
1588*da0073e9SAndroid Build Coastguard Worker                node_attr_to_name = _C._jit_pass_onnx_function_extraction(
1589*da0073e9SAndroid Build Coastguard Worker                    graph,
1590*da0073e9SAndroid Build Coastguard Worker                    module_typenames_to_export_as_functions,
1591*da0073e9SAndroid Build Coastguard Worker                    list(params_dict.keys()),
1592*da0073e9SAndroid Build Coastguard Worker                )
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker            if keep_initializers_as_inputs is not True:
1595*da0073e9SAndroid Build Coastguard Worker                params_dict = _C._jit_pass_onnx_deduplicate_initializers(  # type: ignore[assignment]
1596*da0073e9SAndroid Build Coastguard Worker                    graph,
1597*da0073e9SAndroid Build Coastguard Worker                    params_dict,  # type: ignore[arg-type]
1598*da0073e9SAndroid Build Coastguard Worker                    getattr(model, "training", False),  # type: ignore[arg-type]
1599*da0073e9SAndroid Build Coastguard Worker                )
1600*da0073e9SAndroid Build Coastguard Worker            _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
1601*da0073e9SAndroid Build Coastguard Worker            if export_params:
1602*da0073e9SAndroid Build Coastguard Worker                (
1603*da0073e9SAndroid Build Coastguard Worker                    proto,
1604*da0073e9SAndroid Build Coastguard Worker                    export_map,
1605*da0073e9SAndroid Build Coastguard Worker                    val_use_external_data_format,
1606*da0073e9SAndroid Build Coastguard Worker                    node_names,
1607*da0073e9SAndroid Build Coastguard Worker                ) = graph._export_onnx(  # type: ignore[attr-defined]
1608*da0073e9SAndroid Build Coastguard Worker                    params_dict,
1609*da0073e9SAndroid Build Coastguard Worker                    opset_version,
1610*da0073e9SAndroid Build Coastguard Worker                    dynamic_axes,
1611*da0073e9SAndroid Build Coastguard Worker                    defer_weight_export,
1612*da0073e9SAndroid Build Coastguard Worker                    operator_export_type,
1613*da0073e9SAndroid Build Coastguard Worker                    not verbose,
1614*da0073e9SAndroid Build Coastguard Worker                    val_keep_init_as_ip,
1615*da0073e9SAndroid Build Coastguard Worker                    custom_opsets,
1616*da0073e9SAndroid Build Coastguard Worker                    val_add_node_names,
1617*da0073e9SAndroid Build Coastguard Worker                    model_file_location,
1618*da0073e9SAndroid Build Coastguard Worker                    node_attr_to_name,
1619*da0073e9SAndroid Build Coastguard Worker                )
1620*da0073e9SAndroid Build Coastguard Worker            else:
1621*da0073e9SAndroid Build Coastguard Worker                (
1622*da0073e9SAndroid Build Coastguard Worker                    proto,
1623*da0073e9SAndroid Build Coastguard Worker                    export_map,
1624*da0073e9SAndroid Build Coastguard Worker                    val_use_external_data_format,
1625*da0073e9SAndroid Build Coastguard Worker                    node_names,
1626*da0073e9SAndroid Build Coastguard Worker                ) = graph._export_onnx(  # type: ignore[attr-defined]
1627*da0073e9SAndroid Build Coastguard Worker                    {},
1628*da0073e9SAndroid Build Coastguard Worker                    opset_version,
1629*da0073e9SAndroid Build Coastguard Worker                    dynamic_axes,
1630*da0073e9SAndroid Build Coastguard Worker                    False,
1631*da0073e9SAndroid Build Coastguard Worker                    operator_export_type,
1632*da0073e9SAndroid Build Coastguard Worker                    not verbose,
1633*da0073e9SAndroid Build Coastguard Worker                    val_keep_init_as_ip,
1634*da0073e9SAndroid Build Coastguard Worker                    custom_opsets,
1635*da0073e9SAndroid Build Coastguard Worker                    val_add_node_names,
1636*da0073e9SAndroid Build Coastguard Worker                    model_file_location,
1637*da0073e9SAndroid Build Coastguard Worker                    node_attr_to_name,
1638*da0073e9SAndroid Build Coastguard Worker                )
1639*da0073e9SAndroid Build Coastguard Worker            # insert function_proto into model_proto.
1640*da0073e9SAndroid Build Coastguard Worker            proto = onnx_proto_utils._add_onnxscript_fn(
1641*da0073e9SAndroid Build Coastguard Worker                proto,
1642*da0073e9SAndroid Build Coastguard Worker                custom_opsets,
1643*da0073e9SAndroid Build Coastguard Worker            )
1644*da0073e9SAndroid Build Coastguard Worker            if verbose:
1645*da0073e9SAndroid Build Coastguard Worker                torch.onnx.log("Exported graph: ", graph)
1646*da0073e9SAndroid Build Coastguard Worker            onnx_proto_utils._export_file(proto, f, export_type, export_map)
1647*da0073e9SAndroid Build Coastguard Worker    finally:
1648*da0073e9SAndroid Build Coastguard Worker        assert GLOBALS.in_onnx_export
1649*da0073e9SAndroid Build Coastguard Worker        GLOBALS.in_onnx_export = False
1650*da0073e9SAndroid Build Coastguard Worker        GLOBALS.autograd_inlining = _autograd_inlining_previous
1651*da0073e9SAndroid Build Coastguard Worker        _reset_trace_module_map()
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker    return torch_out
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Workerdef _apply_friendly_debug_names(graph, params):
1657*da0073e9SAndroid Build Coastguard Worker    for n in graph.nodes():
1658*da0073e9SAndroid Build Coastguard Worker        for v in n.inputs():
1659*da0073e9SAndroid Build Coastguard Worker            old_name = v.debugName()
1660*da0073e9SAndroid Build Coastguard Worker            if old_name != str(v.unique()):
1661*da0073e9SAndroid Build Coastguard Worker                continue
1662*da0073e9SAndroid Build Coastguard Worker            new_name = f"{n.kind()}_{v.unique()}"
1663*da0073e9SAndroid Build Coastguard Worker            v.setDebugName(new_name)
1664*da0073e9SAndroid Build Coastguard Worker            if old_name in params:
1665*da0073e9SAndroid Build Coastguard Worker                params[new_name] = params.pop(old_name)
1666*da0073e9SAndroid Build Coastguard Worker
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Workerdef _set_input_and_output_names(graph, input_names, output_names):
1669*da0073e9SAndroid Build Coastguard Worker    def set_names(node_list, name_list, descriptor):
1670*da0073e9SAndroid Build Coastguard Worker        if name_list is None:
1671*da0073e9SAndroid Build Coastguard Worker            return
1672*da0073e9SAndroid Build Coastguard Worker        if len(name_list) > len(node_list):
1673*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1674*da0073e9SAndroid Build Coastguard Worker                "number of %s names provided (%d) exceeded number of %ss (%d)"
1675*da0073e9SAndroid Build Coastguard Worker                % (descriptor, len(name_list), descriptor, len(node_list))
1676*da0073e9SAndroid Build Coastguard Worker            )
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker        # Mark if the output node DebugName is set before.
1679*da0073e9SAndroid Build Coastguard Worker        output_node_set = set()
1680*da0073e9SAndroid Build Coastguard Worker        for i, (name, node) in enumerate(zip(name_list, node_list)):
1681*da0073e9SAndroid Build Coastguard Worker            # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName().
1682*da0073e9SAndroid Build Coastguard Worker            if descriptor == "output":
1683*da0073e9SAndroid Build Coastguard Worker                if node in output_node_set:
1684*da0073e9SAndroid Build Coastguard Worker                    identity_node = graph.create("onnx::Identity")
1685*da0073e9SAndroid Build Coastguard Worker                    identity_node.insertAfter(node.node())
1686*da0073e9SAndroid Build Coastguard Worker                    identity_node.addInput(node)
1687*da0073e9SAndroid Build Coastguard Worker                    identity_node.output().setType(node.type())
1688*da0073e9SAndroid Build Coastguard Worker                    graph.return_node().replaceInput(i, identity_node.output())
1689*da0073e9SAndroid Build Coastguard Worker                    node = identity_node.output()
1690*da0073e9SAndroid Build Coastguard Worker                output_node_set.add(node)
1691*da0073e9SAndroid Build Coastguard Worker
1692*da0073e9SAndroid Build Coastguard Worker            if node.debugName() != name:
1693*da0073e9SAndroid Build Coastguard Worker                node.setDebugName(name)
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker    set_names(list(graph.inputs()), input_names, "input")
1696*da0073e9SAndroid Build Coastguard Worker    set_names(list(graph.outputs()), output_names, "output")
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Workerdef _run_symbolic_method(g, op_name, symbolic_fn, args):
1700*da0073e9SAndroid Build Coastguard Worker    r"""
1701*da0073e9SAndroid Build Coastguard Worker    This trampoline function gets invoked for every symbolic method
1702*da0073e9SAndroid Build Coastguard Worker    call from C++.
1703*da0073e9SAndroid Build Coastguard Worker    """
1704*da0073e9SAndroid Build Coastguard Worker    try:
1705*da0073e9SAndroid Build Coastguard Worker        graph_context = jit_utils.GraphContext(
1706*da0073e9SAndroid Build Coastguard Worker            graph=g,
1707*da0073e9SAndroid Build Coastguard Worker            block=g.block(),
1708*da0073e9SAndroid Build Coastguard Worker            opset=GLOBALS.export_onnx_opset_version,
1709*da0073e9SAndroid Build Coastguard Worker            original_node=None,  # type: ignore[arg-type]
1710*da0073e9SAndroid Build Coastguard Worker            params_dict=_params_dict,
1711*da0073e9SAndroid Build Coastguard Worker            env={},
1712*da0073e9SAndroid Build Coastguard Worker            values_in_env=set(),
1713*da0073e9SAndroid Build Coastguard Worker            new_nodes=[],
1714*da0073e9SAndroid Build Coastguard Worker        )
1715*da0073e9SAndroid Build Coastguard Worker        return symbolic_fn(graph_context, *args)
1716*da0073e9SAndroid Build Coastguard Worker    except TypeError as e:
1717*da0073e9SAndroid Build Coastguard Worker        # Handle the specific case where we didn't successfully dispatch
1718*da0073e9SAndroid Build Coastguard Worker        # to symbolic_fn.  Otherwise, the backtrace will have the clues
1719*da0073e9SAndroid Build Coastguard Worker        # you need.
1720*da0073e9SAndroid Build Coastguard Worker        e.args = (f"{e.args[0]} (occurred when translating {op_name})",)
1721*da0073e9SAndroid Build Coastguard Worker        raise
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Workerdef _add_block(node: _C.Node) -> _C.Block:
1725*da0073e9SAndroid Build Coastguard Worker    return node.addBlock()
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Workerdef _add_input_to_block(block: _C.Block):
1729*da0073e9SAndroid Build Coastguard Worker    return block.addInputToBlock()  # type: ignore[attr-defined]
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Workerdef _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
1733*da0073e9SAndroid Build Coastguard Worker    return block.registerOutput(value)
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker
1736*da0073e9SAndroid Build Coastguard Workerdef _should_aten_fallback(
1737*da0073e9SAndroid Build Coastguard Worker    name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
1738*da0073e9SAndroid Build Coastguard Worker):
1739*da0073e9SAndroid Build Coastguard Worker    # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
1740*da0073e9SAndroid Build Coastguard Worker    #   an aten::ATen operator is created regardless of symbolics existence
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker    is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
1743*da0073e9SAndroid Build Coastguard Worker    is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
1744*da0073e9SAndroid Build Coastguard Worker    is_aten_fallback_export = (
1745*da0073e9SAndroid Build Coastguard Worker        operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
1746*da0073e9SAndroid Build Coastguard Worker    )
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker    if not name.startswith("aten::"):
1749*da0073e9SAndroid Build Coastguard Worker        return False
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker    if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op):
1752*da0073e9SAndroid Build Coastguard Worker        return True
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    return False
1755*da0073e9SAndroid Build Coastguard Worker
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Workerdef _get_aten_op_overload_name(n: _C.Node) -> str:
1758*da0073e9SAndroid Build Coastguard Worker    # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
1759*da0073e9SAndroid Build Coastguard Worker    schema = n.schema()
1760*da0073e9SAndroid Build Coastguard Worker    if not schema.startswith("aten::"):
1761*da0073e9SAndroid Build Coastguard Worker        return ""
1762*da0073e9SAndroid Build Coastguard Worker    return _C.parse_schema(schema).overload_name
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker
1765*da0073e9SAndroid Build Coastguard Workerdef _run_symbolic_function(
1766*da0073e9SAndroid Build Coastguard Worker    graph: _C.Graph,
1767*da0073e9SAndroid Build Coastguard Worker    block: _C.Block,
1768*da0073e9SAndroid Build Coastguard Worker    node: _C.Node,
1769*da0073e9SAndroid Build Coastguard Worker    inputs: Any,
1770*da0073e9SAndroid Build Coastguard Worker    env: dict[_C.Value, _C.Value],
1771*da0073e9SAndroid Build Coastguard Worker    values_in_env: set[_C.Value],
1772*da0073e9SAndroid Build Coastguard Worker    new_nodes: list[_C.Node],
1773*da0073e9SAndroid Build Coastguard Worker    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1774*da0073e9SAndroid Build Coastguard Worker) -> _C.Value | Sequence[_C.Value | None] | None:
1775*da0073e9SAndroid Build Coastguard Worker    """Runs a symbolic function.
1776*da0073e9SAndroid Build Coastguard Worker
1777*da0073e9SAndroid Build Coastguard Worker    The function is used in C++ to export the node to ONNX.
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker    Returns:
1780*da0073e9SAndroid Build Coastguard Worker        A single or a tuple of Values.
1781*da0073e9SAndroid Build Coastguard Worker        None when the node gets cloned as is into the new graph.
1782*da0073e9SAndroid Build Coastguard Worker    """
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker    opset_version = GLOBALS.export_onnx_opset_version
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker    # See Note [Export inplace]
1787*da0073e9SAndroid Build Coastguard Worker    node_kind = node.kind()
1788*da0073e9SAndroid Build Coastguard Worker    if node_kind.endswith("_"):
1789*da0073e9SAndroid Build Coastguard Worker        # Treat relu_ -> relu; add_ -> add etc.
1790*da0073e9SAndroid Build Coastguard Worker        ns_op_name = node_kind[:-1]
1791*da0073e9SAndroid Build Coastguard Worker    else:
1792*da0073e9SAndroid Build Coastguard Worker        ns_op_name = node_kind
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Worker    namespace, op_name = jit_utils.parse_node_kind(ns_op_name)
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker    graph_context = jit_utils.GraphContext(
1797*da0073e9SAndroid Build Coastguard Worker        graph=graph,
1798*da0073e9SAndroid Build Coastguard Worker        block=block,
1799*da0073e9SAndroid Build Coastguard Worker        opset=opset_version,
1800*da0073e9SAndroid Build Coastguard Worker        original_node=node,
1801*da0073e9SAndroid Build Coastguard Worker        params_dict=_params_dict,
1802*da0073e9SAndroid Build Coastguard Worker        env=env,
1803*da0073e9SAndroid Build Coastguard Worker        values_in_env=values_in_env,
1804*da0073e9SAndroid Build Coastguard Worker        new_nodes=new_nodes,
1805*da0073e9SAndroid Build Coastguard Worker    )
1806*da0073e9SAndroid Build Coastguard Worker
1807*da0073e9SAndroid Build Coastguard Worker    # Direct ATen export requested
1808*da0073e9SAndroid Build Coastguard Worker    if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
1809*da0073e9SAndroid Build Coastguard Worker        attrs = {
1810*da0073e9SAndroid Build Coastguard Worker            k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1811*da0073e9SAndroid Build Coastguard Worker            for k in node.attributeNames()
1812*da0073e9SAndroid Build Coastguard Worker        }
1813*da0073e9SAndroid Build Coastguard Worker        outputs = node.outputsSize()
1814*da0073e9SAndroid Build Coastguard Worker        attrs["outputs"] = outputs
1815*da0073e9SAndroid Build Coastguard Worker        return graph_context.aten_op(
1816*da0073e9SAndroid Build Coastguard Worker            op_name,
1817*da0073e9SAndroid Build Coastguard Worker            *inputs,
1818*da0073e9SAndroid Build Coastguard Worker            overload_name=_get_aten_op_overload_name(node),
1819*da0073e9SAndroid Build Coastguard Worker            **attrs,
1820*da0073e9SAndroid Build Coastguard Worker        )
1821*da0073e9SAndroid Build Coastguard Worker
1822*da0073e9SAndroid Build Coastguard Worker    try:
1823*da0073e9SAndroid Build Coastguard Worker        domain = namespace
1824*da0073e9SAndroid Build Coastguard Worker        symbolic_function_name = f"{domain}::{op_name}"
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker        symbolic_function_group = registration.registry.get_function_group(
1827*da0073e9SAndroid Build Coastguard Worker            symbolic_function_name
1828*da0073e9SAndroid Build Coastguard Worker        )
1829*da0073e9SAndroid Build Coastguard Worker        if symbolic_function_group is not None:
1830*da0073e9SAndroid Build Coastguard Worker            symbolic_fn = symbolic_function_group.get(opset_version)
1831*da0073e9SAndroid Build Coastguard Worker            if symbolic_fn is not None:
1832*da0073e9SAndroid Build Coastguard Worker                # TODO Wrap almost identical attrs assignment or comment the difference.
1833*da0073e9SAndroid Build Coastguard Worker                attrs = {
1834*da0073e9SAndroid Build Coastguard Worker                    k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
1835*da0073e9SAndroid Build Coastguard Worker                }
1836*da0073e9SAndroid Build Coastguard Worker                return symbolic_fn(graph_context, *inputs, **attrs)
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        attrs = {
1839*da0073e9SAndroid Build Coastguard Worker            k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1840*da0073e9SAndroid Build Coastguard Worker            for k in node.attributeNames()
1841*da0073e9SAndroid Build Coastguard Worker        }
1842*da0073e9SAndroid Build Coastguard Worker        if namespace == "onnx":
1843*da0073e9SAndroid Build Coastguard Worker            # Clone node to trigger ONNX shape inference
1844*da0073e9SAndroid Build Coastguard Worker            return graph_context.op(
1845*da0073e9SAndroid Build Coastguard Worker                op_name, *inputs, **attrs, outputs=node.outputsSize()
1846*da0073e9SAndroid Build Coastguard Worker            )  # type: ignore[attr-defined]
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker        raise errors.UnsupportedOperatorError(
1849*da0073e9SAndroid Build Coastguard Worker            symbolic_function_name,
1850*da0073e9SAndroid Build Coastguard Worker            opset_version,
1851*da0073e9SAndroid Build Coastguard Worker            symbolic_function_group.get_min_supported()
1852*da0073e9SAndroid Build Coastguard Worker            if symbolic_function_group
1853*da0073e9SAndroid Build Coastguard Worker            else None,
1854*da0073e9SAndroid Build Coastguard Worker        )
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker    except RuntimeError:
1857*da0073e9SAndroid Build Coastguard Worker        if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
1858*da0073e9SAndroid Build Coastguard Worker            return None
1859*da0073e9SAndroid Build Coastguard Worker        elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
1860*da0073e9SAndroid Build Coastguard Worker            # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
1861*da0073e9SAndroid Build Coastguard Worker            attrs = {
1862*da0073e9SAndroid Build Coastguard Worker                k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1863*da0073e9SAndroid Build Coastguard Worker                for k in node.attributeNames()
1864*da0073e9SAndroid Build Coastguard Worker            }
1865*da0073e9SAndroid Build Coastguard Worker            return graph_context.aten_op(
1866*da0073e9SAndroid Build Coastguard Worker                op_name,
1867*da0073e9SAndroid Build Coastguard Worker                *inputs,
1868*da0073e9SAndroid Build Coastguard Worker                overload_name=_get_aten_op_overload_name(node),
1869*da0073e9SAndroid Build Coastguard Worker                **attrs,
1870*da0073e9SAndroid Build Coastguard Worker            )
1871*da0073e9SAndroid Build Coastguard Worker        raise
1872*da0073e9SAndroid Build Coastguard Worker    except TypeError as e:
1873*da0073e9SAndroid Build Coastguard Worker        # Handle the specific case where we didn't successfully dispatch.
1874*da0073e9SAndroid Build Coastguard Worker        # Otherwise, the backtrace will have the clues you need.
1875*da0073e9SAndroid Build Coastguard Worker        e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",)
1876*da0073e9SAndroid Build Coastguard Worker        raise
1877*da0073e9SAndroid Build Coastguard Worker
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Workerdef _verify_custom_op_name(symbolic_name: str):
1880*da0073e9SAndroid Build Coastguard Worker    if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
1881*da0073e9SAndroid Build Coastguard Worker        raise errors.OnnxExporterError(
1882*da0073e9SAndroid Build Coastguard Worker            f"Failed to register operator {symbolic_name}. "
1883*da0073e9SAndroid Build Coastguard Worker            "The symbolic name must match the format domain::name, "
1884*da0073e9SAndroid Build Coastguard Worker            "and should start with a letter and contain only "
1885*da0073e9SAndroid Build Coastguard Worker            "alphanumerical characters"
1886*da0073e9SAndroid Build Coastguard Worker        )
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker    ns, _ = jit_utils.parse_node_kind(symbolic_name)
1889*da0073e9SAndroid Build Coastguard Worker    if ns == "onnx":
1890*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
1891*da0073e9SAndroid Build Coastguard Worker            f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
1892*da0073e9SAndroid Build Coastguard Worker        )
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Workerdef register_custom_op_symbolic(
1896*da0073e9SAndroid Build Coastguard Worker    symbolic_name: str,
1897*da0073e9SAndroid Build Coastguard Worker    symbolic_fn: Callable,
1898*da0073e9SAndroid Build Coastguard Worker    opset_version: int,
1899*da0073e9SAndroid Build Coastguard Worker):
1900*da0073e9SAndroid Build Coastguard Worker    """Registers a symbolic function for a custom operator.
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker    When the user registers symbolic for custom/contrib ops,
1903*da0073e9SAndroid Build Coastguard Worker    it is highly recommended to add shape inference for that operator via setType API,
1904*da0073e9SAndroid Build Coastguard Worker    otherwise the exported graph may have incorrect shape inference in some extreme cases.
1905*da0073e9SAndroid Build Coastguard Worker    An example of setType is `test_aten_embedding_2` in `test_operators.py`.
1906*da0073e9SAndroid Build Coastguard Worker
1907*da0073e9SAndroid Build Coastguard Worker    See "Custom Operators" in the module documentation for an example usage.
1908*da0073e9SAndroid Build Coastguard Worker
1909*da0073e9SAndroid Build Coastguard Worker    Args:
1910*da0073e9SAndroid Build Coastguard Worker        symbolic_name (str): The name of the custom operator in "<domain>::<op>"
1911*da0073e9SAndroid Build Coastguard Worker            format.
1912*da0073e9SAndroid Build Coastguard Worker        symbolic_fn (Callable): A function that takes in the ONNX graph and
1913*da0073e9SAndroid Build Coastguard Worker            the input arguments to the current operator, and returns new
1914*da0073e9SAndroid Build Coastguard Worker            operator nodes to add to the graph.
1915*da0073e9SAndroid Build Coastguard Worker        opset_version (int): The ONNX opset version in which to register.
1916*da0073e9SAndroid Build Coastguard Worker    """
1917*da0073e9SAndroid Build Coastguard Worker    if symbolic_name.startswith("::"):
1918*da0073e9SAndroid Build Coastguard Worker        symbolic_name = f"aten{symbolic_name}"
1919*da0073e9SAndroid Build Coastguard Worker
1920*da0073e9SAndroid Build Coastguard Worker    _verify_custom_op_name(symbolic_name)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker    registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn)
1923*da0073e9SAndroid Build Coastguard Worker
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Workerdef unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
1926*da0073e9SAndroid Build Coastguard Worker    """Unregisters ``symbolic_name``.
1927*da0073e9SAndroid Build Coastguard Worker
1928*da0073e9SAndroid Build Coastguard Worker    See "Custom Operators" in the module documentation for an example usage.
1929*da0073e9SAndroid Build Coastguard Worker
1930*da0073e9SAndroid Build Coastguard Worker    Args:
1931*da0073e9SAndroid Build Coastguard Worker        symbolic_name (str): The name of the custom operator in "<domain>::<op>"
1932*da0073e9SAndroid Build Coastguard Worker            format.
1933*da0073e9SAndroid Build Coastguard Worker        opset_version (int): The ONNX opset version in which to unregister.
1934*da0073e9SAndroid Build Coastguard Worker    """
1935*da0073e9SAndroid Build Coastguard Worker    if symbolic_name.startswith("::"):
1936*da0073e9SAndroid Build Coastguard Worker        symbolic_name = f"aten{symbolic_name}"
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker    _verify_custom_op_name(symbolic_name)
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker    registration.registry.unregister(symbolic_name, opset_version)
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Workerdef _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
1944*da0073e9SAndroid Build Coastguard Worker    """Ensures dynamic axes argument is follows the expected format."""
1945*da0073e9SAndroid Build Coastguard Worker    if len(dynamic_axes) == 0:
1946*da0073e9SAndroid Build Coastguard Worker        return
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker    if hasattr(model, "graph"):
1949*da0073e9SAndroid Build Coastguard Worker        # Extracting set of valid input/output names that shall be used for dynamic_axes
1950*da0073e9SAndroid Build Coastguard Worker        if (input_names is None) or len(input_names) == 0:
1951*da0073e9SAndroid Build Coastguard Worker            input_names = [x.debugName() for x in model.graph.inputs()]
1952*da0073e9SAndroid Build Coastguard Worker        if (output_names is None) or len(output_names) == 0:
1953*da0073e9SAndroid Build Coastguard Worker            output_names = [y.debugName() for y in model.graph.outputs()]
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker    valid_names = set((input_names or []) + (output_names or []))
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker    # If dynamic axes are provided as a list rather than dictionary, they should
1958*da0073e9SAndroid Build Coastguard Worker    # first get converted to a dictionary in expected format. If desired axes names
1959*da0073e9SAndroid Build Coastguard Worker    # are not provided for dynamic axes, automatic names shall be generated for
1960*da0073e9SAndroid Build Coastguard Worker    # provided dynamic axes of specified input/output
1961*da0073e9SAndroid Build Coastguard Worker    for key, value in dynamic_axes.items():
1962*da0073e9SAndroid Build Coastguard Worker        if key not in valid_names:
1963*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
1964*da0073e9SAndroid Build Coastguard Worker                f"Provided key {key} for dynamic axes is not a valid input/output name"
1965*da0073e9SAndroid Build Coastguard Worker            )
1966*da0073e9SAndroid Build Coastguard Worker        if isinstance(value, list):
1967*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
1968*da0073e9SAndroid Build Coastguard Worker                "No names were found for specified dynamic axes of provided input."
1969*da0073e9SAndroid Build Coastguard Worker                f"Automatically generated names will be applied to each dynamic axes of input {key}"
1970*da0073e9SAndroid Build Coastguard Worker            )
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker            value_dict = {}
1973*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(value):
1974*da0073e9SAndroid Build Coastguard Worker                if not isinstance(x, int):
1975*da0073e9SAndroid Build Coastguard Worker                    raise ValueError(
1976*da0073e9SAndroid Build Coastguard Worker                        "The type of axis index is expected to be an integer"
1977*da0073e9SAndroid Build Coastguard Worker                    )
1978*da0073e9SAndroid Build Coastguard Worker                if x in value_dict:
1979*da0073e9SAndroid Build Coastguard Worker                    warnings.warn(
1980*da0073e9SAndroid Build Coastguard Worker                        f"Duplicate dynamic axis index {x} was provided for input {key}."
1981*da0073e9SAndroid Build Coastguard Worker                    )
1982*da0073e9SAndroid Build Coastguard Worker                else:
1983*da0073e9SAndroid Build Coastguard Worker                    value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
1984*da0073e9SAndroid Build Coastguard Worker            dynamic_axes[key] = value_dict
1985*da0073e9SAndroid Build Coastguard Worker
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Workerdef model_signature(model: torch.nn.Module | Callable) -> inspect.Signature:
1988*da0073e9SAndroid Build Coastguard Worker    return inspect.signature(
1989*da0073e9SAndroid Build Coastguard Worker        model.forward if isinstance(model, torch.nn.Module) else model
1990*da0073e9SAndroid Build Coastguard Worker    )
1991