1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker"""Functions to export models into the ONNX IR format. 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard WorkerThese models can be loaded with the ONNX library and then 5*da0073e9SAndroid Build Coastguard Workerconverted to models which run on other deep learning frameworks. 6*da0073e9SAndroid Build Coastguard Worker""" 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport contextlib 11*da0073e9SAndroid Build Coastguard Workerimport copy 12*da0073e9SAndroid Build Coastguard Workerimport inspect 13*da0073e9SAndroid Build Coastguard Workerimport re 14*da0073e9SAndroid Build Coastguard Workerimport typing 15*da0073e9SAndroid Build Coastguard Workerimport warnings 16*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, Collection, Mapping, Sequence 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerimport torch 19*da0073e9SAndroid Build Coastguard Workerimport torch._C._onnx as _C_onnx 20*da0073e9SAndroid Build Coastguard Workerimport torch.jit._trace 21*da0073e9SAndroid Build Coastguard Workerimport torch.serialization 22*da0073e9SAndroid Build Coastguard Workerfrom torch import _C 23*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import ( # noqa: F401 24*da0073e9SAndroid Build Coastguard Worker _constants, 25*da0073e9SAndroid Build Coastguard Worker _deprecation, 26*da0073e9SAndroid Build Coastguard Worker _exporter_states, 27*da0073e9SAndroid Build Coastguard Worker errors, 28*da0073e9SAndroid Build Coastguard Worker symbolic_helper, 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._globals import GLOBALS 31*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker__all__ = [ 35*da0073e9SAndroid Build Coastguard Worker "is_in_onnx_export", 36*da0073e9SAndroid Build Coastguard Worker "select_model_mode_for_export", 37*da0073e9SAndroid Build Coastguard Worker "disable_apex_o2_state_dict_hook", 38*da0073e9SAndroid Build Coastguard Worker "setup_onnx_logging", 39*da0073e9SAndroid Build Coastguard Worker "exporter_context", 40*da0073e9SAndroid Build Coastguard Worker "export", 41*da0073e9SAndroid Build Coastguard Worker "model_signature", 42*da0073e9SAndroid Build Coastguard Worker "warn_on_static_input_change", 43*da0073e9SAndroid Build Coastguard Worker "unpack_quantized_tensor", 44*da0073e9SAndroid Build Coastguard Worker "export_to_pretty_string", 45*da0073e9SAndroid Build Coastguard Worker "unconvertible_ops", 46*da0073e9SAndroid Build Coastguard Worker "register_custom_op_symbolic", 47*da0073e9SAndroid Build Coastguard Worker "unregister_custom_op_symbolic", 48*da0073e9SAndroid Build Coastguard Worker] 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerdef is_in_onnx_export() -> bool: 52*da0073e9SAndroid Build Coastguard Worker """Returns whether it is in the middle of ONNX export.""" 53*da0073e9SAndroid Build Coastguard Worker return GLOBALS.in_onnx_export 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp 57*da0073e9SAndroid Build Coastguard Worker# Skip check due to cannot import IValue from torch._C 58*da0073e9SAndroid Build Coastguard Worker_params_dict = {} # type: ignore[var-annotated] 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 62*da0073e9SAndroid Build Coastguard Workerdef select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): 63*da0073e9SAndroid Build Coastguard Worker r"""A context manager to temporarily set the training mode of ``model`` 64*da0073e9SAndroid Build Coastguard Worker to ``mode``, resetting it when we exit the with-block. 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker Args: 67*da0073e9SAndroid Build Coastguard Worker model: Same type and meaning as ``model`` arg to :func:`export`. 68*da0073e9SAndroid Build Coastguard Worker mode: Same type and meaning as ``training`` arg to :func:`export`. 69*da0073e9SAndroid Build Coastguard Worker """ 70*da0073e9SAndroid Build Coastguard Worker if not isinstance(mode, _C_onnx.TrainingMode): 71*da0073e9SAndroid Build Coastguard Worker raise TypeError( 72*da0073e9SAndroid Build Coastguard Worker f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." 73*da0073e9SAndroid Build Coastguard Worker ) 74*da0073e9SAndroid Build Coastguard Worker originally_training: bool = False 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker if hasattr(model, "training"): 77*da0073e9SAndroid Build Coastguard Worker originally_training = model.training 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # ONNX opset 12 has better support for training amenable models, with updated 80*da0073e9SAndroid Build Coastguard Worker # versions of the dropout and batch_norm operators 81*da0073e9SAndroid Build Coastguard Worker if mode == _C_onnx.TrainingMode.TRAINING or ( 82*da0073e9SAndroid Build Coastguard Worker mode == _C_onnx.TrainingMode.PRESERVE and originally_training 83*da0073e9SAndroid Build Coastguard Worker ): 84*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_training = True 85*da0073e9SAndroid Build Coastguard Worker if GLOBALS.export_onnx_opset_version < 12: 86*da0073e9SAndroid Build Coastguard Worker warnings.warn( 87*da0073e9SAndroid Build Coastguard Worker "You are exporting the model in training mode with onnx opset " 88*da0073e9SAndroid Build Coastguard Worker f"version {GLOBALS.export_onnx_opset_version}. " 89*da0073e9SAndroid Build Coastguard Worker "Opset versions lower than opset 12 will not be able to export " 90*da0073e9SAndroid Build Coastguard Worker "nodes such as Dropout and BatchNorm correctly." 91*da0073e9SAndroid Build Coastguard Worker ) 92*da0073e9SAndroid Build Coastguard Worker else: 93*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_training = False 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker GLOBALS.training_mode = mode 96*da0073e9SAndroid Build Coastguard Worker if mode == _C_onnx.TrainingMode.TRAINING: 97*da0073e9SAndroid Build Coastguard Worker model.train(True) 98*da0073e9SAndroid Build Coastguard Worker elif mode == _C_onnx.TrainingMode.EVAL: 99*da0073e9SAndroid Build Coastguard Worker model.train(False) 100*da0073e9SAndroid Build Coastguard Worker # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker try: 103*da0073e9SAndroid Build Coastguard Worker yield 104*da0073e9SAndroid Build Coastguard Worker finally: 105*da0073e9SAndroid Build Coastguard Worker if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: 106*da0073e9SAndroid Build Coastguard Worker model.train(originally_training) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 110*da0073e9SAndroid Build Coastguard Workerdef disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): 111*da0073e9SAndroid Build Coastguard Worker # Apex O2 hook state_dict to return fp16 weights as fp32. 112*da0073e9SAndroid Build Coastguard Worker # Exporter cannot identify them as same tensors. 113*da0073e9SAndroid Build Coastguard Worker # Since this hook is only used by optimizer, it is safe to 114*da0073e9SAndroid Build Coastguard Worker # remove this hook while exporting. 115*da0073e9SAndroid Build Coastguard Worker if not isinstance(model, torch.jit.ScriptFunction): 116*da0073e9SAndroid Build Coastguard Worker model_hooks = {} # type: ignore[var-annotated] 117*da0073e9SAndroid Build Coastguard Worker for module in model.modules(): 118*da0073e9SAndroid Build Coastguard Worker for key, hook in module._state_dict_hooks.items(): 119*da0073e9SAndroid Build Coastguard Worker if type(hook).__name__ == "O2StateDictHook": 120*da0073e9SAndroid Build Coastguard Worker if module not in model_hooks: 121*da0073e9SAndroid Build Coastguard Worker model_hooks[module] = {} 122*da0073e9SAndroid Build Coastguard Worker model_hooks[module][key] = hook 123*da0073e9SAndroid Build Coastguard Worker if module in model_hooks: 124*da0073e9SAndroid Build Coastguard Worker for key in model_hooks[module]: 125*da0073e9SAndroid Build Coastguard Worker module._state_dict_hooks.pop(key) 126*da0073e9SAndroid Build Coastguard Worker try: 127*da0073e9SAndroid Build Coastguard Worker yield 128*da0073e9SAndroid Build Coastguard Worker finally: 129*da0073e9SAndroid Build Coastguard Worker # Add the hooks back 130*da0073e9SAndroid Build Coastguard Worker for module, m_map in model_hooks.items(): 131*da0073e9SAndroid Build Coastguard Worker for key, hook in m_map.items(): 132*da0073e9SAndroid Build Coastguard Worker module._state_dict_hooks[key] = hook 133*da0073e9SAndroid Build Coastguard Worker else: 134*da0073e9SAndroid Build Coastguard Worker try: 135*da0073e9SAndroid Build Coastguard Worker yield 136*da0073e9SAndroid Build Coastguard Worker finally: 137*da0073e9SAndroid Build Coastguard Worker pass 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 141*da0073e9SAndroid Build Coastguard Workerdef setup_onnx_logging(verbose: bool): 142*da0073e9SAndroid Build Coastguard Worker is_originally_enabled = torch.onnx.is_onnx_log_enabled() 143*da0073e9SAndroid Build Coastguard Worker if is_originally_enabled or verbose: 144*da0073e9SAndroid Build Coastguard Worker torch.onnx.enable_log() 145*da0073e9SAndroid Build Coastguard Worker try: 146*da0073e9SAndroid Build Coastguard Worker yield 147*da0073e9SAndroid Build Coastguard Worker finally: 148*da0073e9SAndroid Build Coastguard Worker if not is_originally_enabled: 149*da0073e9SAndroid Build Coastguard Worker torch.onnx.disable_log() 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 153*da0073e9SAndroid Build Coastguard Workerdef exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): 154*da0073e9SAndroid Build Coastguard Worker with select_model_mode_for_export( 155*da0073e9SAndroid Build Coastguard Worker model, mode 156*da0073e9SAndroid Build Coastguard Worker ) as mode_ctx, disable_apex_o2_state_dict_hook( 157*da0073e9SAndroid Build Coastguard Worker model 158*da0073e9SAndroid Build Coastguard Worker ) as apex_ctx, setup_onnx_logging( 159*da0073e9SAndroid Build Coastguard Worker verbose 160*da0073e9SAndroid Build Coastguard Worker ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx: 161*da0073e9SAndroid Build Coastguard Worker yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Workerdef _get_torch_export_args( 165*da0073e9SAndroid Build Coastguard Worker args: tuple[Any, ...], 166*da0073e9SAndroid Build Coastguard Worker kwargs: dict[str, Any] | None, 167*da0073e9SAndroid Build Coastguard Worker) -> tuple[tuple[Any, ...], dict[str, Any] | None]: 168*da0073e9SAndroid Build Coastguard Worker """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" 169*da0073e9SAndroid Build Coastguard Worker if not kwargs and args and isinstance(args[-1], dict): 170*da0073e9SAndroid Build Coastguard Worker kwargs = args[-1] 171*da0073e9SAndroid Build Coastguard Worker args = args[:-1] 172*da0073e9SAndroid Build Coastguard Worker return args, kwargs 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Workerdef export( 176*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, 177*da0073e9SAndroid Build Coastguard Worker args: tuple[Any, ...] | torch.Tensor, 178*da0073e9SAndroid Build Coastguard Worker f: str, 179*da0073e9SAndroid Build Coastguard Worker *, 180*da0073e9SAndroid Build Coastguard Worker kwargs: dict[str, Any] | None = None, 181*da0073e9SAndroid Build Coastguard Worker export_params: bool = True, 182*da0073e9SAndroid Build Coastguard Worker verbose: bool = False, 183*da0073e9SAndroid Build Coastguard Worker training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 184*da0073e9SAndroid Build Coastguard Worker input_names: Sequence[str] | None = None, 185*da0073e9SAndroid Build Coastguard Worker output_names: Sequence[str] | None = None, 186*da0073e9SAndroid Build Coastguard Worker operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, 187*da0073e9SAndroid Build Coastguard Worker opset_version: int | None = None, 188*da0073e9SAndroid Build Coastguard Worker do_constant_folding: bool = True, 189*da0073e9SAndroid Build Coastguard Worker dynamic_axes: Mapping[str, Mapping[int, str]] 190*da0073e9SAndroid Build Coastguard Worker | Mapping[str, Sequence[int]] 191*da0073e9SAndroid Build Coastguard Worker | None = None, 192*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs: bool | None = None, 193*da0073e9SAndroid Build Coastguard Worker custom_opsets: Mapping[str, int] | None = None, 194*da0073e9SAndroid Build Coastguard Worker export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, 195*da0073e9SAndroid Build Coastguard Worker autograd_inlining: bool = True, 196*da0073e9SAndroid Build Coastguard Worker) -> None: 197*da0073e9SAndroid Build Coastguard Worker r"""Exports a model into ONNX format. 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker If ``model`` is not a :class:`torch.jit.ScriptModule` nor a 200*da0073e9SAndroid Build Coastguard Worker :class:`torch.jit.ScriptFunction`, this runs 201*da0073e9SAndroid Build Coastguard Worker ``model`` once in order to convert it to a TorchScript graph to be exported 202*da0073e9SAndroid Build Coastguard Worker (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support 203*da0073e9SAndroid Build Coastguard Worker for dynamic control flow as :func:`torch.jit.trace`. 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker Args: 206*da0073e9SAndroid Build Coastguard Worker model: The model to be exported. 207*da0073e9SAndroid Build Coastguard Worker args: 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker args can be structured either as: 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker 1. ONLY A TUPLE OF ARGUMENTS:: 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker args = (x, y, z) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker The tuple should contain model inputs such that ``model(*args)`` is a valid 216*da0073e9SAndroid Build Coastguard Worker invocation of the model. Any non-Tensor arguments will be hard-coded into the 217*da0073e9SAndroid Build Coastguard Worker exported model; any Tensor arguments will become inputs of the exported model, 218*da0073e9SAndroid Build Coastguard Worker in the order they occur in the tuple. 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker 2. A TENSOR:: 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker args = torch.Tensor([1]) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker This is equivalent to a 1-ary tuple of that Tensor. 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker args = (x, {"y": input_y, "z": input_z}) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker All but the last element of the tuple will be passed as non-keyword arguments, 231*da0073e9SAndroid Build Coastguard Worker and named arguments will be set from the last element. If a named argument is 232*da0073e9SAndroid Build Coastguard Worker not present in the dictionary, it is assigned the default value, or None if a 233*da0073e9SAndroid Build Coastguard Worker default value is not provided. 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker .. warning:: 236*da0073e9SAndroid Build Coastguard Worker This behavior will be deprecated in a future release. Please use the 237*da0073e9SAndroid Build Coastguard Worker kwargs argument instead. 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker .. note:: 240*da0073e9SAndroid Build Coastguard Worker If a dictionary is the last element of the args tuple, it will be 241*da0073e9SAndroid Build Coastguard Worker interpreted as containing named arguments. In order to pass a dict as the 242*da0073e9SAndroid Build Coastguard Worker last non-keyword arg, provide an empty dict as the last element of the args 243*da0073e9SAndroid Build Coastguard Worker tuple. For example, instead of:: 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker torch.onnx.export( 246*da0073e9SAndroid Build Coastguard Worker model, 247*da0073e9SAndroid Build Coastguard Worker ( 248*da0073e9SAndroid Build Coastguard Worker x, 249*da0073e9SAndroid Build Coastguard Worker # WRONG: will be interpreted as named arguments 250*da0073e9SAndroid Build Coastguard Worker {y: z}, 251*da0073e9SAndroid Build Coastguard Worker ), 252*da0073e9SAndroid Build Coastguard Worker "test.onnx.pb", 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker Write:: 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker f: Path to the output ONNX model file. E.g. "model.onnx". 260*da0073e9SAndroid Build Coastguard Worker kwargs: Named arguments to the model. 261*da0073e9SAndroid Build Coastguard Worker export_params: If True, all parameters will 262*da0073e9SAndroid Build Coastguard Worker be exported. Set this to False if you want to export an untrained model. 263*da0073e9SAndroid Build Coastguard Worker In this case, the exported model will first take all of its parameters 264*da0073e9SAndroid Build Coastguard Worker as arguments, with the ordering as specified by ``model.state_dict().values()`` 265*da0073e9SAndroid Build Coastguard Worker verbose: if True, prints a description of the 266*da0073e9SAndroid Build Coastguard Worker model being exported to stdout. In addition, the final ONNX graph will include the 267*da0073e9SAndroid Build Coastguard Worker field ``doc_string``` from the exported model which mentions the source code locations 268*da0073e9SAndroid Build Coastguard Worker for ``model``. If True, ONNX exporter logging will be turned on. 269*da0073e9SAndroid Build Coastguard Worker training: 270*da0073e9SAndroid Build Coastguard Worker * ``TrainingMode.EVAL``: export the model in inference mode. 271*da0073e9SAndroid Build Coastguard Worker * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is 272*da0073e9SAndroid Build Coastguard Worker False and in training mode if model.training is True. 273*da0073e9SAndroid Build Coastguard Worker * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations 274*da0073e9SAndroid Build Coastguard Worker which might interfere with training. 275*da0073e9SAndroid Build Coastguard Worker input_names (list of str, default empty list): names to assign to the 276*da0073e9SAndroid Build Coastguard Worker input nodes of the graph, in order. 277*da0073e9SAndroid Build Coastguard Worker output_names (list of str, default empty list): names to assign to the 278*da0073e9SAndroid Build Coastguard Worker output nodes of the graph, in order. 279*da0073e9SAndroid Build Coastguard Worker operator_export_type (enum, default OperatorExportTypes.ONNX): 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker .. warning:: 282*da0073e9SAndroid Build Coastguard Worker This option will be deprecated in a future release. Future exported 283*da0073e9SAndroid Build Coastguard Worker graphs will always use the default opset domain. 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops 286*da0073e9SAndroid Build Coastguard Worker (in the default opset domain). 287*da0073e9SAndroid Build Coastguard Worker * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops 288*da0073e9SAndroid Build Coastguard Worker to standard ONNX ops in the default opset domain. If unable to do so 289*da0073e9SAndroid Build Coastguard Worker (e.g. because support has not been added to convert a particular torch op to ONNX), 290*da0073e9SAndroid Build Coastguard Worker fall back to exporting the op into a custom opset domain without conversion. Applies 291*da0073e9SAndroid Build Coastguard Worker to `custom ops <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_ 292*da0073e9SAndroid Build Coastguard Worker as well as ATen ops. For the exported model to be usable, the runtime must support 293*da0073e9SAndroid Build Coastguard Worker these non-standard ops. 294*da0073e9SAndroid Build Coastguard Worker * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") 295*da0073e9SAndroid Build Coastguard Worker are exported as ATen ops (in opset domain "org.pytorch.aten"). 296*da0073e9SAndroid Build Coastguard Worker `ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library, so 297*da0073e9SAndroid Build Coastguard Worker this instructs the runtime to use PyTorch's implementation of these ops. 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker .. warning:: 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker Models exported this way are probably runnable only by Caffe2. 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker This may be useful if the numeric differences in implementations of operators are 304*da0073e9SAndroid Build Coastguard Worker causing large differences in behavior between PyTorch and Caffe2 (which is more 305*da0073e9SAndroid Build Coastguard Worker common on untrained models). 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op 308*da0073e9SAndroid Build Coastguard Worker (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so 309*da0073e9SAndroid Build Coastguard Worker (e.g. because support has not been added to convert a particular torch op to ONNX), 310*da0073e9SAndroid Build Coastguard Worker fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for 311*da0073e9SAndroid Build Coastguard Worker context. 312*da0073e9SAndroid Build Coastguard Worker For example:: 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker graph(%0 : Float): 315*da0073e9SAndroid Build Coastguard Worker %3 : int = prim::Constant[value=0]() 316*da0073e9SAndroid Build Coastguard Worker # conversion unsupported 317*da0073e9SAndroid Build Coastguard Worker %4 : Float = aten::triu(%0, %3) 318*da0073e9SAndroid Build Coastguard Worker # conversion supported 319*da0073e9SAndroid Build Coastguard Worker %5 : Float = aten::mul(%4, %0) 320*da0073e9SAndroid Build Coastguard Worker return (%5) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker graph(%0 : Float): 325*da0073e9SAndroid Build Coastguard Worker %1 : Long() = onnx::Constant[value={0}]() 326*da0073e9SAndroid Build Coastguard Worker # not converted 327*da0073e9SAndroid Build Coastguard Worker %2 : Float = aten::ATen[operator="triu"](%0, %1) 328*da0073e9SAndroid Build Coastguard Worker # converted 329*da0073e9SAndroid Build Coastguard Worker %3 : Float = onnx::Mul(%2, %0) 330*da0073e9SAndroid Build Coastguard Worker return (%3) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker .. warning:: 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker Models exported this way are probably runnable only by Caffe2. 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker opset_version (int, default 17): The version of the 337*da0073e9SAndroid Build Coastguard Worker `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_ 338*da0073e9SAndroid Build Coastguard Worker to target. Must be >= 7 and <= 17. 339*da0073e9SAndroid Build Coastguard Worker do_constant_folding: Apply the constant-folding optimization. 340*da0073e9SAndroid Build Coastguard Worker Constant-folding will replace some of the ops that have all constant inputs 341*da0073e9SAndroid Build Coastguard Worker with pre-computed constant nodes. 342*da0073e9SAndroid Build Coastguard Worker dynamic_axes: 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker By default the exported model will have the shapes of all input and output tensors 345*da0073e9SAndroid Build Coastguard Worker set to exactly match those given in ``args``. To specify axes of tensors as 346*da0073e9SAndroid Build Coastguard Worker dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or 349*da0073e9SAndroid Build Coastguard Worker ``output_names``. 350*da0073e9SAndroid Build Coastguard Worker * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a 351*da0073e9SAndroid Build Coastguard Worker list, each element is an axis index. 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker For example:: 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker class SumModule(torch.nn.Module): 356*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 357*da0073e9SAndroid Build Coastguard Worker return torch.sum(x, dim=1) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker torch.onnx.export( 361*da0073e9SAndroid Build Coastguard Worker SumModule(), 362*da0073e9SAndroid Build Coastguard Worker (torch.ones(2, 2),), 363*da0073e9SAndroid Build Coastguard Worker "onnx.pb", 364*da0073e9SAndroid Build Coastguard Worker input_names=["x"], 365*da0073e9SAndroid Build Coastguard Worker output_names=["sum"], 366*da0073e9SAndroid Build Coastguard Worker ) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker Produces:: 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker input { 371*da0073e9SAndroid Build Coastguard Worker name: "x" 372*da0073e9SAndroid Build Coastguard Worker ... 373*da0073e9SAndroid Build Coastguard Worker shape { 374*da0073e9SAndroid Build Coastguard Worker dim { 375*da0073e9SAndroid Build Coastguard Worker dim_value: 2 # axis 0 376*da0073e9SAndroid Build Coastguard Worker } 377*da0073e9SAndroid Build Coastguard Worker dim { 378*da0073e9SAndroid Build Coastguard Worker dim_value: 2 # axis 1 379*da0073e9SAndroid Build Coastguard Worker ... 380*da0073e9SAndroid Build Coastguard Worker output { 381*da0073e9SAndroid Build Coastguard Worker name: "sum" 382*da0073e9SAndroid Build Coastguard Worker ... 383*da0073e9SAndroid Build Coastguard Worker shape { 384*da0073e9SAndroid Build Coastguard Worker dim { 385*da0073e9SAndroid Build Coastguard Worker dim_value: 2 # axis 0 386*da0073e9SAndroid Build Coastguard Worker ... 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker While:: 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker torch.onnx.export( 391*da0073e9SAndroid Build Coastguard Worker SumModule(), 392*da0073e9SAndroid Build Coastguard Worker (torch.ones(2, 2),), 393*da0073e9SAndroid Build Coastguard Worker "onnx.pb", 394*da0073e9SAndroid Build Coastguard Worker input_names=["x"], 395*da0073e9SAndroid Build Coastguard Worker output_names=["sum"], 396*da0073e9SAndroid Build Coastguard Worker dynamic_axes={ 397*da0073e9SAndroid Build Coastguard Worker # dict value: manually named axes 398*da0073e9SAndroid Build Coastguard Worker "x": {0: "my_custom_axis_name"}, 399*da0073e9SAndroid Build Coastguard Worker # list value: automatic names 400*da0073e9SAndroid Build Coastguard Worker "sum": [0], 401*da0073e9SAndroid Build Coastguard Worker }, 402*da0073e9SAndroid Build Coastguard Worker ) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker Produces:: 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker input { 407*da0073e9SAndroid Build Coastguard Worker name: "x" 408*da0073e9SAndroid Build Coastguard Worker ... 409*da0073e9SAndroid Build Coastguard Worker shape { 410*da0073e9SAndroid Build Coastguard Worker dim { 411*da0073e9SAndroid Build Coastguard Worker dim_param: "my_custom_axis_name" # axis 0 412*da0073e9SAndroid Build Coastguard Worker } 413*da0073e9SAndroid Build Coastguard Worker dim { 414*da0073e9SAndroid Build Coastguard Worker dim_value: 2 # axis 1 415*da0073e9SAndroid Build Coastguard Worker ... 416*da0073e9SAndroid Build Coastguard Worker output { 417*da0073e9SAndroid Build Coastguard Worker name: "sum" 418*da0073e9SAndroid Build Coastguard Worker ... 419*da0073e9SAndroid Build Coastguard Worker shape { 420*da0073e9SAndroid Build Coastguard Worker dim { 421*da0073e9SAndroid Build Coastguard Worker dim_param: "sum_dynamic_axes_1" # axis 0 422*da0073e9SAndroid Build Coastguard Worker ... 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs: If True, all the 425*da0073e9SAndroid Build Coastguard Worker initializers (typically corresponding to parameters) in the 426*da0073e9SAndroid Build Coastguard Worker exported graph will also be added as inputs to the graph. If False, 427*da0073e9SAndroid Build Coastguard Worker then initializers are not added as inputs to the graph, and only 428*da0073e9SAndroid Build Coastguard Worker the non-parameter inputs are added as inputs. 429*da0073e9SAndroid Build Coastguard Worker This may allow for better optimizations (e.g. constant folding) by 430*da0073e9SAndroid Build Coastguard Worker backends/runtimes. 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker If True, `deduplicate_initializers` pass will not be executed. This means 433*da0073e9SAndroid Build Coastguard Worker initializers with duplicated values will not be deduplicated and 434*da0073e9SAndroid Build Coastguard Worker will be treated as distinct inputs to the graph. This allows different 435*da0073e9SAndroid Build Coastguard Worker input initializers to be supplied at the runtime following export. 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker If ``opset_version < 9``, initializers MUST be part of graph 438*da0073e9SAndroid Build Coastguard Worker inputs and this argument will be ignored and the behavior will be 439*da0073e9SAndroid Build Coastguard Worker equivalent to setting this argument to True. 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker custom_opsets (dict[str, int], default empty dict): A dict with schema: 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker * KEY (str): opset domain name 444*da0073e9SAndroid Build Coastguard Worker * VALUE (int): opset version 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker If a custom opset is referenced by ``model`` but not mentioned in this dictionary, 447*da0073e9SAndroid Build Coastguard Worker the opset version is set to 1. Only custom opset domain name and version should be 448*da0073e9SAndroid Build Coastguard Worker indicated through this argument. 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker export_modules_as_functions: Flag to enable 451*da0073e9SAndroid Build Coastguard Worker exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the 452*da0073e9SAndroid Build Coastguard Worker particular types of modules to export as local functions in ONNX. 453*da0073e9SAndroid Build Coastguard Worker This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because 454*da0073e9SAndroid Build Coastguard Worker ``opset_version`` < 15 implies IR version < 8, which means no local function support. 455*da0073e9SAndroid Build Coastguard Worker Module variables will be exported as function attributes. There are two categories of function 456*da0073e9SAndroid Build Coastguard Worker attributes. 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker 1. Annotated attributes: class variables that have type annotations via 459*da0073e9SAndroid Build Coastguard Worker `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_ 460*da0073e9SAndroid Build Coastguard Worker will be exported as attributes. 461*da0073e9SAndroid Build Coastguard Worker Annotated attributes are not used inside the subgraph of ONNX local function because 462*da0073e9SAndroid Build Coastguard Worker they are not created by PyTorch JIT tracing, but they may be used by consumers 463*da0073e9SAndroid Build Coastguard Worker to determine whether or not to replace the function with a particular fused kernel. 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker 2. Inferred attributes: variables that are used by operators inside the module. Attribute names 466*da0073e9SAndroid Build Coastguard Worker will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from 467*da0073e9SAndroid Build Coastguard Worker python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. 470*da0073e9SAndroid Build Coastguard Worker * ``True``: export all ``nn.Module`` forward calls as local function nodes. 471*da0073e9SAndroid Build Coastguard Worker * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, 472*da0073e9SAndroid Build Coastguard Worker only if the type of the ``nn.Module`` is found in the set. 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker autograd_inlining: Flag used to control whether to inline autograd functions. 475*da0073e9SAndroid Build Coastguard Worker Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker Raises: 478*da0073e9SAndroid Build Coastguard Worker :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. 479*da0073e9SAndroid Build Coastguard Worker :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it 480*da0073e9SAndroid Build Coastguard Worker uses an operator that is not supported by the exporter. 481*da0073e9SAndroid Build Coastguard Worker :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. 482*da0073e9SAndroid Build Coastguard Worker All errors are subclasses of :class:`errors.OnnxExporterError`. 483*da0073e9SAndroid Build Coastguard Worker """ 484*da0073e9SAndroid Build Coastguard Worker if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: 485*da0073e9SAndroid Build Coastguard Worker warnings.warn( 486*da0073e9SAndroid Build Coastguard Worker "Setting `operator_export_type` to something other than default is deprecated. " 487*da0073e9SAndroid Build Coastguard Worker "The option will be removed in a future release.", 488*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 489*da0073e9SAndroid Build Coastguard Worker ) 490*da0073e9SAndroid Build Coastguard Worker if training == _C_onnx.TrainingMode.TRAINING: 491*da0073e9SAndroid Build Coastguard Worker warnings.warn( 492*da0073e9SAndroid Build Coastguard Worker "Setting `training` to something other than default is deprecated. " 493*da0073e9SAndroid Build Coastguard Worker "The option will be removed in a future release. Please set the training mode " 494*da0073e9SAndroid Build Coastguard Worker "before exporting the model.", 495*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 496*da0073e9SAndroid Build Coastguard Worker ) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker args = (args,) if isinstance(args, torch.Tensor) else args 499*da0073e9SAndroid Build Coastguard Worker if kwargs is not None: 500*da0073e9SAndroid Build Coastguard Worker args = args + (kwargs,) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker _export( 503*da0073e9SAndroid Build Coastguard Worker model, 504*da0073e9SAndroid Build Coastguard Worker args, 505*da0073e9SAndroid Build Coastguard Worker f, 506*da0073e9SAndroid Build Coastguard Worker export_params, 507*da0073e9SAndroid Build Coastguard Worker verbose, 508*da0073e9SAndroid Build Coastguard Worker training, 509*da0073e9SAndroid Build Coastguard Worker input_names, 510*da0073e9SAndroid Build Coastguard Worker output_names, 511*da0073e9SAndroid Build Coastguard Worker operator_export_type=operator_export_type, 512*da0073e9SAndroid Build Coastguard Worker opset_version=opset_version, 513*da0073e9SAndroid Build Coastguard Worker do_constant_folding=do_constant_folding, 514*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 515*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=keep_initializers_as_inputs, 516*da0073e9SAndroid Build Coastguard Worker custom_opsets=custom_opsets, 517*da0073e9SAndroid Build Coastguard Worker export_modules_as_functions=export_modules_as_functions, 518*da0073e9SAndroid Build Coastguard Worker autograd_inlining=autograd_inlining, 519*da0073e9SAndroid Build Coastguard Worker ) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker return None 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Workerdef _is_constant_tensor_list(node): 525*da0073e9SAndroid Build Coastguard Worker if node.kind() != "prim::Constant": 526*da0073e9SAndroid Build Coastguard Worker return False 527*da0073e9SAndroid Build Coastguard Worker output_type = node.output().type() 528*da0073e9SAndroid Build Coastguard Worker if output_type.isSubtypeOf(_C.ListType.ofTensors()): 529*da0073e9SAndroid Build Coastguard Worker return True 530*da0073e9SAndroid Build Coastguard Worker if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): 531*da0073e9SAndroid Build Coastguard Worker return True 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker# ONNX can't handle constants that are lists of tensors, which can 535*da0073e9SAndroid Build Coastguard Worker# get generated in constant prop. So we split them back into prim::ListConstructs 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Workerdef _split_tensor_list_constants(g, block): 539*da0073e9SAndroid Build Coastguard Worker for node in block.nodes(): 540*da0073e9SAndroid Build Coastguard Worker for subblock in node.blocks(): 541*da0073e9SAndroid Build Coastguard Worker _split_tensor_list_constants(g, subblock) 542*da0073e9SAndroid Build Coastguard Worker if _is_constant_tensor_list(node): 543*da0073e9SAndroid Build Coastguard Worker inputs = [] 544*da0073e9SAndroid Build Coastguard Worker for val in node.output().toIValue(): 545*da0073e9SAndroid Build Coastguard Worker input = g.insertConstant(val) 546*da0073e9SAndroid Build Coastguard Worker input.node().moveBefore(node) 547*da0073e9SAndroid Build Coastguard Worker input.node().copyMetadata(node) 548*da0073e9SAndroid Build Coastguard Worker inputs.append(input) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker lc = ( 551*da0073e9SAndroid Build Coastguard Worker g.create("prim::ListConstruct", inputs) 552*da0073e9SAndroid Build Coastguard Worker .insertBefore(node) 553*da0073e9SAndroid Build Coastguard Worker .output() 554*da0073e9SAndroid Build Coastguard Worker .setType(_C.ListType.ofTensors()) 555*da0073e9SAndroid Build Coastguard Worker ) 556*da0073e9SAndroid Build Coastguard Worker lc.node().copyMetadata(node) 557*da0073e9SAndroid Build Coastguard Worker node.output().replaceAllUsesWith(lc) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Workerdef _optimize_graph( 561*da0073e9SAndroid Build Coastguard Worker graph: _C.Graph, 562*da0073e9SAndroid Build Coastguard Worker operator_export_type: _C_onnx.OperatorExportTypes, 563*da0073e9SAndroid Build Coastguard Worker _disable_torch_constant_prop: bool = False, 564*da0073e9SAndroid Build Coastguard Worker fixed_batch_size: bool = False, 565*da0073e9SAndroid Build Coastguard Worker params_dict=None, 566*da0073e9SAndroid Build Coastguard Worker dynamic_axes=None, 567*da0073e9SAndroid Build Coastguard Worker input_names=None, 568*da0073e9SAndroid Build Coastguard Worker module=None, 569*da0073e9SAndroid Build Coastguard Worker): 570*da0073e9SAndroid Build Coastguard Worker if params_dict is None: 571*da0073e9SAndroid Build Coastguard Worker params_dict = {} 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker # Inline everything 574*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_inline(graph) 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker # Remove fork/wait nodes 577*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_inline_fork_wait(graph) 578*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 579*da0073e9SAndroid Build Coastguard Worker if GLOBALS.autograd_inlining: 580*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_autograd_function_process(graph) 581*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lower_all_tuples(graph) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker # we now record some ops like ones/zeros 584*da0073e9SAndroid Build Coastguard Worker # into a trace where we previously recorded constants. 585*da0073e9SAndroid Build Coastguard Worker # use constant prop to maintain our current level of onnx support 586*da0073e9SAndroid Build Coastguard Worker # without implementing symbolics for all of them 587*da0073e9SAndroid Build Coastguard Worker if _disable_torch_constant_prop is False: 588*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_constant_propagation(graph) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker _split_tensor_list_constants(graph, graph) 591*da0073e9SAndroid Build Coastguard Worker # run dce to eliminate dead parts of the graph that might have been 592*da0073e9SAndroid Build Coastguard Worker # left behind by things like symbolic_override 593*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce(graph) 594*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker # CSE should improve perf when Autocast is used with disabled cache 597*da0073e9SAndroid Build Coastguard Worker # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 598*da0073e9SAndroid Build Coastguard Worker # Must run before _C._jit_pass_erase_number_types to prevent type substitution 599*da0073e9SAndroid Build Coastguard Worker if _C._jit_pass_cse(graph): 600*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_lint(graph) 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_canonicalize_graph_fuser_ops(graph) 603*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 604*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_peephole(graph, True) 605*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_fuse_addmm(graph) 606*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_peephole(graph, True) 609*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lower_all_tuples(graph) 610*da0073e9SAndroid Build Coastguard Worker # in _jit_pass_onnx, symbolic functions are called for each node for conversion. 611*da0073e9SAndroid Build Coastguard Worker # However, there are nodes that cannot be converted without additional context. 612*da0073e9SAndroid Build Coastguard Worker # For example, the number of outputs from split (and whether it is static or dynamic) is unknown 613*da0073e9SAndroid Build Coastguard Worker # until the point where it is unpacked by listUnpack node. 614*da0073e9SAndroid Build Coastguard Worker # This pass does a preprocess, and prepares the nodes such that enough context can be received 615*da0073e9SAndroid Build Coastguard Worker # by the symbolic function. 616*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) 617*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_preprocess(graph) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker # onnx does not support tuples, so try to remove them 620*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 623*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_prepare_division_for_onnx(graph) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_remove_print(graph) 626*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_preprocess_caffe2(graph) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker symbolic_helper._quantized_ops.clear() 629*da0073e9SAndroid Build Coastguard Worker # Unpack quantized weights for conv and linear ops and insert into graph. 630*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) 631*da0073e9SAndroid Build Coastguard Worker # onnx only supports tensors, so we turn all out number types into tensors 632*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_erase_number_types(graph) 633*da0073e9SAndroid Build Coastguard Worker if GLOBALS.onnx_shape_inference: 634*da0073e9SAndroid Build Coastguard Worker input_names = [] if input_names is None else input_names 635*da0073e9SAndroid Build Coastguard Worker dynamic_axes = {} if dynamic_axes is None else dynamic_axes 636*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) 637*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_lint(graph) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker graph = _C._jit_pass_onnx(graph, operator_export_type) 640*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_lint(graph) 641*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_scalar_type_analysis( 644*da0073e9SAndroid Build Coastguard Worker graph, True, GLOBALS.export_onnx_opset_version 645*da0073e9SAndroid Build Coastguard Worker ) 646*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_peephole( 649*da0073e9SAndroid Build Coastguard Worker graph, GLOBALS.export_onnx_opset_version, fixed_batch_size 650*da0073e9SAndroid Build Coastguard Worker ) 651*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker # graph is not a valid jit graph anymore because types have been replaced 654*da0073e9SAndroid Build Coastguard Worker # (e.g. int with Tensor), so it now contains operators that don't actually 655*da0073e9SAndroid Build Coastguard Worker # exist. We can't run normal dead code elimination because it'd fail trying 656*da0073e9SAndroid Build Coastguard Worker # to look up if an operator has side effects, but we can run a dead code 657*da0073e9SAndroid Build Coastguard Worker # elimination variant that doesn't need to look up if an op has side effects. 658*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 659*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 660*da0073e9SAndroid Build Coastguard Worker graph = _C._jit_pass_canonicalize(graph) 661*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_lint(graph) 662*da0073e9SAndroid Build Coastguard Worker if GLOBALS.onnx_shape_inference: 663*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_graph_shape_type_inference( 664*da0073e9SAndroid Build Coastguard Worker graph, params_dict, GLOBALS.export_onnx_opset_version 665*da0073e9SAndroid Build Coastguard Worker ) 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker return graph 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Workerdef warn_on_static_input_change(input_states): 671*da0073e9SAndroid Build Coastguard Worker """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker We accept dictionaries and strings as ONNX inputs, but they should be only for 674*da0073e9SAndroid Build Coastguard Worker configuration use. we detect here if these inputs are modified, and if so we warn 675*da0073e9SAndroid Build Coastguard Worker the user that the changes won't take effect in the traced ONNX graph. 676*da0073e9SAndroid Build Coastguard Worker """ 677*da0073e9SAndroid Build Coastguard Worker for input, traced_input in zip(input_states[0], input_states[1]): 678*da0073e9SAndroid Build Coastguard Worker if isinstance(input, dict): 679*da0073e9SAndroid Build Coastguard Worker if list(input.keys()) != list(traced_input.keys()): 680*da0073e9SAndroid Build Coastguard Worker warning = ( 681*da0073e9SAndroid Build Coastguard Worker "We detected that you are modifying a dictionary that is an input to your " 682*da0073e9SAndroid Build Coastguard Worker "model. " 683*da0073e9SAndroid Build Coastguard Worker "Note that dictionaries are allowed as inputs in ONNX but they should be " 684*da0073e9SAndroid Build Coastguard Worker "handled with care. " 685*da0073e9SAndroid Build Coastguard Worker "Usages of dictionaries is not recommended, and should not be used except " 686*da0073e9SAndroid Build Coastguard Worker "for configuration use. " 687*da0073e9SAndroid Build Coastguard Worker "Also note that the order and values of the keys must remain the same. " 688*da0073e9SAndroid Build Coastguard Worker ) 689*da0073e9SAndroid Build Coastguard Worker warnings.warn(warning) 690*da0073e9SAndroid Build Coastguard Worker elif isinstance(input, str): 691*da0073e9SAndroid Build Coastguard Worker if input != traced_input: 692*da0073e9SAndroid Build Coastguard Worker warning = ( 693*da0073e9SAndroid Build Coastguard Worker "The model seems to have string inputs/outputs. " 694*da0073e9SAndroid Build Coastguard Worker "Note that strings will not appear as inputs/outputs of the ONNX graph. " 695*da0073e9SAndroid Build Coastguard Worker ) 696*da0073e9SAndroid Build Coastguard Worker warnings.warn(warning) 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Workerdef _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): 700*da0073e9SAndroid Build Coastguard Worker """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" 701*da0073e9SAndroid Build Coastguard Worker return arg_value 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Workerdef _decide_keep_init_as_input( 705*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs: bool | None, 706*da0073e9SAndroid Build Coastguard Worker operator_export_type: _C_onnx.OperatorExportTypes, 707*da0073e9SAndroid Build Coastguard Worker opset_version: int, 708*da0073e9SAndroid Build Coastguard Worker): 709*da0073e9SAndroid Build Coastguard Worker """Decides whether the initializers in the graph should be listed as ONNX graph inputs. 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker This method encapsulates the logic to decide whether the initializers in the graph 712*da0073e9SAndroid Build Coastguard Worker should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). 713*da0073e9SAndroid Build Coastguard Worker If keep_initializers_as_inputs is not specified (None), then we decide whether to keep 714*da0073e9SAndroid Build Coastguard Worker initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type 715*da0073e9SAndroid Build Coastguard Worker is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other 716*da0073e9SAndroid Build Coastguard Worker export types keep initializers as input (val_keep_init_as_ip=True). 717*da0073e9SAndroid Build Coastguard Worker If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, 718*da0073e9SAndroid Build Coastguard Worker in which case it must be ignored because for opset version <= 8, all initializers MUST be 719*da0073e9SAndroid Build Coastguard Worker part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker Special handling is needed for opset version 8 or lower, because irrespective 722*da0073e9SAndroid Build Coastguard Worker of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 723*da0073e9SAndroid Build Coastguard Worker semantics, i.e. all initializers must be listed as ONNX graph input. 724*da0073e9SAndroid Build Coastguard Worker """ 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker if opset_version < 9: 727*da0073e9SAndroid Build Coastguard Worker if keep_initializers_as_inputs is False: 728*da0073e9SAndroid Build Coastguard Worker warnings.warn( 729*da0073e9SAndroid Build Coastguard Worker "Setting 'keep_initializers_as_inputs=False' for opset version" 730*da0073e9SAndroid Build Coastguard Worker "8 or lower would lead to an invalid ONNX graph. Therefore, " 731*da0073e9SAndroid Build Coastguard Worker "'keep_initializers_as_inputs=False' is ignored during export." 732*da0073e9SAndroid Build Coastguard Worker "Exported model will have initializers as graph inputs (compliant " 733*da0073e9SAndroid Build Coastguard Worker " to ONNX IR v3)." 734*da0073e9SAndroid Build Coastguard Worker ) 735*da0073e9SAndroid Build Coastguard Worker return True # i.e. True == initializers are part of graph input (ONNX IR v3) 736*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip = ( 737*da0073e9SAndroid Build Coastguard Worker True if keep_initializers_as_inputs is None else keep_initializers_as_inputs 738*da0073e9SAndroid Build Coastguard Worker ) 739*da0073e9SAndroid Build Coastguard Worker if ( 740*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs is None 741*da0073e9SAndroid Build Coastguard Worker and operator_export_type is _C_onnx.OperatorExportTypes.ONNX 742*da0073e9SAndroid Build Coastguard Worker ): 743*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip = False 744*da0073e9SAndroid Build Coastguard Worker return val_keep_init_as_ip 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Workerdef _decide_add_node_names(add_node_names, operator_export_type): 748*da0073e9SAndroid Build Coastguard Worker return _resolve_args_by_export_type( 749*da0073e9SAndroid Build Coastguard Worker "add_node_names", add_node_names, operator_export_type 750*da0073e9SAndroid Build Coastguard Worker ) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Workerdef _decide_constant_folding(do_constant_folding, operator_export_type, training): 754*da0073e9SAndroid Build Coastguard Worker do_constant_folding = _resolve_args_by_export_type( 755*da0073e9SAndroid Build Coastguard Worker "do_constant_folding", do_constant_folding, operator_export_type 756*da0073e9SAndroid Build Coastguard Worker ) 757*da0073e9SAndroid Build Coastguard Worker if do_constant_folding and ( 758*da0073e9SAndroid Build Coastguard Worker training is not None and training is not _C_onnx.TrainingMode.EVAL 759*da0073e9SAndroid Build Coastguard Worker ): 760*da0073e9SAndroid Build Coastguard Worker warnings.warn( 761*da0073e9SAndroid Build Coastguard Worker "It is recommended that constant folding be turned off ('do_constant_folding=False') " 762*da0073e9SAndroid Build Coastguard Worker "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " 763*da0073e9SAndroid Build Coastguard Worker "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " 764*da0073e9SAndroid Build Coastguard Worker "learnable model parameters may not translate correctly in the exported ONNX model " 765*da0073e9SAndroid Build Coastguard Worker "because constant folding mutates model parameters. Please consider " 766*da0073e9SAndroid Build Coastguard Worker "turning off constant folding or setting the training=TrainingMode.EVAL." 767*da0073e9SAndroid Build Coastguard Worker ) 768*da0073e9SAndroid Build Coastguard Worker return do_constant_folding 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Workerdef _signature(model) -> inspect.Signature: 772*da0073e9SAndroid Build Coastguard Worker should_be_callable = getattr(model, "forward", model) 773*da0073e9SAndroid Build Coastguard Worker if callable(should_be_callable): 774*da0073e9SAndroid Build Coastguard Worker return inspect.signature(should_be_callable) 775*da0073e9SAndroid Build Coastguard Worker raise ValueError("model has no forward method and is not callable") 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Workerdef _decide_input_format(model, args): 779*da0073e9SAndroid Build Coastguard Worker try: 780*da0073e9SAndroid Build Coastguard Worker sig = _signature(model) 781*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 782*da0073e9SAndroid Build Coastguard Worker warnings.warn(f"{e}, skipping _decide_input_format") 783*da0073e9SAndroid Build Coastguard Worker return args 784*da0073e9SAndroid Build Coastguard Worker try: 785*da0073e9SAndroid Build Coastguard Worker ordered_list_keys = list(sig.parameters.keys()) 786*da0073e9SAndroid Build Coastguard Worker if ordered_list_keys[0] == "self": 787*da0073e9SAndroid Build Coastguard Worker ordered_list_keys = ordered_list_keys[1:] 788*da0073e9SAndroid Build Coastguard Worker args_dict: dict = {} 789*da0073e9SAndroid Build Coastguard Worker if isinstance(args, list): 790*da0073e9SAndroid Build Coastguard Worker args_list = args 791*da0073e9SAndroid Build Coastguard Worker elif isinstance(args, tuple): 792*da0073e9SAndroid Build Coastguard Worker args_list = list(args) 793*da0073e9SAndroid Build Coastguard Worker else: 794*da0073e9SAndroid Build Coastguard Worker args_list = [args] 795*da0073e9SAndroid Build Coastguard Worker if isinstance(args_list[-1], dict): 796*da0073e9SAndroid Build Coastguard Worker args_dict = args_list[-1] 797*da0073e9SAndroid Build Coastguard Worker args_list = args_list[:-1] 798*da0073e9SAndroid Build Coastguard Worker n_nonkeyword = len(args_list) 799*da0073e9SAndroid Build Coastguard Worker for optional_arg in ordered_list_keys[n_nonkeyword:]: 800*da0073e9SAndroid Build Coastguard Worker if optional_arg in args_dict: 801*da0073e9SAndroid Build Coastguard Worker args_list.append(args_dict[optional_arg]) 802*da0073e9SAndroid Build Coastguard Worker # Check if this arg has a default value 803*da0073e9SAndroid Build Coastguard Worker else: 804*da0073e9SAndroid Build Coastguard Worker param = sig.parameters[optional_arg] 805*da0073e9SAndroid Build Coastguard Worker if param.default != param.empty: 806*da0073e9SAndroid Build Coastguard Worker args_list.append(param.default) 807*da0073e9SAndroid Build Coastguard Worker args = args_list if isinstance(args, list) else tuple(args_list) 808*da0073e9SAndroid Build Coastguard Worker # Cases of models with no input args 809*da0073e9SAndroid Build Coastguard Worker except IndexError: 810*da0073e9SAndroid Build Coastguard Worker warnings.warn("No input args, skipping _decide_input_format") 811*da0073e9SAndroid Build Coastguard Worker except Exception as e: 812*da0073e9SAndroid Build Coastguard Worker warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") 813*da0073e9SAndroid Build Coastguard Worker return args 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Workerdef _from_dynamic_axes_to_dynamic_shapes( 817*da0073e9SAndroid Build Coastguard Worker model, 818*da0073e9SAndroid Build Coastguard Worker dynamic_axes: Mapping[str, Mapping[int, str]] 819*da0073e9SAndroid Build Coastguard Worker | Mapping[str, Sequence[int]] 820*da0073e9SAndroid Build Coastguard Worker | None = None, 821*da0073e9SAndroid Build Coastguard Worker input_names: Sequence[str] | None = None, 822*da0073e9SAndroid Build Coastguard Worker) -> dict[str, Any] | None: 823*da0073e9SAndroid Build Coastguard Worker """ 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker dynamic_axes examples: 826*da0073e9SAndroid Build Coastguard Worker (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} 827*da0073e9SAndroid Build Coastguard Worker (2) dynamic_axes = {"x": [0], "y": [1]} 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker these will be converted to dynamic_shapes respectively: 830*da0073e9SAndroid Build Coastguard Worker (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} 831*da0073e9SAndroid Build Coastguard Worker (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker """ 834*da0073e9SAndroid Build Coastguard Worker if dynamic_axes is None: 835*da0073e9SAndroid Build Coastguard Worker return None 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker if input_names is None: 838*da0073e9SAndroid Build Coastguard Worker input_names_set = set() 839*da0073e9SAndroid Build Coastguard Worker else: 840*da0073e9SAndroid Build Coastguard Worker input_names_set = set(input_names) 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker dynamic_shapes: dict[str, Any | None] = {} 843*da0073e9SAndroid Build Coastguard Worker for input_name, axes in dynamic_axes.items(): 844*da0073e9SAndroid Build Coastguard Worker if input_name in input_names_set: 845*da0073e9SAndroid Build Coastguard Worker raise ValueError( 846*da0073e9SAndroid Build Coastguard Worker "Assinging new input names is not supported yet. Please use model forward signature " 847*da0073e9SAndroid Build Coastguard Worker "to specify input names in dynamix_axes." 848*da0073e9SAndroid Build Coastguard Worker ) 849*da0073e9SAndroid Build Coastguard Worker if isinstance(axes, dict): 850*da0073e9SAndroid Build Coastguard Worker dynamic_shapes[input_name] = { 851*da0073e9SAndroid Build Coastguard Worker k: torch.export.Dim(v) for k, v in axes.items() 852*da0073e9SAndroid Build Coastguard Worker } 853*da0073e9SAndroid Build Coastguard Worker elif isinstance(axes, list): 854*da0073e9SAndroid Build Coastguard Worker dynamic_shapes[input_name] = { 855*da0073e9SAndroid Build Coastguard Worker k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes 856*da0073e9SAndroid Build Coastguard Worker } 857*da0073e9SAndroid Build Coastguard Worker else: 858*da0073e9SAndroid Build Coastguard Worker raise TypeError( 859*da0073e9SAndroid Build Coastguard Worker f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" 860*da0073e9SAndroid Build Coastguard Worker ) 861*da0073e9SAndroid Build Coastguard Worker # torch.export.export needs static dim to present in dynamic_shapes 862*da0073e9SAndroid Build Coastguard Worker # for all input tensors, so we need to add them with None 863*da0073e9SAndroid Build Coastguard Worker try: 864*da0073e9SAndroid Build Coastguard Worker sig = _signature(model) 865*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 866*da0073e9SAndroid Build Coastguard Worker warnings.warn(f"{e}, skipping auto filling None on static axes...") 867*da0073e9SAndroid Build Coastguard Worker return dynamic_shapes 868*da0073e9SAndroid Build Coastguard Worker for input_name in sig.parameters.keys(): 869*da0073e9SAndroid Build Coastguard Worker if input_name not in dynamic_shapes: 870*da0073e9SAndroid Build Coastguard Worker dynamic_shapes[input_name] = None 871*da0073e9SAndroid Build Coastguard Worker return dynamic_shapes 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Workerdef _trace(func, args, operator_export_type, return_outs=False): 875*da0073e9SAndroid Build Coastguard Worker # Special case for common case of passing a single Tensor 876*da0073e9SAndroid Build Coastguard Worker if isinstance(args, torch.Tensor): 877*da0073e9SAndroid Build Coastguard Worker args = (args,) 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( 880*da0073e9SAndroid Build Coastguard Worker func, 881*da0073e9SAndroid Build Coastguard Worker args, 882*da0073e9SAndroid Build Coastguard Worker strict=False, 883*da0073e9SAndroid Build Coastguard Worker _force_outplace=False, 884*da0073e9SAndroid Build Coastguard Worker _return_inputs_states=True, 885*da0073e9SAndroid Build Coastguard Worker ) 886*da0073e9SAndroid Build Coastguard Worker warn_on_static_input_change(inputs_states) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) 889*da0073e9SAndroid Build Coastguard Worker if return_outs: 890*da0073e9SAndroid Build Coastguard Worker return trace_graph, torch_out 891*da0073e9SAndroid Build Coastguard Worker return trace_graph 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Workerdef _trace_and_get_graph_from_model(model, args): 895*da0073e9SAndroid Build Coastguard Worker # A basic sanity check: make sure the state_dict keys are the same 896*da0073e9SAndroid Build Coastguard Worker # before and after running the model. Fail fast! 897*da0073e9SAndroid Build Coastguard Worker orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker # Disable Autocast cache because it replaces kernel's weight and bias 900*da0073e9SAndroid Build Coastguard Worker # by (undesired) constants. 901*da0073e9SAndroid Build Coastguard Worker # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 902*da0073e9SAndroid Build Coastguard Worker prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() 903*da0073e9SAndroid Build Coastguard Worker torch.set_autocast_cache_enabled(False) 904*da0073e9SAndroid Build Coastguard Worker trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( 905*da0073e9SAndroid Build Coastguard Worker model, 906*da0073e9SAndroid Build Coastguard Worker args, 907*da0073e9SAndroid Build Coastguard Worker strict=False, 908*da0073e9SAndroid Build Coastguard Worker _force_outplace=False, 909*da0073e9SAndroid Build Coastguard Worker _return_inputs_states=True, 910*da0073e9SAndroid Build Coastguard Worker ) 911*da0073e9SAndroid Build Coastguard Worker torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker warn_on_static_input_change(inputs_states) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): 916*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 917*da0073e9SAndroid Build Coastguard Worker "state_dict changed after running the tracer; " 918*da0073e9SAndroid Build Coastguard Worker "something weird is happening in your model!" 919*da0073e9SAndroid Build Coastguard Worker ) 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker return trace_graph, torch_out 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Workerdef _get_param_count_list(method_graph, args_params): 925*da0073e9SAndroid Build Coastguard Worker param_count_list = [] 926*da0073e9SAndroid Build Coastguard Worker for input_, arg_params_ in zip(method_graph.inputs(), args_params): 927*da0073e9SAndroid Build Coastguard Worker if "PackedParams" in str(input_.type()): 928*da0073e9SAndroid Build Coastguard Worker in_vars, _ = torch.jit._flatten(arg_params_) 929*da0073e9SAndroid Build Coastguard Worker param_count_list.append(len(in_vars)) 930*da0073e9SAndroid Build Coastguard Worker else: 931*da0073e9SAndroid Build Coastguard Worker param_count_list.append(arg_params_ is not None) 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker return param_count_list 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Workerdef _check_flatten_did_not_remove(original, jit_flattened): 937*da0073e9SAndroid Build Coastguard Worker """torch.jit._flatten removes None. Check if it did so in this case.""" 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker def flatten(x): 940*da0073e9SAndroid Build Coastguard Worker if isinstance(x, (list, tuple)): 941*da0073e9SAndroid Build Coastguard Worker for inner in x: 942*da0073e9SAndroid Build Coastguard Worker yield from flatten(inner) 943*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, dict): 944*da0073e9SAndroid Build Coastguard Worker for inner in x.values(): 945*da0073e9SAndroid Build Coastguard Worker yield from flatten(inner) 946*da0073e9SAndroid Build Coastguard Worker else: 947*da0073e9SAndroid Build Coastguard Worker yield x 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker flattened_with_none = list(flatten(original)) 950*da0073e9SAndroid Build Coastguard Worker num_none = len(flattened_with_none) - len(jit_flattened) 951*da0073e9SAndroid Build Coastguard Worker assert num_none >= 0 952*da0073e9SAndroid Build Coastguard Worker if num_none: 953*da0073e9SAndroid Build Coastguard Worker raise ValueError( 954*da0073e9SAndroid Build Coastguard Worker f"args contained {num_none} None's after flattening. " 955*da0073e9SAndroid Build Coastguard Worker "When exporting a ScriptModule or ScriptFunction, no args may " 956*da0073e9SAndroid Build Coastguard Worker "be None because that breaks type propagation." 957*da0073e9SAndroid Build Coastguard Worker ) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Workerdef _create_jit_graph( 961*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] 962*da0073e9SAndroid Build Coastguard Worker) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: 963*da0073e9SAndroid Build Coastguard Worker if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): 964*da0073e9SAndroid Build Coastguard Worker flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) 965*da0073e9SAndroid Build Coastguard Worker _check_flatten_did_not_remove(args, flattened_args) 966*da0073e9SAndroid Build Coastguard Worker torch_out = None 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker if isinstance(model, torch.jit.ScriptModule): 969*da0073e9SAndroid Build Coastguard Worker try: 970*da0073e9SAndroid Build Coastguard Worker graph = model.forward.graph # type: ignore[attr-defined] 971*da0073e9SAndroid Build Coastguard Worker except AttributeError as e: 972*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("'forward' method must be a script method") from e 973*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_function_substitution(graph) 974*da0073e9SAndroid Build Coastguard Worker freezed_module = _C._freeze_module( 975*da0073e9SAndroid Build Coastguard Worker cast(_C.ScriptModule, model._c), preserveParameters=True 976*da0073e9SAndroid Build Coastguard Worker ) 977*da0073e9SAndroid Build Coastguard Worker module, params = _C._jit_onnx_list_model_parameters(freezed_module) 978*da0073e9SAndroid Build Coastguard Worker method_graph = module._get_method("forward").graph 979*da0073e9SAndroid Build Coastguard Worker args_params = tuple(args) + tuple(params) 980*da0073e9SAndroid Build Coastguard Worker param_count_list = _get_param_count_list(method_graph, args_params) 981*da0073e9SAndroid Build Coastguard Worker in_vars, _ = torch.jit._flatten(args_params) 982*da0073e9SAndroid Build Coastguard Worker graph = _C._propagate_and_assign_input_shapes( 983*da0073e9SAndroid Build Coastguard Worker method_graph, tuple(in_vars), param_count_list, False, False 984*da0073e9SAndroid Build Coastguard Worker ) 985*da0073e9SAndroid Build Coastguard Worker return graph, params, torch_out, module 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker # torch.jit.ScriptFunction 988*da0073e9SAndroid Build Coastguard Worker params = [] 989*da0073e9SAndroid Build Coastguard Worker graph = model.graph 990*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_function_substitution(graph) 991*da0073e9SAndroid Build Coastguard Worker param_count_list = _get_param_count_list(graph, args) 992*da0073e9SAndroid Build Coastguard Worker graph = _C._propagate_and_assign_input_shapes( 993*da0073e9SAndroid Build Coastguard Worker graph, flattened_args, param_count_list, False, False 994*da0073e9SAndroid Build Coastguard Worker ) 995*da0073e9SAndroid Build Coastguard Worker return graph, params, torch_out, None 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker graph, torch_out = _trace_and_get_graph_from_model(model, args) 998*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_lint(graph) 999*da0073e9SAndroid Build Coastguard Worker state_dict = torch.jit._unique_state_dict(model) 1000*da0073e9SAndroid Build Coastguard Worker params = list(state_dict.values()) 1001*da0073e9SAndroid Build Coastguard Worker graph_inputs = list(graph.inputs()) 1002*da0073e9SAndroid Build Coastguard Worker user_input_num = len(graph_inputs) - len(state_dict) 1003*da0073e9SAndroid Build Coastguard Worker param_names = list(state_dict.keys()) 1004*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(graph_inputs): 1005*da0073e9SAndroid Build Coastguard Worker if i >= user_input_num: 1006*da0073e9SAndroid Build Coastguard Worker inp.setDebugName(param_names[i - user_input_num]) 1007*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_function_substitution(graph) 1008*da0073e9SAndroid Build Coastguard Worker return graph, params, torch_out, None 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Workerdef _get_named_param_dict(graph, params): 1012*da0073e9SAndroid Build Coastguard Worker input_and_param_names = [val.debugName() for val in graph.inputs()] 1013*da0073e9SAndroid Build Coastguard Worker param_names = input_and_param_names[len(input_and_param_names) - len(params) :] 1014*da0073e9SAndroid Build Coastguard Worker _params_dict = dict(zip(param_names, params)) 1015*da0073e9SAndroid Build Coastguard Worker return _params_dict 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Workerdef _get_example_outputs(model, args): 1019*da0073e9SAndroid Build Coastguard Worker input_args = copy.deepcopy(args) 1020*da0073e9SAndroid Build Coastguard Worker input_kwargs = {} 1021*da0073e9SAndroid Build Coastguard Worker if input_args and isinstance(input_args[-1], dict): 1022*da0073e9SAndroid Build Coastguard Worker input_kwargs = input_args[-1] 1023*da0073e9SAndroid Build Coastguard Worker input_args = input_args[:-1] 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker example_outputs = model(*input_args, **input_kwargs) 1026*da0073e9SAndroid Build Coastguard Worker if isinstance(example_outputs, list): 1027*da0073e9SAndroid Build Coastguard Worker example_outputs = [example_outputs] 1028*da0073e9SAndroid Build Coastguard Worker elif not isinstance(example_outputs, tuple): 1029*da0073e9SAndroid Build Coastguard Worker example_outputs = (example_outputs,) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker return example_outputs 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker_qtype_vtype_map = { 1035*da0073e9SAndroid Build Coastguard Worker torch.quint8: torch.uint8, 1036*da0073e9SAndroid Build Coastguard Worker torch.qint8: torch.int8, 1037*da0073e9SAndroid Build Coastguard Worker torch.qint32: torch.int32, 1038*da0073e9SAndroid Build Coastguard Worker torch.quint4x2: torch.int8, 1039*da0073e9SAndroid Build Coastguard Worker} 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Workerdef unpack_quantized_tensor(value, cast_onnx_accepted=True): 1043*da0073e9SAndroid Build Coastguard Worker if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: 1044*da0073e9SAndroid Build Coastguard Worker q_value_dequantize = value.dequantize() 1045*da0073e9SAndroid Build Coastguard Worker q_scale = ( 1046*da0073e9SAndroid Build Coastguard Worker torch.tensor(value.q_scale(), dtype=torch.double) 1047*da0073e9SAndroid Build Coastguard Worker if cast_onnx_accepted 1048*da0073e9SAndroid Build Coastguard Worker else torch.tensor(value.q_scale(), dtype=torch.float32) 1049*da0073e9SAndroid Build Coastguard Worker ) 1050*da0073e9SAndroid Build Coastguard Worker q_zero_point = ( 1051*da0073e9SAndroid Build Coastguard Worker torch.tensor(value.q_zero_point(), dtype=torch.int64) 1052*da0073e9SAndroid Build Coastguard Worker if cast_onnx_accepted 1053*da0073e9SAndroid Build Coastguard Worker else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) 1054*da0073e9SAndroid Build Coastguard Worker ) 1055*da0073e9SAndroid Build Coastguard Worker q_value = q_value_dequantize / q_scale + q_zero_point 1056*da0073e9SAndroid Build Coastguard Worker q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) 1057*da0073e9SAndroid Build Coastguard Worker return q_value, q_scale, q_zero_point 1058*da0073e9SAndroid Build Coastguard Worker else: 1059*da0073e9SAndroid Build Coastguard Worker return (value,) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Workerdef _pre_trace_quant_model(model, args): 1063*da0073e9SAndroid Build Coastguard Worker r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return 1064*da0073e9SAndroid Build Coastguard Worker original model. 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker This is due to https://github.com/pytorch/pytorch/issues/75761. 1067*da0073e9SAndroid Build Coastguard Worker """ 1068*da0073e9SAndroid Build Coastguard Worker if any( 1069*da0073e9SAndroid Build Coastguard Worker hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() 1070*da0073e9SAndroid Build Coastguard Worker ) or any(getattr(arg, "is_quantized", False) for arg in args): 1071*da0073e9SAndroid Build Coastguard Worker return torch.jit.trace(model, args) 1072*da0073e9SAndroid Build Coastguard Worker return model 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Workerdef _model_to_graph( 1076*da0073e9SAndroid Build Coastguard Worker model, 1077*da0073e9SAndroid Build Coastguard Worker args, 1078*da0073e9SAndroid Build Coastguard Worker verbose=False, 1079*da0073e9SAndroid Build Coastguard Worker input_names=None, 1080*da0073e9SAndroid Build Coastguard Worker output_names=None, 1081*da0073e9SAndroid Build Coastguard Worker operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1082*da0073e9SAndroid Build Coastguard Worker do_constant_folding=True, 1083*da0073e9SAndroid Build Coastguard Worker _disable_torch_constant_prop=False, 1084*da0073e9SAndroid Build Coastguard Worker fixed_batch_size=False, 1085*da0073e9SAndroid Build Coastguard Worker training=_C_onnx.TrainingMode.EVAL, 1086*da0073e9SAndroid Build Coastguard Worker dynamic_axes=None, 1087*da0073e9SAndroid Build Coastguard Worker) -> tuple[ 1088*da0073e9SAndroid Build Coastguard Worker _C.Graph, 1089*da0073e9SAndroid Build Coastguard Worker dict[str, torch.Tensor], 1090*da0073e9SAndroid Build Coastguard Worker torch.Tensor 1091*da0073e9SAndroid Build Coastguard Worker | tuple[torch.Tensor, ...] 1092*da0073e9SAndroid Build Coastguard Worker | list[torch.Tensor] 1093*da0073e9SAndroid Build Coastguard Worker | dict[str, torch.Tensor] 1094*da0073e9SAndroid Build Coastguard Worker | Any 1095*da0073e9SAndroid Build Coastguard Worker | None, 1096*da0073e9SAndroid Build Coastguard Worker]: 1097*da0073e9SAndroid Build Coastguard Worker """Converts model into an ONNX graph. 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker Returns: 1100*da0073e9SAndroid Build Coastguard Worker graph: A TorchScript IR Graph with ONNX nodes. 1101*da0073e9SAndroid Build Coastguard Worker params_dict: Dict from input param name to param value. 1102*da0073e9SAndroid Build Coastguard Worker torch_out: The output tensors resulting from the trace of ``model``. 1103*da0073e9SAndroid Build Coastguard Worker If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, 1104*da0073e9SAndroid Build Coastguard Worker this will be None, since we are not doing any tracing. 1105*da0073e9SAndroid Build Coastguard Worker """ 1106*da0073e9SAndroid Build Coastguard Worker # TODO: can we simplify this to always return a tuple of Tensor or None? 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker # Special case for common case of passing a single Tensor 1109*da0073e9SAndroid Build Coastguard Worker if isinstance(args, (torch.Tensor, int, float, bool)): 1110*da0073e9SAndroid Build Coastguard Worker args = (args,) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker model = _pre_trace_quant_model(model, args) 1113*da0073e9SAndroid Build Coastguard Worker graph, params, torch_out, module = _create_jit_graph(model, args) 1114*da0073e9SAndroid Build Coastguard Worker params_dict = _get_named_param_dict(graph, params) 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker try: 1117*da0073e9SAndroid Build Coastguard Worker graph = _optimize_graph( 1118*da0073e9SAndroid Build Coastguard Worker graph, 1119*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1120*da0073e9SAndroid Build Coastguard Worker _disable_torch_constant_prop=_disable_torch_constant_prop, 1121*da0073e9SAndroid Build Coastguard Worker fixed_batch_size=fixed_batch_size, 1122*da0073e9SAndroid Build Coastguard Worker params_dict=params_dict, 1123*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 1124*da0073e9SAndroid Build Coastguard Worker input_names=input_names, 1125*da0073e9SAndroid Build Coastguard Worker module=module, 1126*da0073e9SAndroid Build Coastguard Worker ) 1127*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1128*da0073e9SAndroid Build Coastguard Worker torch.onnx.log("Torch IR graph at exception: ", graph) 1129*da0073e9SAndroid Build Coastguard Worker raise 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) 1132*da0073e9SAndroid Build Coastguard Worker if is_script: 1133*da0073e9SAndroid Build Coastguard Worker example_outputs = _get_example_outputs(model, args) 1134*da0073e9SAndroid Build Coastguard Worker example_outputs_final = () 1135*da0073e9SAndroid Build Coastguard Worker for example_output in example_outputs: 1136*da0073e9SAndroid Build Coastguard Worker example_outputs_final += unpack_quantized_tensor(example_output) 1137*da0073e9SAndroid Build Coastguard Worker out_vars, desc = torch.jit._flatten(example_outputs_final) 1138*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_assign_output_shape( 1139*da0073e9SAndroid Build Coastguard Worker graph, 1140*da0073e9SAndroid Build Coastguard Worker out_vars, 1141*da0073e9SAndroid Build Coastguard Worker desc, 1142*da0073e9SAndroid Build Coastguard Worker GLOBALS.onnx_shape_inference, 1143*da0073e9SAndroid Build Coastguard Worker is_script, 1144*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version, 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker # NB: ONNX requires complete information about output types, which might be 1148*da0073e9SAndroid Build Coastguard Worker # erased by some optimizations, so we need to set it explicitly again. 1149*da0073e9SAndroid Build Coastguard Worker else: 1150*da0073e9SAndroid Build Coastguard Worker if not isinstance(torch_out, (list, tuple)): 1151*da0073e9SAndroid Build Coastguard Worker output_wrapped = [torch_out] 1152*da0073e9SAndroid Build Coastguard Worker else: 1153*da0073e9SAndroid Build Coastguard Worker output_wrapped = torch_out # type: ignore[assignment] 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) 1156*da0073e9SAndroid Build Coastguard Worker # assign_output_shape pass is not compatible with quantized outputs. 1157*da0073e9SAndroid Build Coastguard Worker # Quantized outputs are flattened to 3 values in ONNX, while packed as 1158*da0073e9SAndroid Build Coastguard Worker # single value in PyTorch. 1159*da0073e9SAndroid Build Coastguard Worker if not any(getattr(out, "is_quantized", False) for out in output_tensors): 1160*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_assign_output_shape( 1161*da0073e9SAndroid Build Coastguard Worker graph, 1162*da0073e9SAndroid Build Coastguard Worker output_tensors, 1163*da0073e9SAndroid Build Coastguard Worker out_desc, 1164*da0073e9SAndroid Build Coastguard Worker GLOBALS.onnx_shape_inference, 1165*da0073e9SAndroid Build Coastguard Worker is_script, 1166*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version, 1167*da0073e9SAndroid Build Coastguard Worker ) 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker _set_input_and_output_names(graph, input_names, output_names) 1170*da0073e9SAndroid Build Coastguard Worker params_dict = _get_named_param_dict(graph, params) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker if ( 1173*da0073e9SAndroid Build Coastguard Worker do_constant_folding 1174*da0073e9SAndroid Build Coastguard Worker and GLOBALS.export_onnx_opset_version 1175*da0073e9SAndroid Build Coastguard Worker >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET 1176*da0073e9SAndroid Build Coastguard Worker ): 1177*da0073e9SAndroid Build Coastguard Worker if training is None or training == _C_onnx.TrainingMode.EVAL: 1178*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_onnx_constant_fold( 1181*da0073e9SAndroid Build Coastguard Worker graph, params_dict, GLOBALS.export_onnx_opset_version 1182*da0073e9SAndroid Build Coastguard Worker ) 1183*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker if GLOBALS.onnx_shape_inference: 1186*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_graph_shape_type_inference( 1187*da0073e9SAndroid Build Coastguard Worker graph, params_dict, GLOBALS.export_onnx_opset_version 1188*da0073e9SAndroid Build Coastguard Worker ) 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker # For ONNX opset < 9, constants only have three data types: float16, float, double. 1193*da0073e9SAndroid Build Coastguard Worker # In this pass transform constants of other data types to float/double + cast operator. 1194*da0073e9SAndroid Build Coastguard Worker if GLOBALS.export_onnx_opset_version < 9: 1195*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_cast_all_constant_to_floating(graph) 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) 1198*da0073e9SAndroid Build Coastguard Worker _C._jit_decay_packed_param_input_types(graph) 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker # If output names lack a proper name and are identified only by their unique 1201*da0073e9SAndroid Build Coastguard Worker # give them a legible name for debugging purposes 1202*da0073e9SAndroid Build Coastguard Worker _apply_friendly_debug_names(graph, params_dict) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker return graph, params_dict, torch_out 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker@torch._disable_dynamo 1208*da0073e9SAndroid Build Coastguard Worker@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead") 1209*da0073e9SAndroid Build Coastguard Workerdef export_to_pretty_string( 1210*da0073e9SAndroid Build Coastguard Worker model, 1211*da0073e9SAndroid Build Coastguard Worker args, 1212*da0073e9SAndroid Build Coastguard Worker export_params=True, 1213*da0073e9SAndroid Build Coastguard Worker verbose=False, 1214*da0073e9SAndroid Build Coastguard Worker training=_C_onnx.TrainingMode.EVAL, 1215*da0073e9SAndroid Build Coastguard Worker input_names=None, 1216*da0073e9SAndroid Build Coastguard Worker output_names=None, 1217*da0073e9SAndroid Build Coastguard Worker operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1218*da0073e9SAndroid Build Coastguard Worker export_type=None, 1219*da0073e9SAndroid Build Coastguard Worker google_printer=False, 1220*da0073e9SAndroid Build Coastguard Worker opset_version=None, 1221*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=None, 1222*da0073e9SAndroid Build Coastguard Worker custom_opsets=None, 1223*da0073e9SAndroid Build Coastguard Worker add_node_names=True, 1224*da0073e9SAndroid Build Coastguard Worker do_constant_folding=True, 1225*da0073e9SAndroid Build Coastguard Worker dynamic_axes=None, 1226*da0073e9SAndroid Build Coastguard Worker): 1227*da0073e9SAndroid Build Coastguard Worker """Similar to :func:`export`, but returns a text representation of the ONNX model. 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker Only differences in args listed below. All other args are the same 1230*da0073e9SAndroid Build Coastguard Worker as :func:`export`. 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker Args: 1233*da0073e9SAndroid Build Coastguard Worker add_node_names (bool, default True): Whether or not to set 1234*da0073e9SAndroid Build Coastguard Worker NodeProto.name. This makes no difference unless 1235*da0073e9SAndroid Build Coastguard Worker ``google_printer=True``. 1236*da0073e9SAndroid Build Coastguard Worker google_printer (bool, default False): If False, will return a custom, 1237*da0073e9SAndroid Build Coastguard Worker compact representation of the model. If True will return the 1238*da0073e9SAndroid Build Coastguard Worker protobuf's `Message::DebugString()`, which is more verbose. 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker Returns: 1241*da0073e9SAndroid Build Coastguard Worker A UTF-8 str containing a human-readable representation of the ONNX model. 1242*da0073e9SAndroid Build Coastguard Worker """ 1243*da0073e9SAndroid Build Coastguard Worker if opset_version is None: 1244*da0073e9SAndroid Build Coastguard Worker opset_version = _constants.ONNX_DEFAULT_OPSET 1245*da0073e9SAndroid Build Coastguard Worker if custom_opsets is None: 1246*da0073e9SAndroid Build Coastguard Worker custom_opsets = {} 1247*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version = opset_version 1248*da0073e9SAndroid Build Coastguard Worker GLOBALS.operator_export_type = operator_export_type 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker with exporter_context(model, training, verbose): 1251*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip = _decide_keep_init_as_input( 1252*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs, operator_export_type, opset_version 1253*da0073e9SAndroid Build Coastguard Worker ) 1254*da0073e9SAndroid Build Coastguard Worker val_add_node_names = _decide_add_node_names( 1255*da0073e9SAndroid Build Coastguard Worker add_node_names, operator_export_type 1256*da0073e9SAndroid Build Coastguard Worker ) 1257*da0073e9SAndroid Build Coastguard Worker val_do_constant_folding = _decide_constant_folding( 1258*da0073e9SAndroid Build Coastguard Worker do_constant_folding, operator_export_type, training 1259*da0073e9SAndroid Build Coastguard Worker ) 1260*da0073e9SAndroid Build Coastguard Worker args = _decide_input_format(model, args) 1261*da0073e9SAndroid Build Coastguard Worker graph, params_dict, torch_out = _model_to_graph( 1262*da0073e9SAndroid Build Coastguard Worker model, 1263*da0073e9SAndroid Build Coastguard Worker args, 1264*da0073e9SAndroid Build Coastguard Worker verbose, 1265*da0073e9SAndroid Build Coastguard Worker input_names, 1266*da0073e9SAndroid Build Coastguard Worker output_names, 1267*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1268*da0073e9SAndroid Build Coastguard Worker val_do_constant_folding, 1269*da0073e9SAndroid Build Coastguard Worker training=training, 1270*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 1271*da0073e9SAndroid Build Coastguard Worker ) 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker return graph._pretty_print_onnx( # type: ignore[attr-defined] 1274*da0073e9SAndroid Build Coastguard Worker params_dict, 1275*da0073e9SAndroid Build Coastguard Worker opset_version, 1276*da0073e9SAndroid Build Coastguard Worker False, 1277*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1278*da0073e9SAndroid Build Coastguard Worker google_printer, 1279*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip, 1280*da0073e9SAndroid Build Coastguard Worker custom_opsets, 1281*da0073e9SAndroid Build Coastguard Worker val_add_node_names, 1282*da0073e9SAndroid Build Coastguard Worker ) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker@_deprecation.deprecated("2.5", "the future", "avoid using this function") 1286*da0073e9SAndroid Build Coastguard Workerdef unconvertible_ops( 1287*da0073e9SAndroid Build Coastguard Worker model, 1288*da0073e9SAndroid Build Coastguard Worker args, 1289*da0073e9SAndroid Build Coastguard Worker training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 1290*da0073e9SAndroid Build Coastguard Worker opset_version: int | None = None, 1291*da0073e9SAndroid Build Coastguard Worker) -> tuple[_C.Graph, list[str]]: 1292*da0073e9SAndroid Build Coastguard Worker """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker The list is approximated because some ops may be removed during the conversion 1295*da0073e9SAndroid Build Coastguard Worker process and don't need to be converted. Some other ops may have partial support 1296*da0073e9SAndroid Build Coastguard Worker that will fail conversion with particular inputs. Please open a Github Issue 1297*da0073e9SAndroid Build Coastguard Worker for op support requests. 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker Args: 1300*da0073e9SAndroid Build Coastguard Worker model: Same as the `model` parameter in :func:`torch.onnx.export`. 1301*da0073e9SAndroid Build Coastguard Worker args: Same as the `args` parameter in :func:`torch.onnx.export`. 1302*da0073e9SAndroid Build Coastguard Worker training: Same as the `training` parameter in :func:`torch.onnx.export`. 1303*da0073e9SAndroid Build Coastguard Worker opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker Returns: 1306*da0073e9SAndroid Build Coastguard Worker The JIT graph and a list of unconvertible ops in the format of "domain::op". 1307*da0073e9SAndroid Build Coastguard Worker """ 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET 1310*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version = opset_version 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker try: 1313*da0073e9SAndroid Build Coastguard Worker with exporter_context(model, training, verbose=False): 1314*da0073e9SAndroid Build Coastguard Worker # Create a mostly clean JIT graph that contains the plain aten and 1315*da0073e9SAndroid Build Coastguard Worker # other ops we can check with the symbolic registry. 1316*da0073e9SAndroid Build Coastguard Worker # NOTE: We don't want to actually convert any ops to ONNX or run any 1317*da0073e9SAndroid Build Coastguard Worker # symbolic functions because there is a higher chance that a pass 1318*da0073e9SAndroid Build Coastguard Worker # fails or an unconvertible op messes up the graph during ONNX conversion. 1319*da0073e9SAndroid Build Coastguard Worker # This way we can always generate a list just by looking at the names 1320*da0073e9SAndroid Build Coastguard Worker # of the ops in the graph. 1321*da0073e9SAndroid Build Coastguard Worker args = _decide_input_format(model, args) 1322*da0073e9SAndroid Build Coastguard Worker model = _pre_trace_quant_model(model, args) 1323*da0073e9SAndroid Build Coastguard Worker graph, _, _, module = _create_jit_graph(model, args) 1324*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_inline(graph) 1325*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) 1326*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_erase_number_types(graph) 1327*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 1328*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1329*da0073e9SAndroid Build Coastguard Worker raise errors.OnnxExporterError( 1330*da0073e9SAndroid Build Coastguard Worker "Failed to discover unconvertible ops because of errors during the JIT graph " 1331*da0073e9SAndroid Build Coastguard Worker "generation process." 1332*da0073e9SAndroid Build Coastguard Worker ) from e 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker unsupported_ops = [] 1335*da0073e9SAndroid Build Coastguard Worker for node in graph.nodes(): 1336*da0073e9SAndroid Build Coastguard Worker domain_op = node.kind() 1337*da0073e9SAndroid Build Coastguard Worker if domain_op.startswith(("onnx::", "prim::")): 1338*da0073e9SAndroid Build Coastguard Worker # We consider onnx and prim ops as supported ops, even though some "prim" 1339*da0073e9SAndroid Build Coastguard Worker # ops are not implemented as symbolic functions, because they may be 1340*da0073e9SAndroid Build Coastguard Worker # eliminated in the conversion passes. Users may still see errors caused 1341*da0073e9SAndroid Build Coastguard Worker # by prim ops even though they don't show up in the list. 1342*da0073e9SAndroid Build Coastguard Worker continue 1343*da0073e9SAndroid Build Coastguard Worker if not registration.registry.is_registered_op( 1344*da0073e9SAndroid Build Coastguard Worker domain_op.rstrip("_"), opset_version 1345*da0073e9SAndroid Build Coastguard Worker ): 1346*da0073e9SAndroid Build Coastguard Worker # We consider all registered ops supported, even though some of them are 1347*da0073e9SAndroid Build Coastguard Worker # only partially supported, because there is not yet a good way to check 1348*da0073e9SAndroid Build Coastguard Worker # if an op is fully supported. 1349*da0073e9SAndroid Build Coastguard Worker # TODO(justinchuby): Create a way to check if an op is fully supported. 1350*da0073e9SAndroid Build Coastguard Worker unsupported_ops.append(domain_op) 1351*da0073e9SAndroid Build Coastguard Worker return graph, unsupported_ops 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Workerdef _setup_trace_module_map( 1355*da0073e9SAndroid Build Coastguard Worker model: torch.nn.Module | torch.jit.ScriptModule, 1356*da0073e9SAndroid Build Coastguard Worker export_modules_as_functions: bool | Collection[type[torch.nn.Module]], 1357*da0073e9SAndroid Build Coastguard Worker) -> set[str]: 1358*da0073e9SAndroid Build Coastguard Worker def __register_attribute_hook(): 1359*da0073e9SAndroid Build Coastguard Worker attr_name = "_onnx_attrs" 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker def _track_module_attributes_forward_pre_hook(module, input): 1362*da0073e9SAndroid Build Coastguard Worker setattr(module, attr_name, _get_module_attributes(module)) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker def _track_module_attributes_forward_hook(module, input, output): 1365*da0073e9SAndroid Build Coastguard Worker tracing_state = _C._get_tracing_state() 1366*da0073e9SAndroid Build Coastguard Worker if not tracing_state: 1367*da0073e9SAndroid Build Coastguard Worker return 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker graph = tracing_state.graph() 1370*da0073e9SAndroid Build Coastguard Worker onnx_attrs = {} 1371*da0073e9SAndroid Build Coastguard Worker if hasattr(module, attr_name): 1372*da0073e9SAndroid Build Coastguard Worker onnx_attrs = getattr(module, attr_name) 1373*da0073e9SAndroid Build Coastguard Worker delattr(module, attr_name) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker for m in model.modules(): 1378*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(_track_module_attributes_forward_hook) 1379*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker def _unqualified_variable_name(qualified_name: str) -> str: 1382*da0073e9SAndroid Build Coastguard Worker """ 1383*da0073e9SAndroid Build Coastguard Worker Parse qualified variable name and return the unqualified version. 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker Pure numeric atoms are considered inadequate, so this function will look past them, 1386*da0073e9SAndroid Build Coastguard Worker and start from the first non-numeric atom. 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker Example: 1389*da0073e9SAndroid Build Coastguard Worker >>> _unqualified_variable_name("__main__.Foo.bar") 1390*da0073e9SAndroid Build Coastguard Worker 'bar' 1391*da0073e9SAndroid Build Coastguard Worker >>> _unqualified_variable_name("__main__.Foo.bar.0") 1392*da0073e9SAndroid Build Coastguard Worker 'bar.0' 1393*da0073e9SAndroid Build Coastguard Worker """ 1394*da0073e9SAndroid Build Coastguard Worker name_atoms = qualified_name.split(".") 1395*da0073e9SAndroid Build Coastguard Worker for i, atom in reversed(list(enumerate(name_atoms))): 1396*da0073e9SAndroid Build Coastguard Worker if not atom.isnumeric(): 1397*da0073e9SAndroid Build Coastguard Worker return ".".join(name_atoms[i:]) 1398*da0073e9SAndroid Build Coastguard Worker return qualified_name 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker trace_module_map = { 1401*da0073e9SAndroid Build Coastguard Worker _m: torch._C._jit_onnx_create_full_scope_name( 1402*da0073e9SAndroid Build Coastguard Worker torch.typename(type(_m)), _unqualified_variable_name(_n) 1403*da0073e9SAndroid Build Coastguard Worker ) 1404*da0073e9SAndroid Build Coastguard Worker for _n, _m in model.named_modules() 1405*da0073e9SAndroid Build Coastguard Worker } 1406*da0073e9SAndroid Build Coastguard Worker torch.jit._trace._trace_module_map = trace_module_map 1407*da0073e9SAndroid Build Coastguard Worker if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: 1408*da0073e9SAndroid Build Coastguard Worker module_typenames = {torch.typename(type(module)) for module in trace_module_map} 1409*da0073e9SAndroid Build Coastguard Worker elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker def _find_typename(v): 1412*da0073e9SAndroid Build Coastguard Worker if isinstance(v, type): 1413*da0073e9SAndroid Build Coastguard Worker return torch.typename(v) 1414*da0073e9SAndroid Build Coastguard Worker else: 1415*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1416*da0073e9SAndroid Build Coastguard Worker "Only type of the `nn.Module` should be " 1417*da0073e9SAndroid Build Coastguard Worker "passed in the set for argument `export_modules_as_functions`. " 1418*da0073e9SAndroid Build Coastguard Worker f"Got `{type(v).__name__}`." 1419*da0073e9SAndroid Build Coastguard Worker ) 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker module_typenames = {_find_typename(v) for v in export_modules_as_functions} 1422*da0073e9SAndroid Build Coastguard Worker else: 1423*da0073e9SAndroid Build Coastguard Worker module_typenames = set() 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker if module_typenames: 1426*da0073e9SAndroid Build Coastguard Worker __register_attribute_hook() 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker return module_typenames 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Workerdef _reset_trace_module_map(): 1432*da0073e9SAndroid Build Coastguard Worker torch.jit._trace._trace_module_map = None 1433*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_clear_scope_records() 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Workerdef _get_module_attributes(module): 1437*da0073e9SAndroid Build Coastguard Worker annotations = typing.get_type_hints(type(module)) 1438*da0073e9SAndroid Build Coastguard Worker base_m_annotations = typing.get_type_hints(torch.nn.Module) 1439*da0073e9SAndroid Build Coastguard Worker [annotations.pop(k, None) for k in base_m_annotations] 1440*da0073e9SAndroid Build Coastguard Worker # Check whether module attributes can be accessed. Some classes 1441*da0073e9SAndroid Build Coastguard Worker # define attributes but don't provide access to them in their 1442*da0073e9SAndroid Build Coastguard Worker # constructor. 1443*da0073e9SAndroid Build Coastguard Worker # 1444*da0073e9SAndroid Build Coastguard Worker # For example, torch.nn.Embedding has the `freeze` variable and its 1445*da0073e9SAndroid Build Coastguard Worker # type specified in the class but the attribute is not created in the 1446*da0073e9SAndroid Build Coastguard Worker # constructor. In other words, there is no `self.freeze = <True | False>` 1447*da0073e9SAndroid Build Coastguard Worker # in the constructor. 1448*da0073e9SAndroid Build Coastguard Worker # 1449*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 1450*da0073e9SAndroid Build Coastguard Worker attrs = {} 1451*da0073e9SAndroid Build Coastguard Worker for k in annotations: 1452*da0073e9SAndroid Build Coastguard Worker try: 1453*da0073e9SAndroid Build Coastguard Worker attrs[k] = getattr(module, k) 1454*da0073e9SAndroid Build Coastguard Worker except AttributeError: 1455*da0073e9SAndroid Build Coastguard Worker torch.onnx.log(f"Skipping module attribute '{k}'") 1456*da0073e9SAndroid Build Coastguard Worker continue 1457*da0073e9SAndroid Build Coastguard Worker return attrs 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Workerdef _export( 1461*da0073e9SAndroid Build Coastguard Worker model, 1462*da0073e9SAndroid Build Coastguard Worker args, 1463*da0073e9SAndroid Build Coastguard Worker f, 1464*da0073e9SAndroid Build Coastguard Worker export_params=True, 1465*da0073e9SAndroid Build Coastguard Worker verbose=False, 1466*da0073e9SAndroid Build Coastguard Worker training=_C_onnx.TrainingMode.EVAL, 1467*da0073e9SAndroid Build Coastguard Worker input_names=None, 1468*da0073e9SAndroid Build Coastguard Worker output_names=None, 1469*da0073e9SAndroid Build Coastguard Worker operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1470*da0073e9SAndroid Build Coastguard Worker export_type=None, 1471*da0073e9SAndroid Build Coastguard Worker opset_version=None, 1472*da0073e9SAndroid Build Coastguard Worker do_constant_folding=True, 1473*da0073e9SAndroid Build Coastguard Worker dynamic_axes=None, 1474*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=None, 1475*da0073e9SAndroid Build Coastguard Worker fixed_batch_size=False, 1476*da0073e9SAndroid Build Coastguard Worker custom_opsets=None, 1477*da0073e9SAndroid Build Coastguard Worker add_node_names=True, 1478*da0073e9SAndroid Build Coastguard Worker onnx_shape_inference=True, 1479*da0073e9SAndroid Build Coastguard Worker export_modules_as_functions: Any = False, 1480*da0073e9SAndroid Build Coastguard Worker autograd_inlining=True, 1481*da0073e9SAndroid Build Coastguard Worker): 1482*da0073e9SAndroid Build Coastguard Worker assert GLOBALS.in_onnx_export is False 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker if export_type is None: 1485*da0073e9SAndroid Build Coastguard Worker export_type = _exporter_states.ExportTypes.PROTOBUF_FILE 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker if isinstance(model, torch.nn.DataParallel): 1488*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1489*da0073e9SAndroid Build Coastguard Worker "torch.nn.DataParallel is not supported by ONNX " 1490*da0073e9SAndroid Build Coastguard Worker "exporter, please use 'attribute' module to " 1491*da0073e9SAndroid Build Coastguard Worker "unwrap model from torch.nn.DataParallel. Try " 1492*da0073e9SAndroid Build Coastguard Worker "torch.onnx.export(model.module, ...)" 1493*da0073e9SAndroid Build Coastguard Worker ) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker GLOBALS.onnx_shape_inference = onnx_shape_inference 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker if opset_version is None: 1498*da0073e9SAndroid Build Coastguard Worker opset_version = _constants.ONNX_DEFAULT_OPSET 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker # torch.onnx.export does not support opset versions >=18 1501*da0073e9SAndroid Build Coastguard Worker if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: 1502*da0073e9SAndroid Build Coastguard Worker # We do not want to fail because we should still allow users to create 1503*da0073e9SAndroid Build Coastguard Worker # custom symbolic functions for opset>17 1504*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1505*da0073e9SAndroid Build Coastguard Worker f"Exporting to ONNX opset version {opset_version} is not supported. " 1506*da0073e9SAndroid Build Coastguard Worker f"by 'torch.onnx.export()'. " 1507*da0073e9SAndroid Build Coastguard Worker f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " 1508*da0073e9SAndroid Build Coastguard Worker f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", 1509*da0073e9SAndroid Build Coastguard Worker category=errors.OnnxExporterWarning, 1510*da0073e9SAndroid Build Coastguard Worker ) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker if export_modules_as_functions and opset_version < 15: 1513*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1514*da0073e9SAndroid Build Coastguard Worker "`export_modules_as_functions` is not supported for `opset_version` < 15." 1515*da0073e9SAndroid Build Coastguard Worker "This is because `opset_version` < 15 implies IR version < 8, which means " 1516*da0073e9SAndroid Build Coastguard Worker "no local function support. " 1517*da0073e9SAndroid Build Coastguard Worker ) 1518*da0073e9SAndroid Build Coastguard Worker if not operator_export_type: 1519*da0073e9SAndroid Build Coastguard Worker operator_export_type = _C_onnx.OperatorExportTypes.ONNX 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker # By default, training=TrainingMode.EVAL, 1522*da0073e9SAndroid Build Coastguard Worker # which is good because running a model in training mode could result in 1523*da0073e9SAndroid Build Coastguard Worker # internal buffers getting updated, dropout getting applied, etc. 1524*da0073e9SAndroid Build Coastguard Worker # If you really know what you're doing, you can turn 1525*da0073e9SAndroid Build Coastguard Worker # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, 1526*da0073e9SAndroid Build Coastguard Worker # (to preserve whatever the original training mode was.) 1527*da0073e9SAndroid Build Coastguard Worker GLOBALS.export_onnx_opset_version = opset_version 1528*da0073e9SAndroid Build Coastguard Worker GLOBALS.operator_export_type = operator_export_type 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker try: 1531*da0073e9SAndroid Build Coastguard Worker GLOBALS.in_onnx_export = True 1532*da0073e9SAndroid Build Coastguard Worker _autograd_inlining_previous = GLOBALS.autograd_inlining 1533*da0073e9SAndroid Build Coastguard Worker GLOBALS.autograd_inlining = autograd_inlining 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker module_typenames_to_export_as_functions: set[str] = set() 1536*da0073e9SAndroid Build Coastguard Worker if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): 1537*da0073e9SAndroid Build Coastguard Worker module_typenames_to_export_as_functions = _setup_trace_module_map( 1538*da0073e9SAndroid Build Coastguard Worker model, export_modules_as_functions 1539*da0073e9SAndroid Build Coastguard Worker ) 1540*da0073e9SAndroid Build Coastguard Worker 1541*da0073e9SAndroid Build Coastguard Worker with exporter_context(model, training, verbose): 1542*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip = _decide_keep_init_as_input( 1543*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs, 1544*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1545*da0073e9SAndroid Build Coastguard Worker opset_version, 1546*da0073e9SAndroid Build Coastguard Worker ) 1547*da0073e9SAndroid Build Coastguard Worker val_add_node_names = _decide_add_node_names( 1548*da0073e9SAndroid Build Coastguard Worker add_node_names, operator_export_type 1549*da0073e9SAndroid Build Coastguard Worker ) 1550*da0073e9SAndroid Build Coastguard Worker val_do_constant_folding = _decide_constant_folding( 1551*da0073e9SAndroid Build Coastguard Worker do_constant_folding, operator_export_type, training 1552*da0073e9SAndroid Build Coastguard Worker ) 1553*da0073e9SAndroid Build Coastguard Worker # Normally f can be a file-like object, but for large models, the external data format requires a 1554*da0073e9SAndroid Build Coastguard Worker # valid `model_file_location`. Code in export.cpp will enforce this. 1555*da0073e9SAndroid Build Coastguard Worker if isinstance(f, str): 1556*da0073e9SAndroid Build Coastguard Worker model_file_location = f 1557*da0073e9SAndroid Build Coastguard Worker else: 1558*da0073e9SAndroid Build Coastguard Worker model_file_location = "" 1559*da0073e9SAndroid Build Coastguard Worker args = _decide_input_format(model, args) 1560*da0073e9SAndroid Build Coastguard Worker if dynamic_axes is None: 1561*da0073e9SAndroid Build Coastguard Worker dynamic_axes = {} 1562*da0073e9SAndroid Build Coastguard Worker _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker graph, params_dict, torch_out = _model_to_graph( 1565*da0073e9SAndroid Build Coastguard Worker model, 1566*da0073e9SAndroid Build Coastguard Worker args, 1567*da0073e9SAndroid Build Coastguard Worker verbose, 1568*da0073e9SAndroid Build Coastguard Worker input_names, 1569*da0073e9SAndroid Build Coastguard Worker output_names, 1570*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1571*da0073e9SAndroid Build Coastguard Worker val_do_constant_folding, 1572*da0073e9SAndroid Build Coastguard Worker fixed_batch_size=fixed_batch_size, 1573*da0073e9SAndroid Build Coastguard Worker training=training, 1574*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 1575*da0073e9SAndroid Build Coastguard Worker ) 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker # TODO: Don't allocate a in-memory string for the protobuf 1578*da0073e9SAndroid Build Coastguard Worker defer_weight_export = ( 1579*da0073e9SAndroid Build Coastguard Worker export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE 1580*da0073e9SAndroid Build Coastguard Worker ) 1581*da0073e9SAndroid Build Coastguard Worker if custom_opsets is None: 1582*da0073e9SAndroid Build Coastguard Worker custom_opsets = {} 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 1585*da0073e9SAndroid Build Coastguard Worker node_attr_to_name = {} # type: ignore[var-annotated] 1586*da0073e9SAndroid Build Coastguard Worker if module_typenames_to_export_as_functions: 1587*da0073e9SAndroid Build Coastguard Worker # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. 1588*da0073e9SAndroid Build Coastguard Worker node_attr_to_name = _C._jit_pass_onnx_function_extraction( 1589*da0073e9SAndroid Build Coastguard Worker graph, 1590*da0073e9SAndroid Build Coastguard Worker module_typenames_to_export_as_functions, 1591*da0073e9SAndroid Build Coastguard Worker list(params_dict.keys()), 1592*da0073e9SAndroid Build Coastguard Worker ) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker if keep_initializers_as_inputs is not True: 1595*da0073e9SAndroid Build Coastguard Worker params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] 1596*da0073e9SAndroid Build Coastguard Worker graph, 1597*da0073e9SAndroid Build Coastguard Worker params_dict, # type: ignore[arg-type] 1598*da0073e9SAndroid Build Coastguard Worker getattr(model, "training", False), # type: ignore[arg-type] 1599*da0073e9SAndroid Build Coastguard Worker ) 1600*da0073e9SAndroid Build Coastguard Worker _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) 1601*da0073e9SAndroid Build Coastguard Worker if export_params: 1602*da0073e9SAndroid Build Coastguard Worker ( 1603*da0073e9SAndroid Build Coastguard Worker proto, 1604*da0073e9SAndroid Build Coastguard Worker export_map, 1605*da0073e9SAndroid Build Coastguard Worker val_use_external_data_format, 1606*da0073e9SAndroid Build Coastguard Worker node_names, 1607*da0073e9SAndroid Build Coastguard Worker ) = graph._export_onnx( # type: ignore[attr-defined] 1608*da0073e9SAndroid Build Coastguard Worker params_dict, 1609*da0073e9SAndroid Build Coastguard Worker opset_version, 1610*da0073e9SAndroid Build Coastguard Worker dynamic_axes, 1611*da0073e9SAndroid Build Coastguard Worker defer_weight_export, 1612*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1613*da0073e9SAndroid Build Coastguard Worker not verbose, 1614*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip, 1615*da0073e9SAndroid Build Coastguard Worker custom_opsets, 1616*da0073e9SAndroid Build Coastguard Worker val_add_node_names, 1617*da0073e9SAndroid Build Coastguard Worker model_file_location, 1618*da0073e9SAndroid Build Coastguard Worker node_attr_to_name, 1619*da0073e9SAndroid Build Coastguard Worker ) 1620*da0073e9SAndroid Build Coastguard Worker else: 1621*da0073e9SAndroid Build Coastguard Worker ( 1622*da0073e9SAndroid Build Coastguard Worker proto, 1623*da0073e9SAndroid Build Coastguard Worker export_map, 1624*da0073e9SAndroid Build Coastguard Worker val_use_external_data_format, 1625*da0073e9SAndroid Build Coastguard Worker node_names, 1626*da0073e9SAndroid Build Coastguard Worker ) = graph._export_onnx( # type: ignore[attr-defined] 1627*da0073e9SAndroid Build Coastguard Worker {}, 1628*da0073e9SAndroid Build Coastguard Worker opset_version, 1629*da0073e9SAndroid Build Coastguard Worker dynamic_axes, 1630*da0073e9SAndroid Build Coastguard Worker False, 1631*da0073e9SAndroid Build Coastguard Worker operator_export_type, 1632*da0073e9SAndroid Build Coastguard Worker not verbose, 1633*da0073e9SAndroid Build Coastguard Worker val_keep_init_as_ip, 1634*da0073e9SAndroid Build Coastguard Worker custom_opsets, 1635*da0073e9SAndroid Build Coastguard Worker val_add_node_names, 1636*da0073e9SAndroid Build Coastguard Worker model_file_location, 1637*da0073e9SAndroid Build Coastguard Worker node_attr_to_name, 1638*da0073e9SAndroid Build Coastguard Worker ) 1639*da0073e9SAndroid Build Coastguard Worker # insert function_proto into model_proto. 1640*da0073e9SAndroid Build Coastguard Worker proto = onnx_proto_utils._add_onnxscript_fn( 1641*da0073e9SAndroid Build Coastguard Worker proto, 1642*da0073e9SAndroid Build Coastguard Worker custom_opsets, 1643*da0073e9SAndroid Build Coastguard Worker ) 1644*da0073e9SAndroid Build Coastguard Worker if verbose: 1645*da0073e9SAndroid Build Coastguard Worker torch.onnx.log("Exported graph: ", graph) 1646*da0073e9SAndroid Build Coastguard Worker onnx_proto_utils._export_file(proto, f, export_type, export_map) 1647*da0073e9SAndroid Build Coastguard Worker finally: 1648*da0073e9SAndroid Build Coastguard Worker assert GLOBALS.in_onnx_export 1649*da0073e9SAndroid Build Coastguard Worker GLOBALS.in_onnx_export = False 1650*da0073e9SAndroid Build Coastguard Worker GLOBALS.autograd_inlining = _autograd_inlining_previous 1651*da0073e9SAndroid Build Coastguard Worker _reset_trace_module_map() 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker return torch_out 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Workerdef _apply_friendly_debug_names(graph, params): 1657*da0073e9SAndroid Build Coastguard Worker for n in graph.nodes(): 1658*da0073e9SAndroid Build Coastguard Worker for v in n.inputs(): 1659*da0073e9SAndroid Build Coastguard Worker old_name = v.debugName() 1660*da0073e9SAndroid Build Coastguard Worker if old_name != str(v.unique()): 1661*da0073e9SAndroid Build Coastguard Worker continue 1662*da0073e9SAndroid Build Coastguard Worker new_name = f"{n.kind()}_{v.unique()}" 1663*da0073e9SAndroid Build Coastguard Worker v.setDebugName(new_name) 1664*da0073e9SAndroid Build Coastguard Worker if old_name in params: 1665*da0073e9SAndroid Build Coastguard Worker params[new_name] = params.pop(old_name) 1666*da0073e9SAndroid Build Coastguard Worker 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Workerdef _set_input_and_output_names(graph, input_names, output_names): 1669*da0073e9SAndroid Build Coastguard Worker def set_names(node_list, name_list, descriptor): 1670*da0073e9SAndroid Build Coastguard Worker if name_list is None: 1671*da0073e9SAndroid Build Coastguard Worker return 1672*da0073e9SAndroid Build Coastguard Worker if len(name_list) > len(node_list): 1673*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1674*da0073e9SAndroid Build Coastguard Worker "number of %s names provided (%d) exceeded number of %ss (%d)" 1675*da0073e9SAndroid Build Coastguard Worker % (descriptor, len(name_list), descriptor, len(node_list)) 1676*da0073e9SAndroid Build Coastguard Worker ) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker # Mark if the output node DebugName is set before. 1679*da0073e9SAndroid Build Coastguard Worker output_node_set = set() 1680*da0073e9SAndroid Build Coastguard Worker for i, (name, node) in enumerate(zip(name_list, node_list)): 1681*da0073e9SAndroid Build Coastguard Worker # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). 1682*da0073e9SAndroid Build Coastguard Worker if descriptor == "output": 1683*da0073e9SAndroid Build Coastguard Worker if node in output_node_set: 1684*da0073e9SAndroid Build Coastguard Worker identity_node = graph.create("onnx::Identity") 1685*da0073e9SAndroid Build Coastguard Worker identity_node.insertAfter(node.node()) 1686*da0073e9SAndroid Build Coastguard Worker identity_node.addInput(node) 1687*da0073e9SAndroid Build Coastguard Worker identity_node.output().setType(node.type()) 1688*da0073e9SAndroid Build Coastguard Worker graph.return_node().replaceInput(i, identity_node.output()) 1689*da0073e9SAndroid Build Coastguard Worker node = identity_node.output() 1690*da0073e9SAndroid Build Coastguard Worker output_node_set.add(node) 1691*da0073e9SAndroid Build Coastguard Worker 1692*da0073e9SAndroid Build Coastguard Worker if node.debugName() != name: 1693*da0073e9SAndroid Build Coastguard Worker node.setDebugName(name) 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker set_names(list(graph.inputs()), input_names, "input") 1696*da0073e9SAndroid Build Coastguard Worker set_names(list(graph.outputs()), output_names, "output") 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Workerdef _run_symbolic_method(g, op_name, symbolic_fn, args): 1700*da0073e9SAndroid Build Coastguard Worker r""" 1701*da0073e9SAndroid Build Coastguard Worker This trampoline function gets invoked for every symbolic method 1702*da0073e9SAndroid Build Coastguard Worker call from C++. 1703*da0073e9SAndroid Build Coastguard Worker """ 1704*da0073e9SAndroid Build Coastguard Worker try: 1705*da0073e9SAndroid Build Coastguard Worker graph_context = jit_utils.GraphContext( 1706*da0073e9SAndroid Build Coastguard Worker graph=g, 1707*da0073e9SAndroid Build Coastguard Worker block=g.block(), 1708*da0073e9SAndroid Build Coastguard Worker opset=GLOBALS.export_onnx_opset_version, 1709*da0073e9SAndroid Build Coastguard Worker original_node=None, # type: ignore[arg-type] 1710*da0073e9SAndroid Build Coastguard Worker params_dict=_params_dict, 1711*da0073e9SAndroid Build Coastguard Worker env={}, 1712*da0073e9SAndroid Build Coastguard Worker values_in_env=set(), 1713*da0073e9SAndroid Build Coastguard Worker new_nodes=[], 1714*da0073e9SAndroid Build Coastguard Worker ) 1715*da0073e9SAndroid Build Coastguard Worker return symbolic_fn(graph_context, *args) 1716*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 1717*da0073e9SAndroid Build Coastguard Worker # Handle the specific case where we didn't successfully dispatch 1718*da0073e9SAndroid Build Coastguard Worker # to symbolic_fn. Otherwise, the backtrace will have the clues 1719*da0073e9SAndroid Build Coastguard Worker # you need. 1720*da0073e9SAndroid Build Coastguard Worker e.args = (f"{e.args[0]} (occurred when translating {op_name})",) 1721*da0073e9SAndroid Build Coastguard Worker raise 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Workerdef _add_block(node: _C.Node) -> _C.Block: 1725*da0073e9SAndroid Build Coastguard Worker return node.addBlock() 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Workerdef _add_input_to_block(block: _C.Block): 1729*da0073e9SAndroid Build Coastguard Worker return block.addInputToBlock() # type: ignore[attr-defined] 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Workerdef _add_output_to_block(block: _C.Block, value: _C.Value) -> int: 1733*da0073e9SAndroid Build Coastguard Worker return block.registerOutput(value) 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker 1736*da0073e9SAndroid Build Coastguard Workerdef _should_aten_fallback( 1737*da0073e9SAndroid Build Coastguard Worker name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes 1738*da0073e9SAndroid Build Coastguard Worker): 1739*da0073e9SAndroid Build Coastguard Worker # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, 1740*da0073e9SAndroid Build Coastguard Worker # an aten::ATen operator is created regardless of symbolics existence 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) 1743*da0073e9SAndroid Build Coastguard Worker is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN 1744*da0073e9SAndroid Build Coastguard Worker is_aten_fallback_export = ( 1745*da0073e9SAndroid Build Coastguard Worker operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK 1746*da0073e9SAndroid Build Coastguard Worker ) 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker if not name.startswith("aten::"): 1749*da0073e9SAndroid Build Coastguard Worker return False 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): 1752*da0073e9SAndroid Build Coastguard Worker return True 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker return False 1755*da0073e9SAndroid Build Coastguard Worker 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Workerdef _get_aten_op_overload_name(n: _C.Node) -> str: 1758*da0073e9SAndroid Build Coastguard Worker # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds 1759*da0073e9SAndroid Build Coastguard Worker schema = n.schema() 1760*da0073e9SAndroid Build Coastguard Worker if not schema.startswith("aten::"): 1761*da0073e9SAndroid Build Coastguard Worker return "" 1762*da0073e9SAndroid Build Coastguard Worker return _C.parse_schema(schema).overload_name 1763*da0073e9SAndroid Build Coastguard Worker 1764*da0073e9SAndroid Build Coastguard Worker 1765*da0073e9SAndroid Build Coastguard Workerdef _run_symbolic_function( 1766*da0073e9SAndroid Build Coastguard Worker graph: _C.Graph, 1767*da0073e9SAndroid Build Coastguard Worker block: _C.Block, 1768*da0073e9SAndroid Build Coastguard Worker node: _C.Node, 1769*da0073e9SAndroid Build Coastguard Worker inputs: Any, 1770*da0073e9SAndroid Build Coastguard Worker env: dict[_C.Value, _C.Value], 1771*da0073e9SAndroid Build Coastguard Worker values_in_env: set[_C.Value], 1772*da0073e9SAndroid Build Coastguard Worker new_nodes: list[_C.Node], 1773*da0073e9SAndroid Build Coastguard Worker operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1774*da0073e9SAndroid Build Coastguard Worker) -> _C.Value | Sequence[_C.Value | None] | None: 1775*da0073e9SAndroid Build Coastguard Worker """Runs a symbolic function. 1776*da0073e9SAndroid Build Coastguard Worker 1777*da0073e9SAndroid Build Coastguard Worker The function is used in C++ to export the node to ONNX. 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker Returns: 1780*da0073e9SAndroid Build Coastguard Worker A single or a tuple of Values. 1781*da0073e9SAndroid Build Coastguard Worker None when the node gets cloned as is into the new graph. 1782*da0073e9SAndroid Build Coastguard Worker """ 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker opset_version = GLOBALS.export_onnx_opset_version 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker # See Note [Export inplace] 1787*da0073e9SAndroid Build Coastguard Worker node_kind = node.kind() 1788*da0073e9SAndroid Build Coastguard Worker if node_kind.endswith("_"): 1789*da0073e9SAndroid Build Coastguard Worker # Treat relu_ -> relu; add_ -> add etc. 1790*da0073e9SAndroid Build Coastguard Worker ns_op_name = node_kind[:-1] 1791*da0073e9SAndroid Build Coastguard Worker else: 1792*da0073e9SAndroid Build Coastguard Worker ns_op_name = node_kind 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Worker namespace, op_name = jit_utils.parse_node_kind(ns_op_name) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker graph_context = jit_utils.GraphContext( 1797*da0073e9SAndroid Build Coastguard Worker graph=graph, 1798*da0073e9SAndroid Build Coastguard Worker block=block, 1799*da0073e9SAndroid Build Coastguard Worker opset=opset_version, 1800*da0073e9SAndroid Build Coastguard Worker original_node=node, 1801*da0073e9SAndroid Build Coastguard Worker params_dict=_params_dict, 1802*da0073e9SAndroid Build Coastguard Worker env=env, 1803*da0073e9SAndroid Build Coastguard Worker values_in_env=values_in_env, 1804*da0073e9SAndroid Build Coastguard Worker new_nodes=new_nodes, 1805*da0073e9SAndroid Build Coastguard Worker ) 1806*da0073e9SAndroid Build Coastguard Worker 1807*da0073e9SAndroid Build Coastguard Worker # Direct ATen export requested 1808*da0073e9SAndroid Build Coastguard Worker if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): 1809*da0073e9SAndroid Build Coastguard Worker attrs = { 1810*da0073e9SAndroid Build Coastguard Worker k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) 1811*da0073e9SAndroid Build Coastguard Worker for k in node.attributeNames() 1812*da0073e9SAndroid Build Coastguard Worker } 1813*da0073e9SAndroid Build Coastguard Worker outputs = node.outputsSize() 1814*da0073e9SAndroid Build Coastguard Worker attrs["outputs"] = outputs 1815*da0073e9SAndroid Build Coastguard Worker return graph_context.aten_op( 1816*da0073e9SAndroid Build Coastguard Worker op_name, 1817*da0073e9SAndroid Build Coastguard Worker *inputs, 1818*da0073e9SAndroid Build Coastguard Worker overload_name=_get_aten_op_overload_name(node), 1819*da0073e9SAndroid Build Coastguard Worker **attrs, 1820*da0073e9SAndroid Build Coastguard Worker ) 1821*da0073e9SAndroid Build Coastguard Worker 1822*da0073e9SAndroid Build Coastguard Worker try: 1823*da0073e9SAndroid Build Coastguard Worker domain = namespace 1824*da0073e9SAndroid Build Coastguard Worker symbolic_function_name = f"{domain}::{op_name}" 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker symbolic_function_group = registration.registry.get_function_group( 1827*da0073e9SAndroid Build Coastguard Worker symbolic_function_name 1828*da0073e9SAndroid Build Coastguard Worker ) 1829*da0073e9SAndroid Build Coastguard Worker if symbolic_function_group is not None: 1830*da0073e9SAndroid Build Coastguard Worker symbolic_fn = symbolic_function_group.get(opset_version) 1831*da0073e9SAndroid Build Coastguard Worker if symbolic_fn is not None: 1832*da0073e9SAndroid Build Coastguard Worker # TODO Wrap almost identical attrs assignment or comment the difference. 1833*da0073e9SAndroid Build Coastguard Worker attrs = { 1834*da0073e9SAndroid Build Coastguard Worker k: symbolic_helper._node_get(node, k) for k in node.attributeNames() 1835*da0073e9SAndroid Build Coastguard Worker } 1836*da0073e9SAndroid Build Coastguard Worker return symbolic_fn(graph_context, *inputs, **attrs) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker attrs = { 1839*da0073e9SAndroid Build Coastguard Worker k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) 1840*da0073e9SAndroid Build Coastguard Worker for k in node.attributeNames() 1841*da0073e9SAndroid Build Coastguard Worker } 1842*da0073e9SAndroid Build Coastguard Worker if namespace == "onnx": 1843*da0073e9SAndroid Build Coastguard Worker # Clone node to trigger ONNX shape inference 1844*da0073e9SAndroid Build Coastguard Worker return graph_context.op( 1845*da0073e9SAndroid Build Coastguard Worker op_name, *inputs, **attrs, outputs=node.outputsSize() 1846*da0073e9SAndroid Build Coastguard Worker ) # type: ignore[attr-defined] 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker raise errors.UnsupportedOperatorError( 1849*da0073e9SAndroid Build Coastguard Worker symbolic_function_name, 1850*da0073e9SAndroid Build Coastguard Worker opset_version, 1851*da0073e9SAndroid Build Coastguard Worker symbolic_function_group.get_min_supported() 1852*da0073e9SAndroid Build Coastguard Worker if symbolic_function_group 1853*da0073e9SAndroid Build Coastguard Worker else None, 1854*da0073e9SAndroid Build Coastguard Worker ) 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 1857*da0073e9SAndroid Build Coastguard Worker if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: 1858*da0073e9SAndroid Build Coastguard Worker return None 1859*da0073e9SAndroid Build Coastguard Worker elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: 1860*da0073e9SAndroid Build Coastguard Worker # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` 1861*da0073e9SAndroid Build Coastguard Worker attrs = { 1862*da0073e9SAndroid Build Coastguard Worker k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) 1863*da0073e9SAndroid Build Coastguard Worker for k in node.attributeNames() 1864*da0073e9SAndroid Build Coastguard Worker } 1865*da0073e9SAndroid Build Coastguard Worker return graph_context.aten_op( 1866*da0073e9SAndroid Build Coastguard Worker op_name, 1867*da0073e9SAndroid Build Coastguard Worker *inputs, 1868*da0073e9SAndroid Build Coastguard Worker overload_name=_get_aten_op_overload_name(node), 1869*da0073e9SAndroid Build Coastguard Worker **attrs, 1870*da0073e9SAndroid Build Coastguard Worker ) 1871*da0073e9SAndroid Build Coastguard Worker raise 1872*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 1873*da0073e9SAndroid Build Coastguard Worker # Handle the specific case where we didn't successfully dispatch. 1874*da0073e9SAndroid Build Coastguard Worker # Otherwise, the backtrace will have the clues you need. 1875*da0073e9SAndroid Build Coastguard Worker e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) 1876*da0073e9SAndroid Build Coastguard Worker raise 1877*da0073e9SAndroid Build Coastguard Worker 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Workerdef _verify_custom_op_name(symbolic_name: str): 1880*da0073e9SAndroid Build Coastguard Worker if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): 1881*da0073e9SAndroid Build Coastguard Worker raise errors.OnnxExporterError( 1882*da0073e9SAndroid Build Coastguard Worker f"Failed to register operator {symbolic_name}. " 1883*da0073e9SAndroid Build Coastguard Worker "The symbolic name must match the format domain::name, " 1884*da0073e9SAndroid Build Coastguard Worker "and should start with a letter and contain only " 1885*da0073e9SAndroid Build Coastguard Worker "alphanumerical characters" 1886*da0073e9SAndroid Build Coastguard Worker ) 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker ns, _ = jit_utils.parse_node_kind(symbolic_name) 1889*da0073e9SAndroid Build Coastguard Worker if ns == "onnx": 1890*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1891*da0073e9SAndroid Build Coastguard Worker f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." 1892*da0073e9SAndroid Build Coastguard Worker ) 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Workerdef register_custom_op_symbolic( 1896*da0073e9SAndroid Build Coastguard Worker symbolic_name: str, 1897*da0073e9SAndroid Build Coastguard Worker symbolic_fn: Callable, 1898*da0073e9SAndroid Build Coastguard Worker opset_version: int, 1899*da0073e9SAndroid Build Coastguard Worker): 1900*da0073e9SAndroid Build Coastguard Worker """Registers a symbolic function for a custom operator. 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker When the user registers symbolic for custom/contrib ops, 1903*da0073e9SAndroid Build Coastguard Worker it is highly recommended to add shape inference for that operator via setType API, 1904*da0073e9SAndroid Build Coastguard Worker otherwise the exported graph may have incorrect shape inference in some extreme cases. 1905*da0073e9SAndroid Build Coastguard Worker An example of setType is `test_aten_embedding_2` in `test_operators.py`. 1906*da0073e9SAndroid Build Coastguard Worker 1907*da0073e9SAndroid Build Coastguard Worker See "Custom Operators" in the module documentation for an example usage. 1908*da0073e9SAndroid Build Coastguard Worker 1909*da0073e9SAndroid Build Coastguard Worker Args: 1910*da0073e9SAndroid Build Coastguard Worker symbolic_name (str): The name of the custom operator in "<domain>::<op>" 1911*da0073e9SAndroid Build Coastguard Worker format. 1912*da0073e9SAndroid Build Coastguard Worker symbolic_fn (Callable): A function that takes in the ONNX graph and 1913*da0073e9SAndroid Build Coastguard Worker the input arguments to the current operator, and returns new 1914*da0073e9SAndroid Build Coastguard Worker operator nodes to add to the graph. 1915*da0073e9SAndroid Build Coastguard Worker opset_version (int): The ONNX opset version in which to register. 1916*da0073e9SAndroid Build Coastguard Worker """ 1917*da0073e9SAndroid Build Coastguard Worker if symbolic_name.startswith("::"): 1918*da0073e9SAndroid Build Coastguard Worker symbolic_name = f"aten{symbolic_name}" 1919*da0073e9SAndroid Build Coastguard Worker 1920*da0073e9SAndroid Build Coastguard Worker _verify_custom_op_name(symbolic_name) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) 1923*da0073e9SAndroid Build Coastguard Worker 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Workerdef unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): 1926*da0073e9SAndroid Build Coastguard Worker """Unregisters ``symbolic_name``. 1927*da0073e9SAndroid Build Coastguard Worker 1928*da0073e9SAndroid Build Coastguard Worker See "Custom Operators" in the module documentation for an example usage. 1929*da0073e9SAndroid Build Coastguard Worker 1930*da0073e9SAndroid Build Coastguard Worker Args: 1931*da0073e9SAndroid Build Coastguard Worker symbolic_name (str): The name of the custom operator in "<domain>::<op>" 1932*da0073e9SAndroid Build Coastguard Worker format. 1933*da0073e9SAndroid Build Coastguard Worker opset_version (int): The ONNX opset version in which to unregister. 1934*da0073e9SAndroid Build Coastguard Worker """ 1935*da0073e9SAndroid Build Coastguard Worker if symbolic_name.startswith("::"): 1936*da0073e9SAndroid Build Coastguard Worker symbolic_name = f"aten{symbolic_name}" 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker _verify_custom_op_name(symbolic_name) 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker registration.registry.unregister(symbolic_name, opset_version) 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Workerdef _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): 1944*da0073e9SAndroid Build Coastguard Worker """Ensures dynamic axes argument is follows the expected format.""" 1945*da0073e9SAndroid Build Coastguard Worker if len(dynamic_axes) == 0: 1946*da0073e9SAndroid Build Coastguard Worker return 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker if hasattr(model, "graph"): 1949*da0073e9SAndroid Build Coastguard Worker # Extracting set of valid input/output names that shall be used for dynamic_axes 1950*da0073e9SAndroid Build Coastguard Worker if (input_names is None) or len(input_names) == 0: 1951*da0073e9SAndroid Build Coastguard Worker input_names = [x.debugName() for x in model.graph.inputs()] 1952*da0073e9SAndroid Build Coastguard Worker if (output_names is None) or len(output_names) == 0: 1953*da0073e9SAndroid Build Coastguard Worker output_names = [y.debugName() for y in model.graph.outputs()] 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker valid_names = set((input_names or []) + (output_names or [])) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker # If dynamic axes are provided as a list rather than dictionary, they should 1958*da0073e9SAndroid Build Coastguard Worker # first get converted to a dictionary in expected format. If desired axes names 1959*da0073e9SAndroid Build Coastguard Worker # are not provided for dynamic axes, automatic names shall be generated for 1960*da0073e9SAndroid Build Coastguard Worker # provided dynamic axes of specified input/output 1961*da0073e9SAndroid Build Coastguard Worker for key, value in dynamic_axes.items(): 1962*da0073e9SAndroid Build Coastguard Worker if key not in valid_names: 1963*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1964*da0073e9SAndroid Build Coastguard Worker f"Provided key {key} for dynamic axes is not a valid input/output name" 1965*da0073e9SAndroid Build Coastguard Worker ) 1966*da0073e9SAndroid Build Coastguard Worker if isinstance(value, list): 1967*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1968*da0073e9SAndroid Build Coastguard Worker "No names were found for specified dynamic axes of provided input." 1969*da0073e9SAndroid Build Coastguard Worker f"Automatically generated names will be applied to each dynamic axes of input {key}" 1970*da0073e9SAndroid Build Coastguard Worker ) 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker value_dict = {} 1973*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(value): 1974*da0073e9SAndroid Build Coastguard Worker if not isinstance(x, int): 1975*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1976*da0073e9SAndroid Build Coastguard Worker "The type of axis index is expected to be an integer" 1977*da0073e9SAndroid Build Coastguard Worker ) 1978*da0073e9SAndroid Build Coastguard Worker if x in value_dict: 1979*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1980*da0073e9SAndroid Build Coastguard Worker f"Duplicate dynamic axis index {x} was provided for input {key}." 1981*da0073e9SAndroid Build Coastguard Worker ) 1982*da0073e9SAndroid Build Coastguard Worker else: 1983*da0073e9SAndroid Build Coastguard Worker value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) 1984*da0073e9SAndroid Build Coastguard Worker dynamic_axes[key] = value_dict 1985*da0073e9SAndroid Build Coastguard Worker 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Workerdef model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: 1988*da0073e9SAndroid Build Coastguard Worker return inspect.signature( 1989*da0073e9SAndroid Build Coastguard Worker model.forward if isinstance(model, torch.nn.Module) else model 1990*da0073e9SAndroid Build Coastguard Worker ) 1991