xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset11.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code=arg-type
3"""This file exports ONNX ops for opset 11."""
4
5from __future__ import annotations
6
7import functools
8import sys
9import warnings
10from typing import Sequence
11
12import torch
13from torch import _C
14from torch._C import _onnx as _C_onnx
15from torch.onnx import (
16    _type_utils,
17    errors,
18    symbolic_helper,
19    symbolic_opset10 as opset10,
20    symbolic_opset9 as opset9,
21    utils,
22)
23from torch.onnx._internal import jit_utils, registration
24
25
26# EDITING THIS FILE? READ THIS FIRST!
27# see Note [Edit Symbolic Files] in README.md
28
29__all__ = [
30    "add",
31    "append",
32    "arange",
33    "argsort",
34    "atleast_1d",
35    "atleast_2d",
36    "atleast_3d",
37    "cat",
38    "chunk",
39    "clamp_max",
40    "clamp_min",
41    "clamp",
42    "constant_pad_nd",
43    "cumsum",
44    "Delete",
45    "embedding_bag",
46    "embedding_renorm",
47    "flatten",
48    "gather",
49    "hardtanh",
50    "hstack",
51    "im2col",
52    "index_fill",
53    "index",
54    "index_copy",
55    "index_put",
56    "insert",
57    "linalg_det",
58    "linalg_vector_norm",
59    "logdet",
60    "masked_scatter",
61    "masked_select",
62    "mm",
63    "narrow",
64    "normal",
65    "pad",
66    "pixel_shuffle",
67    "pop",
68    "prim_constant_chunk",
69    "reflection_pad",
70    "relu6",
71    "remainder",
72    "replication_pad",
73    "round",
74    "scatter",
75    "select",
76    "size",
77    "sort",
78    "split_with_sizes",
79    "split",
80    "squeeze",
81    "stack",
82    "topk",
83    "unbind",
84    "unique_dim",
85    "unsqueeze",
86    "vstack",
87]
88
89_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
90
91
92@_onnx_symbolic("aten::hardtanh")
93@symbolic_helper.quantized_args(True)
94@symbolic_helper.parse_args("v", "f", "f")
95def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
96    scalar_type = _type_utils.JitScalarType.from_value(
97        self, _type_utils.JitScalarType.FLOAT
98    )
99    min_val = g.op(
100        "Constant",
101        value_t=torch.tensor(min_val, dtype=scalar_type.dtype()),
102    )
103    max_val = g.op(
104        "Constant",
105        value_t=torch.tensor(max_val, dtype=scalar_type.dtype()),
106    )
107    return symbolic_helper._op_with_optional_float_cast(
108        g, "Clip", self, min_val, max_val, opset_before=12
109    )
110
111
112@_onnx_symbolic("aten::clamp")
113def clamp(g: jit_utils.GraphContext, self, min, max):
114    def _cast_if_not_none(tensor, dtype):
115        if tensor is not None and not symbolic_helper._is_none(tensor):
116            return g.op(
117                "Cast",
118                tensor,
119                to_i=dtype.onnx_type(),
120            )
121        else:
122            return tensor
123
124    scalar_type = _type_utils.JitScalarType.from_value(
125        self, _type_utils.JitScalarType.UNDEFINED
126    )
127    if scalar_type != _type_utils.JitScalarType.UNDEFINED:
128        min = _cast_if_not_none(min, scalar_type)
129        max = _cast_if_not_none(max, scalar_type)
130
131    if symbolic_helper._is_none(min):
132        return clamp_max(g, self, max)
133    elif symbolic_helper._is_none(max):
134        return clamp_min(g, self, min)
135    else:
136        if (
137            symbolic_helper._get_tensor_rank(min) == 0
138            and symbolic_helper._get_tensor_rank(max) == 0
139        ):
140            return symbolic_helper._op_with_optional_float_cast(
141                g, "Clip", self, min, max, opset_before=12
142            )
143        else:
144            return clamp_max(g, clamp_min(g, self, min), max)
145
146
147@_onnx_symbolic("aten::clamp_min")
148@symbolic_helper.parse_args("v", "v")
149def clamp_min(g: jit_utils.GraphContext, self, min):
150    min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
151    if symbolic_helper._get_tensor_rank(min) == 0:
152        max = opset9.unused(g)
153        return symbolic_helper._op_with_optional_float_cast(
154            g, "Clip", self, min, max, opset_before=12
155        )
156    else:
157        return symbolic_helper._op_with_optional_float_cast(
158            g, "Max", self, min, opset_before=12
159        )
160
161
162@_onnx_symbolic("aten::clamp_max")
163@symbolic_helper.parse_args("v", "v")
164def clamp_max(g: jit_utils.GraphContext, self, max):
165    max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
166    if symbolic_helper._get_tensor_rank(max) == 0:
167        min = opset9.unused(g)
168        return symbolic_helper._op_with_optional_float_cast(
169            g, "Clip", self, min, max, opset_before=12
170        )
171    else:
172        return symbolic_helper._op_with_optional_float_cast(
173            g, "Min", self, max, opset_before=12
174        )
175
176
177@_onnx_symbolic("aten::relu6")
178def relu6(g: jit_utils.GraphContext, input):
179    scalar_type = _type_utils.JitScalarType.from_value(
180        input, _type_utils.JitScalarType.FLOAT
181    )
182    min_val = g.op(
183        "Constant",
184        value_t=torch.tensor(0, dtype=scalar_type.dtype()),
185    )
186    max_val = g.op(
187        "Constant",
188        value_t=torch.tensor(6, dtype=scalar_type.dtype()),
189    )
190    return clamp(g, input, min_val, max_val)
191
192
193@_onnx_symbolic("aten::select")
194# Opset 11 gather accepts negative indices
195@symbolic_helper.quantized_args(True)
196@symbolic_helper.parse_args("v", "i", "v")
197def select(g: jit_utils.GraphContext, self, dim, index):
198    return g.op("Gather", self, index, axis_i=dim)
199
200
201@_onnx_symbolic("aten::index_put")
202def index_put(
203    g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False
204):
205    if symbolic_helper._is_packed_list(indices_list_value):
206        indices_list = symbolic_helper._unpack_list(indices_list_value)
207    else:
208        indices_list = [indices_list_value]
209    accumulate = symbolic_helper._parse_arg(accumulate, "b")
210
211    if len(indices_list) == 0:
212        return values
213
214    if len(indices_list) > 1:
215        for idx_ in range(len(indices_list)):
216            if symbolic_helper._is_bool(indices_list[idx_]):
217                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
218        index = indices_list[0]
219
220        for ind in indices_list[1:]:
221            index = opset9.add(g, index, ind)
222        broadcast_index_shape = g.op("Shape", index)
223        indices_list = [
224            symbolic_helper._unsqueeze_helper(
225                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
226            )
227            for ind in indices_list
228        ]
229        index = g.op("Concat", *indices_list, axis_i=-1)
230    else:
231        # Replace index_put node with masked_scatter or masked_fill
232        # when inputs to the index_put node contains a single boolean input.
233        #
234        # index_put -> masked_fill
235        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
236        #   * input value contains single element (e.g.: %18).
237        #
238        # Torch IR
239        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
240        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
241        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
242        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
243        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
244        #   %24 : Tensor?[] = prim::ListConstruct(%23)
245        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
246        #                aten::index_put(%mask, %24, %18, %30)
247        #   return (%25)
248        #
249        #
250        # index_put -> masked_scatter
251        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
252        #   * input value contains multiple elements (e.g.: %28).
253        #
254        # Torch IR
255        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
256        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
257        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
258        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
259        #                = aten::ne(%mask, %some_const)
260        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
261        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
262        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
263        #   %30 : int[] = prim::Constant[value=[-1]]()
264        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
265        #   %32 : Tensor?[] = prim::ListConstruct(%31)
266        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
267        #               = aten::index_put(%mask, %32, %28, %38)
268        #   return (%33)
269        index = indices_list[0]
270        bool_inp = index
271        if symbolic_helper._is_bool(bool_inp):
272            rank = symbolic_helper._get_tensor_rank(values)
273            if rank is not None and rank == 0:
274                return opset9.masked_fill(g, self, bool_inp, values)
275            mask_rank = symbolic_helper._get_tensor_rank(bool_inp)
276            self_rank = symbolic_helper._get_tensor_rank(self)
277            if (
278                mask_rank is not None
279                and self_rank is not None
280                and self_rank > mask_rank
281            ):
282                # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'.
283                bool_inp = symbolic_helper._unsqueeze_helper(
284                    g, bool_inp, list(range(mask_rank, self_rank))
285                )
286            return masked_scatter(g, self, bool_inp, values)
287        broadcast_index_shape = g.op("Shape", index)
288        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
289    sub_data_shape = symbolic_helper._slice_helper(
290        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
291    )
292    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
293    # Check if values is a singular value and expand accordingly
294    rank = symbolic_helper._get_tensor_rank(values)
295    if rank is not None and rank == 0:
296        values = opset9.expand(g, values, values_shape, None)
297    values = symbolic_helper._reshape_helper(g, values, values_shape)
298
299    self_scalar_type = _type_utils.JitScalarType.from_value(
300        self, _type_utils.JitScalarType.UNDEFINED
301    )
302    if self_scalar_type != _type_utils.JitScalarType.UNDEFINED:
303        values_scalar_type = _type_utils.JitScalarType.from_value(
304            values, _type_utils.JitScalarType.UNDEFINED
305        )
306        if self_scalar_type != values_scalar_type:
307            values = g.op("Cast", values, to_i=self_scalar_type.onnx_type())
308    elif accumulate:
309        raise errors.SymbolicValueError("self does not have a valid scalar type.", self)
310
311    if accumulate:
312        zeros = g.op(
313            "ConstantOfShape",
314            g.op("Shape", self),
315            value_t=torch.tensor([0], dtype=self_scalar_type.dtype()),
316        )
317        result = g.op("ScatterND", zeros, index, values)
318        result = add(g, self, result)
319    else:
320        result = g.op("ScatterND", self, index, values)
321
322    return result
323
324
325@_onnx_symbolic("aten::pixel_shuffle")
326@symbolic_helper.parse_args("v", "i")
327def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
328    rank = symbolic_helper._get_tensor_rank(self)
329    if rank is not None and rank != 4:
330        return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input")
331    return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
332
333
334@_onnx_symbolic(
335    "aten::upsample_nearest1d",
336    decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
337)
338@_onnx_symbolic(
339    "aten::upsample_nearest2d",
340    decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
341)
342@_onnx_symbolic(
343    "aten::upsample_nearest3d",
344    decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
345)
346@_onnx_symbolic(
347    "aten::upsample_linear1d",
348    decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
349)
350@_onnx_symbolic(
351    "aten::upsample_bilinear2d",
352    decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
353)
354@_onnx_symbolic(
355    "aten::upsample_trilinear3d",
356    decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
357)
358@_onnx_symbolic(
359    "aten::upsample_bicubic2d",
360    decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")],
361)
362def _interpolate(name: str, dim: int, interpolate_mode: str):
363    return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
364
365
366@_onnx_symbolic("aten::__interpolate")
367@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
368def __interpolate(
369    g: jit_utils.GraphContext,
370    input,
371    size,
372    scale_factor,
373    mode,
374    align_corners,
375    recompute_scale_factor,
376    antialias,
377):
378    return symbolic_helper.__interpolate_helper(
379        g, input, size, scale_factor, mode, align_corners, recompute_scale_factor
380    )
381
382
383@_onnx_symbolic("aten::gather")
384@symbolic_helper.parse_args("v", "i", "v", "v")
385def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
386    if symbolic_helper._maybe_get_const(sparse_grad, "i"):
387        return symbolic_helper._unimplemented("gather", "sparse_grad == True")
388    return g.op("GatherElements", self, index, axis_i=dim)
389
390
391@_onnx_symbolic("aten::scatter")
392@symbolic_helper.parse_args("v", "i", "v", "v")
393def scatter(g: jit_utils.GraphContext, self, dim, index, src):
394    src_type = _type_utils.JitScalarType.from_value(src)
395    src = symbolic_helper._maybe_get_scalar(src)
396    if symbolic_helper._is_value(src):
397        return g.op("ScatterElements", self, index, src, axis_i=dim)
398    else:
399        # Check if scalar "src" has same type as self (PyTorch allows different
400        # type for scalar src (but not when src is tensor)). If not, insert Cast node.
401        if _type_utils.JitScalarType.from_value(self) != src_type:
402            src = g.op(
403                "Cast",
404                src,
405                to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
406            )
407        return g.op(
408            "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim
409        )
410
411
412@_onnx_symbolic("aten::cumsum")
413@symbolic_helper.parse_args("v", "i", "none")
414def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
415    dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
416    if dtype and dtype.node().kind() != "prim::Constant":
417        parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
418        cast = g.op(
419            "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
420        )
421    else:
422        cast = self
423    csum = g.op("CumSum", cast, dim_tensor)
424    return csum
425
426
427@_onnx_symbolic("aten::masked_select")
428def masked_select(g: jit_utils.GraphContext, self, mask):
429    index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
430    return g.op("GatherND", self, index)
431
432
433@_onnx_symbolic("aten::masked_scatter")
434def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
435    index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
436    # NOTE: source can have more elements than needed.
437    # It could also have arbitrary shape.
438    # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
439    source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1]))
440    source = symbolic_helper._slice_helper(
441        g,
442        source,
443        axes=torch.LongTensor([0]),
444        starts=torch.LongTensor([0]),
445        ends=opset9.size(g, index, torch.LongTensor([0])),
446    )
447    return g.op("ScatterND", self, index, source)
448
449
450@_onnx_symbolic("aten::len")
451def _len(g: jit_utils.GraphContext, self):
452    if (
453        symbolic_helper._is_tensor_list(self)
454        or self.node().kind() == "onnx::SplitToSequence"
455    ):
456        return g.op("SequenceLength", self)
457    sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
458    return symbolic_helper._squeeze_helper(g, sz_0, [0])
459
460
461@_onnx_symbolic("aten::__getitem_")
462def __getitem_(g: jit_utils.GraphContext, self, i):
463    if symbolic_helper._is_tensor_list(self):
464        # SequenceAt requires that the input be a List of Tensors
465        return g.op("SequenceAt", self, i)
466    else:
467        from torch.onnx.symbolic_opset9 import __getitem_ as getitem
468
469        return getitem(g, self, i)
470
471
472@_onnx_symbolic("aten::_set_item")
473def _set_item(g: jit_utils.GraphContext, tensor_list, i, v):
474    tensor_list = g.op("SequenceErase", tensor_list, i)
475    return g.op("SequenceInsert", tensor_list, v, i)
476
477
478@_onnx_symbolic("aten::append")
479def append(g: jit_utils.GraphContext, self, tensor):
480    return g.op("SequenceInsert", self, tensor)
481
482
483@_onnx_symbolic("aten::add")
484def add(g: jit_utils.GraphContext, self, other, alpha=None):
485    if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
486        tensor_list_node = other.node()
487        if tensor_list_node.kind() != "prim::ListConstruct":
488            return symbolic_helper._unimplemented(
489                "add", "does not support adding dynamic tensor list to another"
490            )
491        tensors = symbolic_helper._unpack_list(other)
492        l = self
493        for t in tensors:
494            l = g.op("SequenceInsert", l, t)
495        return l
496
497    return opset9.add(g, self, other, alpha)
498
499
500@_onnx_symbolic("aten::insert")
501def insert(g: jit_utils.GraphContext, self, pos, tensor):
502    return g.op("SequenceInsert", self, tensor, pos)
503
504
505@_onnx_symbolic("aten::pop")
506def pop(g: jit_utils.GraphContext, tensor_list, dim):
507    return g.op("SequenceErase", tensor_list, dim)
508
509
510@_onnx_symbolic("aten::Delete")
511def Delete(g: jit_utils.GraphContext, tensor_list, dim):
512    return g.op("SequenceErase", tensor_list, dim)
513
514
515@_onnx_symbolic("aten::cat")
516@symbolic_helper.quantized_args(True)
517def cat(g: jit_utils.GraphContext, tensor_list, dim):
518    if symbolic_helper._is_packed_list(tensor_list):
519        return opset9.cat(g, tensor_list, dim)
520    else:
521        dim = symbolic_helper._get_const(dim, "i", "dim")
522        return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
523
524
525@_onnx_symbolic("aten::stack")
526def stack(g: jit_utils.GraphContext, tensor_list, dim):
527    if symbolic_helper._is_packed_list(tensor_list):
528        return opset9.stack(g, tensor_list, dim)
529    else:
530        dim = symbolic_helper._get_const(dim, "i", "dim")
531        return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
532
533
534@_onnx_symbolic("aten::_unique2")
535@symbolic_helper.parse_args("v", "i", "i", "i")
536def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
537    u, indices, inverse_indices, counts = g.op(
538        "Unique", self, sorted_i=sorted, outputs=4
539    )
540    return u, inverse_indices, counts
541
542
543@_onnx_symbolic("aten::unique_dim")
544@symbolic_helper.parse_args("v", "i", "i", "i", "i")
545def unique_dim(
546    g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
547):
548    u, indices, inverse_indices, counts = g.op(
549        "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
550    )
551    return u, inverse_indices, counts
552
553
554@_onnx_symbolic("aten::topk")
555@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
556def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
557    return symbolic_helper._topk_helper(
558        g, self, k, dim, largest=largest, sorted=sorted, out=out
559    )
560
561
562@_onnx_symbolic("aten::sort")
563@symbolic_helper.parse_args("v", "i", "i", "none")
564def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
565    return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
566
567
568@_onnx_symbolic("aten::argsort")
569@symbolic_helper.parse_args("v", "i", "i", "none")
570def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
571    _, indices = symbolic_helper._sort_helper(
572        g, self, dim, decending=decending, out=out
573    )
574    return indices
575
576
577@_onnx_symbolic("aten::round")
578@symbolic_helper.parse_args("v", "i")
579def round(g: jit_utils.GraphContext, self, decimals=0):
580    if not symbolic_helper._is_fp(self):
581        return self
582    if decimals == 0:
583        return g.op("Round", self)
584    mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals))))
585    round = g.op("Round", mul)
586    return g.op(
587        "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals)))
588    )
589
590
591@_onnx_symbolic("aten::remainder")
592def remainder(g: jit_utils.GraphContext, input, other):
593    if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
594        return opset9.remainder(g, input, other)
595    return g.op("Mod", input, other, fmod_i=0)
596
597
598@_onnx_symbolic("aten::split")
599@symbolic_helper.parse_args("v", "v", "i", "i")
600def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
601    if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
602        split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
603        if _outputs is None:
604            return split_out
605        # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
606        if (
607            symbolic_helper._is_packed_list(split_size_or_sizes)
608            and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
609        ):
610            split_sizes = [
611                symbolic_helper._unsqueeze_helper(g, v, [0])
612                for v in symbolic_helper._unpack_list(split_size_or_sizes)
613            ]
614            start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
615            axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
616            res = []
617            for i in range(_outputs):
618                end = g.op(
619                    "Add", start, split_sizes[i]
620                )  # split_sizes is a list of same length as _outputs
621                res.append(g.op("Slice", self, start, end, axis))
622                start = end
623            return res
624        return [
625            g.op(
626                "SequenceAt",
627                split_out,
628                g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
629            )
630            for i in range(_outputs)
631        ]
632    else:
633        return opset9.split(g, self, split_size_or_sizes, dim, _outputs)
634
635
636@_onnx_symbolic("aten::split_with_sizes")
637@symbolic_helper.parse_args("v", "v", "i", "i")
638def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
639    return split(g, self, split_sizes, dim, _outputs)
640
641
642@_onnx_symbolic("aten::unbind")
643@symbolic_helper.parse_args("v", "i", "i")
644def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
645    if _outputs is None:
646        return g.op(
647            "SplitToSequence",
648            self,
649            g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
650            axis_i=dim,
651            keepdims_i=0,
652        )
653    else:
654        return opset9.unbind(g, self, dim, _outputs)
655
656
657def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
658    """Generate paddings in ONNX order based on pad in pytorch.
659
660    Args:
661        input: the input tensor.
662        pad: the paddings in pytorch.
663            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
664            where m is in range [0, n].
665    """
666    if (
667        not symbolic_helper._is_packed_list(pad)
668        and symbolic_helper._is_list(pad)
669        and symbolic_helper._is_scalar_list(pad)
670    ):
671        pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1)
672    # The desired order of paddings is
673    # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
674    # n is the dimension of input.
675    # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
676    pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
677    # Set extension = [0] * (dim * 2 - len(pad))
678    rank = symbolic_helper._get_tensor_rank(input)
679    if rank is None:
680        rank = g.op("Size", g.op("Shape", input))
681    else:
682        rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
683    extension = g.op(
684        "Sub",
685        g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))),
686        pad_len,
687    )
688    # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
689    # Currently ONNX only supports int64 type for Pad
690    pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64)
691    paddings = g.op(
692        "Concat",
693        pad,
694        g.op(
695            "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)
696        ),
697        axis_i=0,
698    )
699    # Reshape and reverse order and collate first beginnings and then ends
700    # paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
701    #               [..., 0, dim_n-1_end, dim_n_end]]
702    # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
703    paddings = symbolic_helper._reshape_helper(
704        g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))
705    )
706    paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0])
707    paddings = symbolic_helper._reshape_helper(
708        g, paddings, g.op("Constant", value_t=torch.tensor([-1]))
709    )
710    padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64)
711    return padding_c
712
713
714@_onnx_symbolic("aten::constant_pad_nd")
715def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
716    mode = "constant"
717    value = symbolic_helper._maybe_get_scalar(value)
718    value = symbolic_helper._if_scalar_type_as(value, input)
719    pad = _prepare_onnx_paddings(g, input, padding)
720    return g.op("Pad", input, pad, value, mode_s=mode)
721
722
723@_onnx_symbolic("aten::reflection_pad1d")
724@_onnx_symbolic("aten::reflection_pad2d")
725@_onnx_symbolic("aten::reflection_pad3d")
726def reflection_pad(g: jit_utils.GraphContext, input, padding):
727    mode = "reflect"
728    paddings = _prepare_onnx_paddings(g, input, padding)
729    return g.op("Pad", input, paddings, mode_s=mode)
730
731
732@_onnx_symbolic("aten::replication_pad1d")
733@_onnx_symbolic("aten::replication_pad2d")
734@_onnx_symbolic("aten::replication_pad3d")
735def replication_pad(g: jit_utils.GraphContext, input, padding):
736    mode = "edge"
737    paddings = _prepare_onnx_paddings(g, input, padding)
738    return g.op("Pad", input, paddings, mode_s=mode)
739
740
741@_onnx_symbolic("aten::pad")
742def pad(
743    g: jit_utils.GraphContext,
744    input: _C.Value,
745    pad: _C.Value,
746    mode: _C.Value,
747    value: _C.Value,
748):
749    mode = symbolic_helper._parse_arg(mode, "s")
750    if mode == "replicate":
751        return replication_pad(g, input, pad)
752    elif mode == "reflect":
753        return reflection_pad(g, input, pad)
754    elif mode == "constant":
755        return constant_pad_nd(g, input, pad, value)
756    elif mode == "circular":
757        return opset9._pad_circular(g, input, pad)
758    else:
759        raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
760
761
762@_onnx_symbolic("aten::linalg_det")
763def linalg_det(g: jit_utils.GraphContext, self):
764    return g.op("Det", self)
765
766
767@_onnx_symbolic("aten::logdet")
768def logdet(g: jit_utils.GraphContext, input):
769    return opset9.log(g, linalg_det(g, input))
770
771
772@_onnx_symbolic("aten::arange")
773def arange(g: jit_utils.GraphContext, *args):
774    def _get_arange_dtype(dtype):
775        dtype = symbolic_helper._maybe_get_const(dtype, "i")
776        return dtype
777
778    if len(args) == 2 and all(isinstance(val, int) for val in args):
779        # aten::arange(Scalar start, Scalar end)
780        dtype = torch.int64
781        # Start index.
782        start = g.op(
783            "Constant",
784            value_t=torch.tensor(args[0], dtype=dtype),
785        )
786        # End (exclusive) index.
787        end = g.op(
788            "Constant",
789            value_t=torch.tensor(args[1], dtype=dtype),
790        )
791        # Step size from start to end indexes.
792        delta_default = g.op(
793            "Constant",
794            value_t=torch.tensor(1, dtype=dtype),
795        )
796        return g.op("Range", start, end, delta_default)
797    elif len(args) == 2 or len(args) == 5:
798        if len(args) == 2:
799            # aten::arange(Scalar end, Tensor out)
800            dtype = None
801        else:
802            # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
803            dtype = _get_arange_dtype(args[1])
804        type_, end, start, step = symbolic_helper._arange_cast_helper(
805            g, end=args[0], dtype=dtype
806        )
807        start_default = g.op(
808            "Constant",
809            value_t=torch.tensor(0, dtype=type_.dtype()),
810        )
811        delta_default = g.op(
812            "Constant",
813            value_t=torch.tensor(1, dtype=type_.dtype()),
814        )
815        return g.op("Range", start_default, end, delta_default)
816    elif len(args) == 4 or len(args) == 7:
817        if len(args) == 4:
818            # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
819            dtype = None
820        else:
821            # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
822            dtype = _get_arange_dtype(args[3])
823        _, end, start, step = symbolic_helper._arange_cast_helper(
824            g, start=args[0], end=args[1], step=args[2], dtype=dtype
825        )
826        return g.op("Range", start, end, step)
827    elif len(args) == 6:
828        # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
829        dtype = _get_arange_dtype(args[2])
830        type_, end, start, step = symbolic_helper._arange_cast_helper(
831            g, start=args[0], end=args[1], dtype=dtype
832        )
833        delta_default = g.op(
834            "Constant",
835            value_t=torch.tensor(1, dtype=type_.dtype()),
836        )
837        return g.op("Range", start, end, delta_default)
838    else:
839        return symbolic_helper._unimplemented(
840            "aten::arange", f"with {len(args)} arguments"
841        )
842
843
844@_onnx_symbolic("aten::_dim_arange")
845@symbolic_helper.parse_args("v", "i")
846def _dim_arange(g: jit_utils.GraphContext, like, dim):
847    like_shape = g.op("Shape", like)
848    stop = g.op(
849        "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
850    )
851    return arange(g, stop, 4, None, None, None)
852
853
854@_onnx_symbolic("aten::size")
855@symbolic_helper.quantized_args(True, quantize_output=False)
856def size(g: jit_utils.GraphContext, self, dim=None):
857    if dim is None:
858        return g.op("Shape", self)
859    return symbolic_helper._size_helper(g, self, dim)
860
861
862@_onnx_symbolic("aten::squeeze")
863def squeeze(g: jit_utils.GraphContext, self, dim=None):
864    if dim is None:
865        return g.op("Squeeze", self)
866
867    # dim as a tensor
868    if not symbolic_helper._is_constant(dim):
869        return symbolic_helper._squeeze_helper(g, self, [dim])
870
871    dim = symbolic_helper._get_const(dim, "i", "dim")
872
873    input_rank = symbolic_helper._get_tensor_rank(self)
874    adjusted_dim = dim
875    if input_rank is not None and dim < 0:
876        adjusted_dim += input_rank
877    dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim)
878    if (dim < 0 and input_rank is None) or dim_size is None:
879        # If onnx shape inference is not on, export always as dynamic.
880        # Because we cannot tell if observed static shape is also static at runtime.
881        # create "cond" node (condition is shape[i]==1)
882        dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
883        size = symbolic_helper._size_helper(g, self, dim_constant)
884        const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
885        cond = g.op("Equal", size, const_one)
886        # create the "If" node and add the "then" and "else" blocks to it.
887        if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
888            g, "If", cond, n_blocks=2
889        )
890        squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim])
891        utils._add_output_to_block(if_context.block, squeeze_)
892        identity_ = else_context.op("Identity", self)
893        utils._add_output_to_block(else_context.block, identity_)
894        return if_op
895
896    # For static input shape
897    dim = adjusted_dim
898    if dim_size > 1:
899        warnings.warn(
900            "This model contains a squeeze operation on dimension "
901            + str(dim)
902            + ". The size of "
903            + "this dimension in the given input is "
904            + str(dim_size)
905            + ". The model will "
906            + "be exported without the squeeze node. If the model is intended to be used with dynamic "
907            + "input shapes, please export with dynamic_axes argument."
908        )
909        return self
910    return symbolic_helper._squeeze_helper(g, self, [dim])
911
912
913@_onnx_symbolic("aten::unsqueeze")
914def unsqueeze(g: jit_utils.GraphContext, self, dim):
915    if symbolic_helper._is_constant(dim):
916        dim = symbolic_helper._get_const(dim, "i", "dim")
917
918    return symbolic_helper._unsqueeze_helper(g, self, [dim])
919
920
921@_onnx_symbolic("aten::mm")
922def mm(g: jit_utils.GraphContext, self, other):
923    return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
924
925
926@_onnx_symbolic("aten::index")
927def index(g: jit_utils.GraphContext, self, index):
928    if symbolic_helper._is_packed_list(index):
929        indices = symbolic_helper._unpack_list(index)
930    else:
931        indices = [index]
932
933    # Handle single mask index.
934    if len(indices) == 1:
935        index = indices[0]
936        if not symbolic_helper._is_none(index) and (
937            symbolic_helper._is_bool(index)
938            or _type_utils.JitScalarType.from_value(index)
939            == _type_utils.JitScalarType.UINT8
940        ):
941            index = opset9.nonzero(g, index)
942            return g.op("GatherND", self, index)
943    return opset9.index(g, self, index)
944
945
946@_onnx_symbolic("aten::index_fill")
947def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
948    dim_value = symbolic_helper._parse_arg(dim, "i")
949    expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
950        g, self, dim, index
951    )
952    value = symbolic_helper._maybe_get_scalar(value)
953    value = symbolic_helper._if_scalar_type_as(value, self)
954    expanded_value = opset9.expand(g, value, expanded_index_shape, None)
955    return scatter(g, self, dim, expanded_index, expanded_value)
956
957
958@_onnx_symbolic("aten::index_copy")
959def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
960    dim_value = symbolic_helper._parse_arg(dim, "i")
961    expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
962        g, self, dim, index
963    )
964    return scatter(g, self, dim, expanded_index, source)
965
966
967@_onnx_symbolic("aten::bitwise_right_shift")
968@_onnx_symbolic("aten::__rshift_")
969def __rshift_(g: jit_utils.GraphContext, self, other):
970    # make sure to cast other to self's type
971    # (when self is long, make sure that other is not float)
972    if _type_utils.JitScalarType.from_value(
973        other, _type_utils.JitScalarType.UNDEFINED
974    ) != _type_utils.JitScalarType.from_value(self):
975        other = g.op(
976            "Cast",
977            other,
978            to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
979        )
980
981    if (
982        _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
983        == _type_utils.JitScalarType.UINT8
984    ):
985        return g.op("BitShift", self, other, direction_s="RIGHT")
986
987    two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
988    # exponent (same type as self) has to be float or double in onnx::Pow
989    if not symbolic_helper._is_fp(self):
990        other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
991    two_pow = g.op("Pow", two, other)
992    two_pow = g.op(
993        "Cast",
994        two_pow,
995        to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
996    )
997    rshift = g.op("Div", self, two_pow)
998    return rshift
999
1000
1001@_onnx_symbolic("aten::bitwise_left_shift")
1002@_onnx_symbolic("aten::__lshift_")
1003def __lshift_(g: jit_utils.GraphContext, self, other):
1004    # make sure to cast other to self's type
1005    # (when self is long, make sure that other is not float)
1006    if _type_utils.JitScalarType.from_value(
1007        other, _type_utils.JitScalarType.UNDEFINED
1008    ) != _type_utils.JitScalarType.from_value(self):
1009        other = g.op(
1010            "Cast",
1011            other,
1012            to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
1013        )
1014
1015    if (
1016        _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
1017        == _type_utils.JitScalarType.UINT8
1018    ):
1019        return g.op("BitShift", self, other, direction_s="LEFT")
1020
1021    two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
1022    # exponent (same type as self) has to be float or double in onnx::Pow
1023    if not symbolic_helper._is_fp(self):
1024        other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
1025    two_pow = g.op("Pow", two, other)
1026    two_pow = g.op(
1027        "Cast",
1028        two_pow,
1029        to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
1030    )
1031    lshift = g.op("Mul", self, two_pow)
1032    return lshift
1033
1034
1035def _get_im2col_indices_along_dim(
1036    g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d
1037):
1038    # Input is always 4-D (N, C, H, W)
1039    # Calculate indices of sliding blocks along spatial dimension
1040    # Slide kernel over input each dim d:
1041    # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
1042    # with steps = stride
1043
1044    blocks_d = g.op(
1045        "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2))
1046    )
1047    blocks_d = g.op(
1048        "Sub",
1049        blocks_d,
1050        g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))),
1051    )
1052
1053    # Stride kernel over input and find starting indices along dim d
1054    blocks_d_indices = g.op(
1055        "Range",
1056        g.op("Constant", value_t=torch.tensor(0)),
1057        blocks_d,
1058        g.op("Constant", value_t=torch.tensor(stride_d)),
1059    )
1060
1061    # Apply dilation on kernel and find its indices along dim d
1062    kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d)
1063    kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0))
1064
1065    # Broadcast and add kernel staring positions (indices) with
1066    # kernel_grid along dim d, to get block indices along dim d
1067    blocks_d_indices = symbolic_helper._unsqueeze_helper(
1068        g, blocks_d_indices, [0]
1069    )  # Reshape to [1, -1]
1070    kernel_mask = symbolic_helper._reshape_helper(
1071        g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1]))
1072    )
1073    block_mask = g.op("Add", blocks_d_indices, kernel_mask)
1074
1075    return block_mask
1076
1077
1078def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w):
1079    # Input is always 4-D tensor (N, C, H, W)
1080    # Padding tensor has the following format: (padding_h, padding_w)
1081    # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
1082    pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
1083    return g.op("Pad", input, pad)
1084
1085
1086def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w):
1087    batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
1088    channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
1089    channel_unfolded = g.op(
1090        "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))
1091    )
1092
1093    return g.op(
1094        "Concat",
1095        symbolic_helper._unsqueeze_helper(g, batch_dim, [0]),
1096        symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]),
1097        g.op("Constant", value_t=torch.tensor([-1])),
1098        axis_i=0,
1099    )
1100
1101
1102@_onnx_symbolic("aten::im2col")
1103@symbolic_helper.parse_args("v", "is", "is", "is", "is")
1104def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride):
1105    # Input is always 4-D tensor (N, C, H, W)
1106    # All other args are int[2]
1107
1108    input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
1109    input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
1110
1111    stride_h, stride_w = stride[0], stride[1]
1112    padding_h, padding_w = padding[0], padding[1]
1113    dilation_h, dilation_w = dilation[0], dilation[1]
1114    kernel_h, kernel_w = kernel_size[0], kernel_size[1]
1115
1116    blocks_row_indices = _get_im2col_indices_along_dim(
1117        g, input_h, kernel_h, dilation_h, padding_h, stride_h
1118    )
1119    blocks_col_indices = _get_im2col_indices_along_dim(
1120        g, input_w, kernel_w, dilation_w, padding_w, stride_w
1121    )
1122
1123    output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
1124    padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
1125
1126    # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
1127    # [[[[1., 2., 3.,],
1128    #    [4., 5., 6.,],
1129    #    [7., 8., 9.,]]]]
1130    # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
1131    # [[[[[1., 2., 3.],
1132    #     [4., 5., 6.]],
1133    #    [[4., 5., 6.],
1134    #     [7., 8., 9.]]]]]
1135    # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
1136    # [[[[[[1., 2.],
1137    #      [4., 5.]],
1138    #     [[2., 3.],
1139    #      [5., 6]]],
1140    #    [[[4., 5.],
1141    #      [7., 8.]],
1142    #     [[5., 6.],
1143    #      [8., 9.]]]]]]
1144    # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
1145    #  [[[1., 2., 4., 5.],
1146    #    [2., 3., 5., 6.],
1147    #    [4., 5., 7., 8.],
1148    #    [5., 6., 8., 9.]]]
1149    output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
1150    output = g.op("Gather", output, blocks_col_indices, axis_i=4)
1151    output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
1152    return symbolic_helper._reshape_helper(g, output, output_shape)
1153
1154
1155@_onnx_symbolic("aten::narrow")
1156def narrow(g: jit_utils.GraphContext, input, dim, start, length):
1157    end = g.op("Add", start, length)
1158    return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end)
1159
1160
1161@_onnx_symbolic("aten::flatten")
1162@symbolic_helper.quantized_args(True, False, False)
1163@symbolic_helper.parse_args("v", "i", "i")
1164def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
1165    dim = symbolic_helper._get_tensor_rank(input)
1166    if dim == 1:
1167        return input
1168    # use ONNX's Flatten operator for cases where the output shape is 2D
1169    if start_dim == 1:
1170        if end_dim == -1 or (dim is not None and end_dim == dim - 1):
1171            return g.op("Flatten", input, axis_i=start_dim)
1172    elif start_dim == 0:
1173        if end_dim == -2 or (dim is not None and end_dim == dim - 2):
1174            return g.op("Flatten", input, axis_i=end_dim + 1)
1175    if dim is None:
1176        return symbolic_helper._unimplemented(
1177            "dim",
1178            "ONNX and PyTorch use different strategies to split the input. "
1179            "Input rank must be known at export time.",
1180        )
1181    # if end_dim is negative add dim
1182    if end_dim < 0:
1183        end_dim = dim + end_dim
1184
1185    return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
1186
1187
1188@_onnx_symbolic("aten::linalg_vector_norm")
1189@symbolic_helper.parse_args("v", "f", "is", "b", "v")
1190def linalg_vector_norm(
1191    g: jit_utils.GraphContext,
1192    self,
1193    ord,
1194    dim: Sequence[int] | None,
1195    keepdim: bool,
1196    dtype,
1197):
1198    return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)
1199
1200
1201@_onnx_symbolic("aten::embedding_bag")
1202@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
1203def embedding_bag(
1204    g: jit_utils.GraphContext,
1205    embedding_matrix,
1206    indices,
1207    offsets,
1208    scale_grad_by_freq,
1209    mode,
1210    sparse,
1211    per_sample_weights,
1212    include_last_offset,
1213    padding_idx,
1214):
1215    return symbolic_helper._embedding_bag_helper(
1216        g,
1217        embedding_matrix,
1218        indices,
1219        offsets,
1220        scale_grad_by_freq,
1221        mode,
1222        sparse,
1223        per_sample_weights,
1224        include_last_offset,
1225        padding_idx,
1226    )
1227
1228
1229@_onnx_symbolic("aten::embedding_renorm")
1230@symbolic_helper.parse_args("v", "v", "f", "f")
1231def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
1232    unique_indices = g.op("Unique", indices)
1233    partial_weight = g.op("Gather", weight, unique_indices)
1234    norm_i = int(norm_type)
1235    if norm_i == 1:
1236        norm_type = "ReduceL1"
1237    elif norm_i == 2:
1238        norm_type = "ReduceL2"
1239    else:
1240        raise errors.SymbolicValueError(
1241            f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. "
1242            "Only 1. and 2. are supported.",
1243            weight,
1244        )
1245    partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1)
1246    # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177
1247    # Add 1e-7 to prevent division by zero.
1248    partial_weight_norm_ = g.op(
1249        "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7))
1250    )
1251    max_norm = torch.tensor(max_norm)
1252    scales = g.op("Div", max_norm, partial_weight_norm_)
1253    partial_weight_renorm = g.op("Mul", partial_weight, scales)
1254    partial_weight_renorm = g.op(
1255        "Where",
1256        g.op("Greater", partial_weight_norm, max_norm),
1257        partial_weight_renorm,
1258        partial_weight,
1259    )
1260    return g.op(
1261        "ScatterND",
1262        weight,
1263        symbolic_helper._unsqueeze_helper(g, unique_indices, [1]),
1264        partial_weight_renorm,
1265    )
1266
1267
1268@_onnx_symbolic("aten::chunk")
1269def chunk(g: jit_utils.GraphContext, self, chunks, dim):
1270    # Calculate chunk size for dynamic chunk
1271    dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
1272    chunk_size_s = g.op(
1273        "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long))
1274    )
1275    chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks)
1276    # Create splits vector
1277    chunk_vec = [
1278        opset9.expand(g, chunk_size, chunk_size_s, None),
1279        g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)),
1280    ]
1281    chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
1282    return split(g, self, chunk_vec, dim)
1283
1284
1285@_onnx_symbolic("aten::normal")
1286def normal(
1287    g: jit_utils.GraphContext,
1288    mean,
1289    std,
1290    sizes=None,
1291    generator=None,
1292    dtype=None,
1293    layout=None,
1294    device=None,
1295    pin_memory=None,
1296):
1297    # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
1298    # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample
1299    # from a mean 0 and variance 1 distribution then
1300    #       sigma x+mu
1301    # is a sample with mean mu and variance sigma's square.
1302    if sizes is not None and not symbolic_helper._is_none(sizes):
1303        mean = opset9.expand(g, mean, sizes, None)
1304    result = opset9.mul(g, std, g.op("RandomNormalLike", mean))
1305    return add(g, result, mean)
1306
1307
1308@_onnx_symbolic("aten::atleast_1d")
1309def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value):
1310    # NOTE: If it's 0D, reshape to 1D
1311
1312    # NOTE: self could be a packed list or a tensor
1313    if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self):
1314        tensor_list = symbolic_helper._unpack_list(self)
1315        new_tensor_list = []
1316        for tensor in tensor_list:
1317            new_tensor = tensor
1318            tensor_rank = symbolic_helper._get_tensor_rank(tensor)
1319            if tensor_rank == 0:
1320                new_tensor = symbolic_helper._reshape_helper(
1321                    g, new_tensor, g.op("Constant", value_t=torch.tensor([1]))
1322                )
1323            new_tensor_list.append(new_tensor)
1324        return g.op("SequenceConstruct", *new_tensor_list)
1325
1326    tensor_rank = symbolic_helper._get_tensor_rank(self)
1327    if tensor_rank == 0:
1328        self = symbolic_helper._reshape_helper(
1329            g, self, g.op("Constant", value_t=torch.tensor([1]))
1330        )
1331    return self
1332
1333
1334@_onnx_symbolic("aten::atleast_2d")
1335def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value):
1336    # NOTE: If it's 0D, reshape to 2D
1337    #       If it's 1D, unsqueeze to 2D
1338
1339    # NOTE: self could be a packed list or a tensor
1340    if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self):
1341        tensor_list = symbolic_helper._unpack_list(self)
1342        new_tensor_list = []
1343        for tensor in tensor_list:
1344            new_tensor = tensor
1345            tensor_rank = symbolic_helper._get_tensor_rank(tensor)
1346            if tensor_rank == 0:
1347                new_tensor = symbolic_helper._reshape_helper(
1348                    g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1]))
1349                )
1350            elif tensor_rank == 1:
1351                new_tensor = symbolic_helper._unsqueeze_helper(
1352                    g, new_tensor, axes_i=[0]
1353                )
1354            new_tensor_list.append(new_tensor)
1355        return g.op("SequenceConstruct", *new_tensor_list)
1356
1357    tensor_rank = symbolic_helper._get_tensor_rank(self)
1358    if tensor_rank == 0:
1359        self = symbolic_helper._reshape_helper(
1360            g, self, g.op("Constant", value_t=torch.tensor([1, 1]))
1361        )
1362    elif tensor_rank == 1:
1363        self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0])
1364    return self
1365
1366
1367@_onnx_symbolic("aten::atleast_3d")
1368def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value):
1369    # NOTE: If it's 0D, reshape to 3D
1370    #       If it's 1D, unsqueeze to 3D
1371    #       If it's 2D, unsqueeze to 3D
1372
1373    # NOTE: self could be a packed list or a tensor
1374    if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self):
1375        tensor_list = symbolic_helper._unpack_list(self)
1376        new_tensor_list = []
1377        for tensor in tensor_list:
1378            new_tensor = tensor
1379            tensor_rank = symbolic_helper._get_tensor_rank(tensor)
1380            if tensor_rank == 0:
1381                new_tensor = symbolic_helper._reshape_helper(
1382                    g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1]))
1383                )
1384            elif tensor_rank == 1:
1385                new_tensor = symbolic_helper._unsqueeze_helper(
1386                    g, new_tensor, axes_i=[0]
1387                )
1388                new_tensor = symbolic_helper._unsqueeze_helper(
1389                    g, new_tensor, axes_i=[-1]
1390                )
1391            elif tensor_rank == 2:
1392                new_tensor = symbolic_helper._unsqueeze_helper(
1393                    g, new_tensor, axes_i=[-1]
1394                )
1395            new_tensor_list.append(new_tensor)
1396        return g.op("SequenceConstruct", *new_tensor_list)
1397
1398    tensor_rank = symbolic_helper._get_tensor_rank(self)
1399    if tensor_rank == 0:
1400        self = symbolic_helper._reshape_helper(
1401            g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1]))
1402        )
1403    elif tensor_rank == 1:
1404        self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0])
1405        self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1])
1406    elif tensor_rank == 2:
1407        self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1])
1408    return self
1409
1410
1411@_onnx_symbolic("prim::ConstantChunk")
1412def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
1413    input_shape = g.op("Shape", self)
1414    axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
1415    input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
1416    start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
1417    chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
1418    chunk_size_minus_1 = g.op(
1419        "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)
1420    )
1421    input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
1422    chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
1423    res = []
1424    for i in range(chunks):
1425        index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
1426        end = g.op("Mul", chunk_dim, index)
1427        res.append(g.op("Slice", self, start, end, axis))
1428        start = end
1429    return res
1430
1431
1432@_onnx_symbolic("aten::hstack")
1433def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
1434    tensor_list = atleast_1d(g, tensor_list)
1435    first_tensor = g.op(
1436        "SequenceAt",
1437        tensor_list,
1438        g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)),
1439    )
1440    first_tensor_shape = g.op("Shape", first_tensor)
1441    first_tensor_dim = g.op("Size", first_tensor_shape)
1442
1443    const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))
1444    equal_to_one = g.op("Equal", first_tensor_dim, const_one)
1445
1446    (
1447        if_op_greater,
1448        (if_context_equal, else_context_equal),
1449        _,
1450    ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1)
1451    result_if = if_context_equal.op(
1452        "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0
1453    )
1454    utils._add_output_to_block(if_context_equal.block, result_if)
1455    result_else = else_context_equal.op(
1456        "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0
1457    )
1458    utils._add_output_to_block(else_context_equal.block, result_else)
1459    result = if_op_greater.node().output()
1460
1461    return result
1462
1463
1464@_onnx_symbolic("aten::vstack")
1465def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
1466    tensor_list = atleast_2d(g, tensor_list)
1467    return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0)
1468