1# mypy: allow-untyped-defs 2# mypy: disable-error-code=arg-type 3"""This file exports ONNX ops for opset 11.""" 4 5from __future__ import annotations 6 7import functools 8import sys 9import warnings 10from typing import Sequence 11 12import torch 13from torch import _C 14from torch._C import _onnx as _C_onnx 15from torch.onnx import ( 16 _type_utils, 17 errors, 18 symbolic_helper, 19 symbolic_opset10 as opset10, 20 symbolic_opset9 as opset9, 21 utils, 22) 23from torch.onnx._internal import jit_utils, registration 24 25 26# EDITING THIS FILE? READ THIS FIRST! 27# see Note [Edit Symbolic Files] in README.md 28 29__all__ = [ 30 "add", 31 "append", 32 "arange", 33 "argsort", 34 "atleast_1d", 35 "atleast_2d", 36 "atleast_3d", 37 "cat", 38 "chunk", 39 "clamp_max", 40 "clamp_min", 41 "clamp", 42 "constant_pad_nd", 43 "cumsum", 44 "Delete", 45 "embedding_bag", 46 "embedding_renorm", 47 "flatten", 48 "gather", 49 "hardtanh", 50 "hstack", 51 "im2col", 52 "index_fill", 53 "index", 54 "index_copy", 55 "index_put", 56 "insert", 57 "linalg_det", 58 "linalg_vector_norm", 59 "logdet", 60 "masked_scatter", 61 "masked_select", 62 "mm", 63 "narrow", 64 "normal", 65 "pad", 66 "pixel_shuffle", 67 "pop", 68 "prim_constant_chunk", 69 "reflection_pad", 70 "relu6", 71 "remainder", 72 "replication_pad", 73 "round", 74 "scatter", 75 "select", 76 "size", 77 "sort", 78 "split_with_sizes", 79 "split", 80 "squeeze", 81 "stack", 82 "topk", 83 "unbind", 84 "unique_dim", 85 "unsqueeze", 86 "vstack", 87] 88 89_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) 90 91 92@_onnx_symbolic("aten::hardtanh") 93@symbolic_helper.quantized_args(True) 94@symbolic_helper.parse_args("v", "f", "f") 95def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): 96 scalar_type = _type_utils.JitScalarType.from_value( 97 self, _type_utils.JitScalarType.FLOAT 98 ) 99 min_val = g.op( 100 "Constant", 101 value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), 102 ) 103 max_val = g.op( 104 "Constant", 105 value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), 106 ) 107 return symbolic_helper._op_with_optional_float_cast( 108 g, "Clip", self, min_val, max_val, opset_before=12 109 ) 110 111 112@_onnx_symbolic("aten::clamp") 113def clamp(g: jit_utils.GraphContext, self, min, max): 114 def _cast_if_not_none(tensor, dtype): 115 if tensor is not None and not symbolic_helper._is_none(tensor): 116 return g.op( 117 "Cast", 118 tensor, 119 to_i=dtype.onnx_type(), 120 ) 121 else: 122 return tensor 123 124 scalar_type = _type_utils.JitScalarType.from_value( 125 self, _type_utils.JitScalarType.UNDEFINED 126 ) 127 if scalar_type != _type_utils.JitScalarType.UNDEFINED: 128 min = _cast_if_not_none(min, scalar_type) 129 max = _cast_if_not_none(max, scalar_type) 130 131 if symbolic_helper._is_none(min): 132 return clamp_max(g, self, max) 133 elif symbolic_helper._is_none(max): 134 return clamp_min(g, self, min) 135 else: 136 if ( 137 symbolic_helper._get_tensor_rank(min) == 0 138 and symbolic_helper._get_tensor_rank(max) == 0 139 ): 140 return symbolic_helper._op_with_optional_float_cast( 141 g, "Clip", self, min, max, opset_before=12 142 ) 143 else: 144 return clamp_max(g, clamp_min(g, self, min), max) 145 146 147@_onnx_symbolic("aten::clamp_min") 148@symbolic_helper.parse_args("v", "v") 149def clamp_min(g: jit_utils.GraphContext, self, min): 150 min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) 151 if symbolic_helper._get_tensor_rank(min) == 0: 152 max = opset9.unused(g) 153 return symbolic_helper._op_with_optional_float_cast( 154 g, "Clip", self, min, max, opset_before=12 155 ) 156 else: 157 return symbolic_helper._op_with_optional_float_cast( 158 g, "Max", self, min, opset_before=12 159 ) 160 161 162@_onnx_symbolic("aten::clamp_max") 163@symbolic_helper.parse_args("v", "v") 164def clamp_max(g: jit_utils.GraphContext, self, max): 165 max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) 166 if symbolic_helper._get_tensor_rank(max) == 0: 167 min = opset9.unused(g) 168 return symbolic_helper._op_with_optional_float_cast( 169 g, "Clip", self, min, max, opset_before=12 170 ) 171 else: 172 return symbolic_helper._op_with_optional_float_cast( 173 g, "Min", self, max, opset_before=12 174 ) 175 176 177@_onnx_symbolic("aten::relu6") 178def relu6(g: jit_utils.GraphContext, input): 179 scalar_type = _type_utils.JitScalarType.from_value( 180 input, _type_utils.JitScalarType.FLOAT 181 ) 182 min_val = g.op( 183 "Constant", 184 value_t=torch.tensor(0, dtype=scalar_type.dtype()), 185 ) 186 max_val = g.op( 187 "Constant", 188 value_t=torch.tensor(6, dtype=scalar_type.dtype()), 189 ) 190 return clamp(g, input, min_val, max_val) 191 192 193@_onnx_symbolic("aten::select") 194# Opset 11 gather accepts negative indices 195@symbolic_helper.quantized_args(True) 196@symbolic_helper.parse_args("v", "i", "v") 197def select(g: jit_utils.GraphContext, self, dim, index): 198 return g.op("Gather", self, index, axis_i=dim) 199 200 201@_onnx_symbolic("aten::index_put") 202def index_put( 203 g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False 204): 205 if symbolic_helper._is_packed_list(indices_list_value): 206 indices_list = symbolic_helper._unpack_list(indices_list_value) 207 else: 208 indices_list = [indices_list_value] 209 accumulate = symbolic_helper._parse_arg(accumulate, "b") 210 211 if len(indices_list) == 0: 212 return values 213 214 if len(indices_list) > 1: 215 for idx_ in range(len(indices_list)): 216 if symbolic_helper._is_bool(indices_list[idx_]): 217 indices_list[idx_] = g.op("NonZero", indices_list[idx_]) 218 index = indices_list[0] 219 220 for ind in indices_list[1:]: 221 index = opset9.add(g, index, ind) 222 broadcast_index_shape = g.op("Shape", index) 223 indices_list = [ 224 symbolic_helper._unsqueeze_helper( 225 g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] 226 ) 227 for ind in indices_list 228 ] 229 index = g.op("Concat", *indices_list, axis_i=-1) 230 else: 231 # Replace index_put node with masked_scatter or masked_fill 232 # when inputs to the index_put node contains a single boolean input. 233 # 234 # index_put -> masked_fill 235 # * input index contains single tensor of Bool type (e.g.: %24 <- %23). 236 # * input value contains single element (e.g.: %18). 237 # 238 # Torch IR 239 # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) 240 # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = 241 # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) 242 # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() 243 # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) 244 # %24 : Tensor?[] = prim::ListConstruct(%23) 245 # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = 246 # aten::index_put(%mask, %24, %18, %30) 247 # return (%25) 248 # 249 # 250 # index_put -> masked_scatter 251 # * input index contains single tensor of Bool type (e.g.: %32 <- %31). 252 # * input value contains multiple elements (e.g.: %28). 253 # 254 # Torch IR 255 # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) 256 # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) 257 # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() 258 # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) 259 # = aten::ne(%mask, %some_const) 260 # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) 261 # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) 262 # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() 263 # %30 : int[] = prim::Constant[value=[-1]]() 264 # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) 265 # %32 : Tensor?[] = prim::ListConstruct(%31) 266 # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) 267 # = aten::index_put(%mask, %32, %28, %38) 268 # return (%33) 269 index = indices_list[0] 270 bool_inp = index 271 if symbolic_helper._is_bool(bool_inp): 272 rank = symbolic_helper._get_tensor_rank(values) 273 if rank is not None and rank == 0: 274 return opset9.masked_fill(g, self, bool_inp, values) 275 mask_rank = symbolic_helper._get_tensor_rank(bool_inp) 276 self_rank = symbolic_helper._get_tensor_rank(self) 277 if ( 278 mask_rank is not None 279 and self_rank is not None 280 and self_rank > mask_rank 281 ): 282 # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. 283 bool_inp = symbolic_helper._unsqueeze_helper( 284 g, bool_inp, list(range(mask_rank, self_rank)) 285 ) 286 return masked_scatter(g, self, bool_inp, values) 287 broadcast_index_shape = g.op("Shape", index) 288 index = symbolic_helper._unsqueeze_helper(g, index, [-1]) 289 sub_data_shape = symbolic_helper._slice_helper( 290 g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] 291 ) 292 values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) 293 # Check if values is a singular value and expand accordingly 294 rank = symbolic_helper._get_tensor_rank(values) 295 if rank is not None and rank == 0: 296 values = opset9.expand(g, values, values_shape, None) 297 values = symbolic_helper._reshape_helper(g, values, values_shape) 298 299 self_scalar_type = _type_utils.JitScalarType.from_value( 300 self, _type_utils.JitScalarType.UNDEFINED 301 ) 302 if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: 303 values_scalar_type = _type_utils.JitScalarType.from_value( 304 values, _type_utils.JitScalarType.UNDEFINED 305 ) 306 if self_scalar_type != values_scalar_type: 307 values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) 308 elif accumulate: 309 raise errors.SymbolicValueError("self does not have a valid scalar type.", self) 310 311 if accumulate: 312 zeros = g.op( 313 "ConstantOfShape", 314 g.op("Shape", self), 315 value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), 316 ) 317 result = g.op("ScatterND", zeros, index, values) 318 result = add(g, self, result) 319 else: 320 result = g.op("ScatterND", self, index, values) 321 322 return result 323 324 325@_onnx_symbolic("aten::pixel_shuffle") 326@symbolic_helper.parse_args("v", "i") 327def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): 328 rank = symbolic_helper._get_tensor_rank(self) 329 if rank is not None and rank != 4: 330 return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") 331 return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") 332 333 334@_onnx_symbolic( 335 "aten::upsample_nearest1d", 336 decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], 337) 338@_onnx_symbolic( 339 "aten::upsample_nearest2d", 340 decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], 341) 342@_onnx_symbolic( 343 "aten::upsample_nearest3d", 344 decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], 345) 346@_onnx_symbolic( 347 "aten::upsample_linear1d", 348 decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], 349) 350@_onnx_symbolic( 351 "aten::upsample_bilinear2d", 352 decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], 353) 354@_onnx_symbolic( 355 "aten::upsample_trilinear3d", 356 decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], 357) 358@_onnx_symbolic( 359 "aten::upsample_bicubic2d", 360 decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")], 361) 362def _interpolate(name: str, dim: int, interpolate_mode: str): 363 return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) 364 365 366@_onnx_symbolic("aten::__interpolate") 367@symbolic_helper.quantized_args(True, False, False, False, False, False, False) 368def __interpolate( 369 g: jit_utils.GraphContext, 370 input, 371 size, 372 scale_factor, 373 mode, 374 align_corners, 375 recompute_scale_factor, 376 antialias, 377): 378 return symbolic_helper.__interpolate_helper( 379 g, input, size, scale_factor, mode, align_corners, recompute_scale_factor 380 ) 381 382 383@_onnx_symbolic("aten::gather") 384@symbolic_helper.parse_args("v", "i", "v", "v") 385def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): 386 if symbolic_helper._maybe_get_const(sparse_grad, "i"): 387 return symbolic_helper._unimplemented("gather", "sparse_grad == True") 388 return g.op("GatherElements", self, index, axis_i=dim) 389 390 391@_onnx_symbolic("aten::scatter") 392@symbolic_helper.parse_args("v", "i", "v", "v") 393def scatter(g: jit_utils.GraphContext, self, dim, index, src): 394 src_type = _type_utils.JitScalarType.from_value(src) 395 src = symbolic_helper._maybe_get_scalar(src) 396 if symbolic_helper._is_value(src): 397 return g.op("ScatterElements", self, index, src, axis_i=dim) 398 else: 399 # Check if scalar "src" has same type as self (PyTorch allows different 400 # type for scalar src (but not when src is tensor)). If not, insert Cast node. 401 if _type_utils.JitScalarType.from_value(self) != src_type: 402 src = g.op( 403 "Cast", 404 src, 405 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 406 ) 407 return g.op( 408 "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim 409 ) 410 411 412@_onnx_symbolic("aten::cumsum") 413@symbolic_helper.parse_args("v", "i", "none") 414def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): 415 dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) 416 if dtype and dtype.node().kind() != "prim::Constant": 417 parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") 418 cast = g.op( 419 "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() 420 ) 421 else: 422 cast = self 423 csum = g.op("CumSum", cast, dim_tensor) 424 return csum 425 426 427@_onnx_symbolic("aten::masked_select") 428def masked_select(g: jit_utils.GraphContext, self, mask): 429 index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) 430 return g.op("GatherND", self, index) 431 432 433@_onnx_symbolic("aten::masked_scatter") 434def masked_scatter(g: jit_utils.GraphContext, self, mask, source): 435 index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) 436 # NOTE: source can have more elements than needed. 437 # It could also have arbitrary shape. 438 # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. 439 source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) 440 source = symbolic_helper._slice_helper( 441 g, 442 source, 443 axes=torch.LongTensor([0]), 444 starts=torch.LongTensor([0]), 445 ends=opset9.size(g, index, torch.LongTensor([0])), 446 ) 447 return g.op("ScatterND", self, index, source) 448 449 450@_onnx_symbolic("aten::len") 451def _len(g: jit_utils.GraphContext, self): 452 if ( 453 symbolic_helper._is_tensor_list(self) 454 or self.node().kind() == "onnx::SplitToSequence" 455 ): 456 return g.op("SequenceLength", self) 457 sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) 458 return symbolic_helper._squeeze_helper(g, sz_0, [0]) 459 460 461@_onnx_symbolic("aten::__getitem_") 462def __getitem_(g: jit_utils.GraphContext, self, i): 463 if symbolic_helper._is_tensor_list(self): 464 # SequenceAt requires that the input be a List of Tensors 465 return g.op("SequenceAt", self, i) 466 else: 467 from torch.onnx.symbolic_opset9 import __getitem_ as getitem 468 469 return getitem(g, self, i) 470 471 472@_onnx_symbolic("aten::_set_item") 473def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): 474 tensor_list = g.op("SequenceErase", tensor_list, i) 475 return g.op("SequenceInsert", tensor_list, v, i) 476 477 478@_onnx_symbolic("aten::append") 479def append(g: jit_utils.GraphContext, self, tensor): 480 return g.op("SequenceInsert", self, tensor) 481 482 483@_onnx_symbolic("aten::add") 484def add(g: jit_utils.GraphContext, self, other, alpha=None): 485 if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): 486 tensor_list_node = other.node() 487 if tensor_list_node.kind() != "prim::ListConstruct": 488 return symbolic_helper._unimplemented( 489 "add", "does not support adding dynamic tensor list to another" 490 ) 491 tensors = symbolic_helper._unpack_list(other) 492 l = self 493 for t in tensors: 494 l = g.op("SequenceInsert", l, t) 495 return l 496 497 return opset9.add(g, self, other, alpha) 498 499 500@_onnx_symbolic("aten::insert") 501def insert(g: jit_utils.GraphContext, self, pos, tensor): 502 return g.op("SequenceInsert", self, tensor, pos) 503 504 505@_onnx_symbolic("aten::pop") 506def pop(g: jit_utils.GraphContext, tensor_list, dim): 507 return g.op("SequenceErase", tensor_list, dim) 508 509 510@_onnx_symbolic("aten::Delete") 511def Delete(g: jit_utils.GraphContext, tensor_list, dim): 512 return g.op("SequenceErase", tensor_list, dim) 513 514 515@_onnx_symbolic("aten::cat") 516@symbolic_helper.quantized_args(True) 517def cat(g: jit_utils.GraphContext, tensor_list, dim): 518 if symbolic_helper._is_packed_list(tensor_list): 519 return opset9.cat(g, tensor_list, dim) 520 else: 521 dim = symbolic_helper._get_const(dim, "i", "dim") 522 return g.op("ConcatFromSequence", tensor_list, axis_i=dim) 523 524 525@_onnx_symbolic("aten::stack") 526def stack(g: jit_utils.GraphContext, tensor_list, dim): 527 if symbolic_helper._is_packed_list(tensor_list): 528 return opset9.stack(g, tensor_list, dim) 529 else: 530 dim = symbolic_helper._get_const(dim, "i", "dim") 531 return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) 532 533 534@_onnx_symbolic("aten::_unique2") 535@symbolic_helper.parse_args("v", "i", "i", "i") 536def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): 537 u, indices, inverse_indices, counts = g.op( 538 "Unique", self, sorted_i=sorted, outputs=4 539 ) 540 return u, inverse_indices, counts 541 542 543@_onnx_symbolic("aten::unique_dim") 544@symbolic_helper.parse_args("v", "i", "i", "i", "i") 545def unique_dim( 546 g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts 547): 548 u, indices, inverse_indices, counts = g.op( 549 "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 550 ) 551 return u, inverse_indices, counts 552 553 554@_onnx_symbolic("aten::topk") 555@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") 556def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): 557 return symbolic_helper._topk_helper( 558 g, self, k, dim, largest=largest, sorted=sorted, out=out 559 ) 560 561 562@_onnx_symbolic("aten::sort") 563@symbolic_helper.parse_args("v", "i", "i", "none") 564def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): 565 return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) 566 567 568@_onnx_symbolic("aten::argsort") 569@symbolic_helper.parse_args("v", "i", "i", "none") 570def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None): 571 _, indices = symbolic_helper._sort_helper( 572 g, self, dim, decending=decending, out=out 573 ) 574 return indices 575 576 577@_onnx_symbolic("aten::round") 578@symbolic_helper.parse_args("v", "i") 579def round(g: jit_utils.GraphContext, self, decimals=0): 580 if not symbolic_helper._is_fp(self): 581 return self 582 if decimals == 0: 583 return g.op("Round", self) 584 mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) 585 round = g.op("Round", mul) 586 return g.op( 587 "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) 588 ) 589 590 591@_onnx_symbolic("aten::remainder") 592def remainder(g: jit_utils.GraphContext, input, other): 593 if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): 594 return opset9.remainder(g, input, other) 595 return g.op("Mod", input, other, fmod_i=0) 596 597 598@_onnx_symbolic("aten::split") 599@symbolic_helper.parse_args("v", "v", "i", "i") 600def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): 601 if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): 602 split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) 603 if _outputs is None: 604 return split_out 605 # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. 606 if ( 607 symbolic_helper._is_packed_list(split_size_or_sizes) 608 and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs 609 ): 610 split_sizes = [ 611 symbolic_helper._unsqueeze_helper(g, v, [0]) 612 for v in symbolic_helper._unpack_list(split_size_or_sizes) 613 ] 614 start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) 615 axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) 616 res = [] 617 for i in range(_outputs): 618 end = g.op( 619 "Add", start, split_sizes[i] 620 ) # split_sizes is a list of same length as _outputs 621 res.append(g.op("Slice", self, start, end, axis)) 622 start = end 623 return res 624 return [ 625 g.op( 626 "SequenceAt", 627 split_out, 628 g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), 629 ) 630 for i in range(_outputs) 631 ] 632 else: 633 return opset9.split(g, self, split_size_or_sizes, dim, _outputs) 634 635 636@_onnx_symbolic("aten::split_with_sizes") 637@symbolic_helper.parse_args("v", "v", "i", "i") 638def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): 639 return split(g, self, split_sizes, dim, _outputs) 640 641 642@_onnx_symbolic("aten::unbind") 643@symbolic_helper.parse_args("v", "i", "i") 644def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): 645 if _outputs is None: 646 return g.op( 647 "SplitToSequence", 648 self, 649 g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), 650 axis_i=dim, 651 keepdims_i=0, 652 ) 653 else: 654 return opset9.unbind(g, self, dim, _outputs) 655 656 657def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): 658 """Generate paddings in ONNX order based on pad in pytorch. 659 660 Args: 661 input: the input tensor. 662 pad: the paddings in pytorch. 663 The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, 664 where m is in range [0, n]. 665 """ 666 if ( 667 not symbolic_helper._is_packed_list(pad) 668 and symbolic_helper._is_list(pad) 669 and symbolic_helper._is_scalar_list(pad) 670 ): 671 pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) 672 # The desired order of paddings is 673 # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. 674 # n is the dimension of input. 675 # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning 676 pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) 677 # Set extension = [0] * (dim * 2 - len(pad)) 678 rank = symbolic_helper._get_tensor_rank(input) 679 if rank is None: 680 rank = g.op("Size", g.op("Shape", input)) 681 else: 682 rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) 683 extension = g.op( 684 "Sub", 685 g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), 686 pad_len, 687 ) 688 # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] 689 # Currently ONNX only supports int64 type for Pad 690 pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) 691 paddings = g.op( 692 "Concat", 693 pad, 694 g.op( 695 "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) 696 ), 697 axis_i=0, 698 ) 699 # Reshape and reverse order and collate first beginnings and then ends 700 # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], 701 # [..., 0, dim_n-1_end, dim_n_end]] 702 # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] 703 paddings = symbolic_helper._reshape_helper( 704 g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) 705 ) 706 paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) 707 paddings = symbolic_helper._reshape_helper( 708 g, paddings, g.op("Constant", value_t=torch.tensor([-1])) 709 ) 710 padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) 711 return padding_c 712 713 714@_onnx_symbolic("aten::constant_pad_nd") 715def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): 716 mode = "constant" 717 value = symbolic_helper._maybe_get_scalar(value) 718 value = symbolic_helper._if_scalar_type_as(value, input) 719 pad = _prepare_onnx_paddings(g, input, padding) 720 return g.op("Pad", input, pad, value, mode_s=mode) 721 722 723@_onnx_symbolic("aten::reflection_pad1d") 724@_onnx_symbolic("aten::reflection_pad2d") 725@_onnx_symbolic("aten::reflection_pad3d") 726def reflection_pad(g: jit_utils.GraphContext, input, padding): 727 mode = "reflect" 728 paddings = _prepare_onnx_paddings(g, input, padding) 729 return g.op("Pad", input, paddings, mode_s=mode) 730 731 732@_onnx_symbolic("aten::replication_pad1d") 733@_onnx_symbolic("aten::replication_pad2d") 734@_onnx_symbolic("aten::replication_pad3d") 735def replication_pad(g: jit_utils.GraphContext, input, padding): 736 mode = "edge" 737 paddings = _prepare_onnx_paddings(g, input, padding) 738 return g.op("Pad", input, paddings, mode_s=mode) 739 740 741@_onnx_symbolic("aten::pad") 742def pad( 743 g: jit_utils.GraphContext, 744 input: _C.Value, 745 pad: _C.Value, 746 mode: _C.Value, 747 value: _C.Value, 748): 749 mode = symbolic_helper._parse_arg(mode, "s") 750 if mode == "replicate": 751 return replication_pad(g, input, pad) 752 elif mode == "reflect": 753 return reflection_pad(g, input, pad) 754 elif mode == "constant": 755 return constant_pad_nd(g, input, pad, value) 756 elif mode == "circular": 757 return opset9._pad_circular(g, input, pad) 758 else: 759 raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) 760 761 762@_onnx_symbolic("aten::linalg_det") 763def linalg_det(g: jit_utils.GraphContext, self): 764 return g.op("Det", self) 765 766 767@_onnx_symbolic("aten::logdet") 768def logdet(g: jit_utils.GraphContext, input): 769 return opset9.log(g, linalg_det(g, input)) 770 771 772@_onnx_symbolic("aten::arange") 773def arange(g: jit_utils.GraphContext, *args): 774 def _get_arange_dtype(dtype): 775 dtype = symbolic_helper._maybe_get_const(dtype, "i") 776 return dtype 777 778 if len(args) == 2 and all(isinstance(val, int) for val in args): 779 # aten::arange(Scalar start, Scalar end) 780 dtype = torch.int64 781 # Start index. 782 start = g.op( 783 "Constant", 784 value_t=torch.tensor(args[0], dtype=dtype), 785 ) 786 # End (exclusive) index. 787 end = g.op( 788 "Constant", 789 value_t=torch.tensor(args[1], dtype=dtype), 790 ) 791 # Step size from start to end indexes. 792 delta_default = g.op( 793 "Constant", 794 value_t=torch.tensor(1, dtype=dtype), 795 ) 796 return g.op("Range", start, end, delta_default) 797 elif len(args) == 2 or len(args) == 5: 798 if len(args) == 2: 799 # aten::arange(Scalar end, Tensor out) 800 dtype = None 801 else: 802 # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) 803 dtype = _get_arange_dtype(args[1]) 804 type_, end, start, step = symbolic_helper._arange_cast_helper( 805 g, end=args[0], dtype=dtype 806 ) 807 start_default = g.op( 808 "Constant", 809 value_t=torch.tensor(0, dtype=type_.dtype()), 810 ) 811 delta_default = g.op( 812 "Constant", 813 value_t=torch.tensor(1, dtype=type_.dtype()), 814 ) 815 return g.op("Range", start_default, end, delta_default) 816 elif len(args) == 4 or len(args) == 7: 817 if len(args) == 4: 818 # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) 819 dtype = None 820 else: 821 # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) 822 dtype = _get_arange_dtype(args[3]) 823 _, end, start, step = symbolic_helper._arange_cast_helper( 824 g, start=args[0], end=args[1], step=args[2], dtype=dtype 825 ) 826 return g.op("Range", start, end, step) 827 elif len(args) == 6: 828 # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) 829 dtype = _get_arange_dtype(args[2]) 830 type_, end, start, step = symbolic_helper._arange_cast_helper( 831 g, start=args[0], end=args[1], dtype=dtype 832 ) 833 delta_default = g.op( 834 "Constant", 835 value_t=torch.tensor(1, dtype=type_.dtype()), 836 ) 837 return g.op("Range", start, end, delta_default) 838 else: 839 return symbolic_helper._unimplemented( 840 "aten::arange", f"with {len(args)} arguments" 841 ) 842 843 844@_onnx_symbolic("aten::_dim_arange") 845@symbolic_helper.parse_args("v", "i") 846def _dim_arange(g: jit_utils.GraphContext, like, dim): 847 like_shape = g.op("Shape", like) 848 stop = g.op( 849 "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 850 ) 851 return arange(g, stop, 4, None, None, None) 852 853 854@_onnx_symbolic("aten::size") 855@symbolic_helper.quantized_args(True, quantize_output=False) 856def size(g: jit_utils.GraphContext, self, dim=None): 857 if dim is None: 858 return g.op("Shape", self) 859 return symbolic_helper._size_helper(g, self, dim) 860 861 862@_onnx_symbolic("aten::squeeze") 863def squeeze(g: jit_utils.GraphContext, self, dim=None): 864 if dim is None: 865 return g.op("Squeeze", self) 866 867 # dim as a tensor 868 if not symbolic_helper._is_constant(dim): 869 return symbolic_helper._squeeze_helper(g, self, [dim]) 870 871 dim = symbolic_helper._get_const(dim, "i", "dim") 872 873 input_rank = symbolic_helper._get_tensor_rank(self) 874 adjusted_dim = dim 875 if input_rank is not None and dim < 0: 876 adjusted_dim += input_rank 877 dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) 878 if (dim < 0 and input_rank is None) or dim_size is None: 879 # If onnx shape inference is not on, export always as dynamic. 880 # Because we cannot tell if observed static shape is also static at runtime. 881 # create "cond" node (condition is shape[i]==1) 882 dim_constant = g.op("Constant", value_t=torch.tensor([dim])) 883 size = symbolic_helper._size_helper(g, self, dim_constant) 884 const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) 885 cond = g.op("Equal", size, const_one) 886 # create the "If" node and add the "then" and "else" blocks to it. 887 if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( 888 g, "If", cond, n_blocks=2 889 ) 890 squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) 891 utils._add_output_to_block(if_context.block, squeeze_) 892 identity_ = else_context.op("Identity", self) 893 utils._add_output_to_block(else_context.block, identity_) 894 return if_op 895 896 # For static input shape 897 dim = adjusted_dim 898 if dim_size > 1: 899 warnings.warn( 900 "This model contains a squeeze operation on dimension " 901 + str(dim) 902 + ". The size of " 903 + "this dimension in the given input is " 904 + str(dim_size) 905 + ". The model will " 906 + "be exported without the squeeze node. If the model is intended to be used with dynamic " 907 + "input shapes, please export with dynamic_axes argument." 908 ) 909 return self 910 return symbolic_helper._squeeze_helper(g, self, [dim]) 911 912 913@_onnx_symbolic("aten::unsqueeze") 914def unsqueeze(g: jit_utils.GraphContext, self, dim): 915 if symbolic_helper._is_constant(dim): 916 dim = symbolic_helper._get_const(dim, "i", "dim") 917 918 return symbolic_helper._unsqueeze_helper(g, self, [dim]) 919 920 921@_onnx_symbolic("aten::mm") 922def mm(g: jit_utils.GraphContext, self, other): 923 return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) 924 925 926@_onnx_symbolic("aten::index") 927def index(g: jit_utils.GraphContext, self, index): 928 if symbolic_helper._is_packed_list(index): 929 indices = symbolic_helper._unpack_list(index) 930 else: 931 indices = [index] 932 933 # Handle single mask index. 934 if len(indices) == 1: 935 index = indices[0] 936 if not symbolic_helper._is_none(index) and ( 937 symbolic_helper._is_bool(index) 938 or _type_utils.JitScalarType.from_value(index) 939 == _type_utils.JitScalarType.UINT8 940 ): 941 index = opset9.nonzero(g, index) 942 return g.op("GatherND", self, index) 943 return opset9.index(g, self, index) 944 945 946@_onnx_symbolic("aten::index_fill") 947def index_fill(g: jit_utils.GraphContext, self, dim, index, value): 948 dim_value = symbolic_helper._parse_arg(dim, "i") 949 expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( 950 g, self, dim, index 951 ) 952 value = symbolic_helper._maybe_get_scalar(value) 953 value = symbolic_helper._if_scalar_type_as(value, self) 954 expanded_value = opset9.expand(g, value, expanded_index_shape, None) 955 return scatter(g, self, dim, expanded_index, expanded_value) 956 957 958@_onnx_symbolic("aten::index_copy") 959def index_copy(g: jit_utils.GraphContext, self, dim, index, source): 960 dim_value = symbolic_helper._parse_arg(dim, "i") 961 expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( 962 g, self, dim, index 963 ) 964 return scatter(g, self, dim, expanded_index, source) 965 966 967@_onnx_symbolic("aten::bitwise_right_shift") 968@_onnx_symbolic("aten::__rshift_") 969def __rshift_(g: jit_utils.GraphContext, self, other): 970 # make sure to cast other to self's type 971 # (when self is long, make sure that other is not float) 972 if _type_utils.JitScalarType.from_value( 973 other, _type_utils.JitScalarType.UNDEFINED 974 ) != _type_utils.JitScalarType.from_value(self): 975 other = g.op( 976 "Cast", 977 other, 978 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 979 ) 980 981 if ( 982 _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) 983 == _type_utils.JitScalarType.UINT8 984 ): 985 return g.op("BitShift", self, other, direction_s="RIGHT") 986 987 two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) 988 # exponent (same type as self) has to be float or double in onnx::Pow 989 if not symbolic_helper._is_fp(self): 990 other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) 991 two_pow = g.op("Pow", two, other) 992 two_pow = g.op( 993 "Cast", 994 two_pow, 995 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 996 ) 997 rshift = g.op("Div", self, two_pow) 998 return rshift 999 1000 1001@_onnx_symbolic("aten::bitwise_left_shift") 1002@_onnx_symbolic("aten::__lshift_") 1003def __lshift_(g: jit_utils.GraphContext, self, other): 1004 # make sure to cast other to self's type 1005 # (when self is long, make sure that other is not float) 1006 if _type_utils.JitScalarType.from_value( 1007 other, _type_utils.JitScalarType.UNDEFINED 1008 ) != _type_utils.JitScalarType.from_value(self): 1009 other = g.op( 1010 "Cast", 1011 other, 1012 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 1013 ) 1014 1015 if ( 1016 _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) 1017 == _type_utils.JitScalarType.UINT8 1018 ): 1019 return g.op("BitShift", self, other, direction_s="LEFT") 1020 1021 two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) 1022 # exponent (same type as self) has to be float or double in onnx::Pow 1023 if not symbolic_helper._is_fp(self): 1024 other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) 1025 two_pow = g.op("Pow", two, other) 1026 two_pow = g.op( 1027 "Cast", 1028 two_pow, 1029 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), 1030 ) 1031 lshift = g.op("Mul", self, two_pow) 1032 return lshift 1033 1034 1035def _get_im2col_indices_along_dim( 1036 g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d 1037): 1038 # Input is always 4-D (N, C, H, W) 1039 # Calculate indices of sliding blocks along spatial dimension 1040 # Slide kernel over input each dim d: 1041 # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) 1042 # with steps = stride 1043 1044 blocks_d = g.op( 1045 "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) 1046 ) 1047 blocks_d = g.op( 1048 "Sub", 1049 blocks_d, 1050 g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), 1051 ) 1052 1053 # Stride kernel over input and find starting indices along dim d 1054 blocks_d_indices = g.op( 1055 "Range", 1056 g.op("Constant", value_t=torch.tensor(0)), 1057 blocks_d, 1058 g.op("Constant", value_t=torch.tensor(stride_d)), 1059 ) 1060 1061 # Apply dilation on kernel and find its indices along dim d 1062 kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) 1063 kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) 1064 1065 # Broadcast and add kernel staring positions (indices) with 1066 # kernel_grid along dim d, to get block indices along dim d 1067 blocks_d_indices = symbolic_helper._unsqueeze_helper( 1068 g, blocks_d_indices, [0] 1069 ) # Reshape to [1, -1] 1070 kernel_mask = symbolic_helper._reshape_helper( 1071 g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) 1072 ) 1073 block_mask = g.op("Add", blocks_d_indices, kernel_mask) 1074 1075 return block_mask 1076 1077 1078def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): 1079 # Input is always 4-D tensor (N, C, H, W) 1080 # Padding tensor has the following format: (padding_h, padding_w) 1081 # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) 1082 pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) 1083 return g.op("Pad", input, pad) 1084 1085 1086def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): 1087 batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) 1088 channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) 1089 channel_unfolded = g.op( 1090 "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) 1091 ) 1092 1093 return g.op( 1094 "Concat", 1095 symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), 1096 symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), 1097 g.op("Constant", value_t=torch.tensor([-1])), 1098 axis_i=0, 1099 ) 1100 1101 1102@_onnx_symbolic("aten::im2col") 1103@symbolic_helper.parse_args("v", "is", "is", "is", "is") 1104def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): 1105 # Input is always 4-D tensor (N, C, H, W) 1106 # All other args are int[2] 1107 1108 input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) 1109 input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) 1110 1111 stride_h, stride_w = stride[0], stride[1] 1112 padding_h, padding_w = padding[0], padding[1] 1113 dilation_h, dilation_w = dilation[0], dilation[1] 1114 kernel_h, kernel_w = kernel_size[0], kernel_size[1] 1115 1116 blocks_row_indices = _get_im2col_indices_along_dim( 1117 g, input_h, kernel_h, dilation_h, padding_h, stride_h 1118 ) 1119 blocks_col_indices = _get_im2col_indices_along_dim( 1120 g, input_w, kernel_w, dilation_w, padding_w, stride_w 1121 ) 1122 1123 output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) 1124 padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) 1125 1126 # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 1127 # [[[[1., 2., 3.,], 1128 # [4., 5., 6.,], 1129 # [7., 8., 9.,]]]] 1130 # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: 1131 # [[[[[1., 2., 3.], 1132 # [4., 5., 6.]], 1133 # [[4., 5., 6.], 1134 # [7., 8., 9.]]]]] 1135 # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: 1136 # [[[[[[1., 2.], 1137 # [4., 5.]], 1138 # [[2., 3.], 1139 # [5., 6]]], 1140 # [[[4., 5.], 1141 # [7., 8.]], 1142 # [[5., 6.], 1143 # [8., 9.]]]]]] 1144 # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: 1145 # [[[1., 2., 4., 5.], 1146 # [2., 3., 5., 6.], 1147 # [4., 5., 7., 8.], 1148 # [5., 6., 8., 9.]]] 1149 output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) 1150 output = g.op("Gather", output, blocks_col_indices, axis_i=4) 1151 output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) 1152 return symbolic_helper._reshape_helper(g, output, output_shape) 1153 1154 1155@_onnx_symbolic("aten::narrow") 1156def narrow(g: jit_utils.GraphContext, input, dim, start, length): 1157 end = g.op("Add", start, length) 1158 return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) 1159 1160 1161@_onnx_symbolic("aten::flatten") 1162@symbolic_helper.quantized_args(True, False, False) 1163@symbolic_helper.parse_args("v", "i", "i") 1164def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): 1165 dim = symbolic_helper._get_tensor_rank(input) 1166 if dim == 1: 1167 return input 1168 # use ONNX's Flatten operator for cases where the output shape is 2D 1169 if start_dim == 1: 1170 if end_dim == -1 or (dim is not None and end_dim == dim - 1): 1171 return g.op("Flatten", input, axis_i=start_dim) 1172 elif start_dim == 0: 1173 if end_dim == -2 or (dim is not None and end_dim == dim - 2): 1174 return g.op("Flatten", input, axis_i=end_dim + 1) 1175 if dim is None: 1176 return symbolic_helper._unimplemented( 1177 "dim", 1178 "ONNX and PyTorch use different strategies to split the input. " 1179 "Input rank must be known at export time.", 1180 ) 1181 # if end_dim is negative add dim 1182 if end_dim < 0: 1183 end_dim = dim + end_dim 1184 1185 return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) 1186 1187 1188@_onnx_symbolic("aten::linalg_vector_norm") 1189@symbolic_helper.parse_args("v", "f", "is", "b", "v") 1190def linalg_vector_norm( 1191 g: jit_utils.GraphContext, 1192 self, 1193 ord, 1194 dim: Sequence[int] | None, 1195 keepdim: bool, 1196 dtype, 1197): 1198 return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) 1199 1200 1201@_onnx_symbolic("aten::embedding_bag") 1202@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") 1203def embedding_bag( 1204 g: jit_utils.GraphContext, 1205 embedding_matrix, 1206 indices, 1207 offsets, 1208 scale_grad_by_freq, 1209 mode, 1210 sparse, 1211 per_sample_weights, 1212 include_last_offset, 1213 padding_idx, 1214): 1215 return symbolic_helper._embedding_bag_helper( 1216 g, 1217 embedding_matrix, 1218 indices, 1219 offsets, 1220 scale_grad_by_freq, 1221 mode, 1222 sparse, 1223 per_sample_weights, 1224 include_last_offset, 1225 padding_idx, 1226 ) 1227 1228 1229@_onnx_symbolic("aten::embedding_renorm") 1230@symbolic_helper.parse_args("v", "v", "f", "f") 1231def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): 1232 unique_indices = g.op("Unique", indices) 1233 partial_weight = g.op("Gather", weight, unique_indices) 1234 norm_i = int(norm_type) 1235 if norm_i == 1: 1236 norm_type = "ReduceL1" 1237 elif norm_i == 2: 1238 norm_type = "ReduceL2" 1239 else: 1240 raise errors.SymbolicValueError( 1241 f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " 1242 "Only 1. and 2. are supported.", 1243 weight, 1244 ) 1245 partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) 1246 # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 1247 # Add 1e-7 to prevent division by zero. 1248 partial_weight_norm_ = g.op( 1249 "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) 1250 ) 1251 max_norm = torch.tensor(max_norm) 1252 scales = g.op("Div", max_norm, partial_weight_norm_) 1253 partial_weight_renorm = g.op("Mul", partial_weight, scales) 1254 partial_weight_renorm = g.op( 1255 "Where", 1256 g.op("Greater", partial_weight_norm, max_norm), 1257 partial_weight_renorm, 1258 partial_weight, 1259 ) 1260 return g.op( 1261 "ScatterND", 1262 weight, 1263 symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), 1264 partial_weight_renorm, 1265 ) 1266 1267 1268@_onnx_symbolic("aten::chunk") 1269def chunk(g: jit_utils.GraphContext, self, chunks, dim): 1270 # Calculate chunk size for dynamic chunk 1271 dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) 1272 chunk_size_s = g.op( 1273 "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) 1274 ) 1275 chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) 1276 # Create splits vector 1277 chunk_vec = [ 1278 opset9.expand(g, chunk_size, chunk_size_s, None), 1279 g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), 1280 ] 1281 chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) 1282 return split(g, self, chunk_vec, dim) 1283 1284 1285@_onnx_symbolic("aten::normal") 1286def normal( 1287 g: jit_utils.GraphContext, 1288 mean, 1289 std, 1290 sizes=None, 1291 generator=None, 1292 dtype=None, 1293 layout=None, 1294 device=None, 1295 pin_memory=None, 1296): 1297 # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a 1298 # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample 1299 # from a mean 0 and variance 1 distribution then 1300 # sigma x+mu 1301 # is a sample with mean mu and variance sigma's square. 1302 if sizes is not None and not symbolic_helper._is_none(sizes): 1303 mean = opset9.expand(g, mean, sizes, None) 1304 result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) 1305 return add(g, result, mean) 1306 1307 1308@_onnx_symbolic("aten::atleast_1d") 1309def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): 1310 # NOTE: If it's 0D, reshape to 1D 1311 1312 # NOTE: self could be a packed list or a tensor 1313 if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): 1314 tensor_list = symbolic_helper._unpack_list(self) 1315 new_tensor_list = [] 1316 for tensor in tensor_list: 1317 new_tensor = tensor 1318 tensor_rank = symbolic_helper._get_tensor_rank(tensor) 1319 if tensor_rank == 0: 1320 new_tensor = symbolic_helper._reshape_helper( 1321 g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) 1322 ) 1323 new_tensor_list.append(new_tensor) 1324 return g.op("SequenceConstruct", *new_tensor_list) 1325 1326 tensor_rank = symbolic_helper._get_tensor_rank(self) 1327 if tensor_rank == 0: 1328 self = symbolic_helper._reshape_helper( 1329 g, self, g.op("Constant", value_t=torch.tensor([1])) 1330 ) 1331 return self 1332 1333 1334@_onnx_symbolic("aten::atleast_2d") 1335def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): 1336 # NOTE: If it's 0D, reshape to 2D 1337 # If it's 1D, unsqueeze to 2D 1338 1339 # NOTE: self could be a packed list or a tensor 1340 if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): 1341 tensor_list = symbolic_helper._unpack_list(self) 1342 new_tensor_list = [] 1343 for tensor in tensor_list: 1344 new_tensor = tensor 1345 tensor_rank = symbolic_helper._get_tensor_rank(tensor) 1346 if tensor_rank == 0: 1347 new_tensor = symbolic_helper._reshape_helper( 1348 g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) 1349 ) 1350 elif tensor_rank == 1: 1351 new_tensor = symbolic_helper._unsqueeze_helper( 1352 g, new_tensor, axes_i=[0] 1353 ) 1354 new_tensor_list.append(new_tensor) 1355 return g.op("SequenceConstruct", *new_tensor_list) 1356 1357 tensor_rank = symbolic_helper._get_tensor_rank(self) 1358 if tensor_rank == 0: 1359 self = symbolic_helper._reshape_helper( 1360 g, self, g.op("Constant", value_t=torch.tensor([1, 1])) 1361 ) 1362 elif tensor_rank == 1: 1363 self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) 1364 return self 1365 1366 1367@_onnx_symbolic("aten::atleast_3d") 1368def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): 1369 # NOTE: If it's 0D, reshape to 3D 1370 # If it's 1D, unsqueeze to 3D 1371 # If it's 2D, unsqueeze to 3D 1372 1373 # NOTE: self could be a packed list or a tensor 1374 if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): 1375 tensor_list = symbolic_helper._unpack_list(self) 1376 new_tensor_list = [] 1377 for tensor in tensor_list: 1378 new_tensor = tensor 1379 tensor_rank = symbolic_helper._get_tensor_rank(tensor) 1380 if tensor_rank == 0: 1381 new_tensor = symbolic_helper._reshape_helper( 1382 g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) 1383 ) 1384 elif tensor_rank == 1: 1385 new_tensor = symbolic_helper._unsqueeze_helper( 1386 g, new_tensor, axes_i=[0] 1387 ) 1388 new_tensor = symbolic_helper._unsqueeze_helper( 1389 g, new_tensor, axes_i=[-1] 1390 ) 1391 elif tensor_rank == 2: 1392 new_tensor = symbolic_helper._unsqueeze_helper( 1393 g, new_tensor, axes_i=[-1] 1394 ) 1395 new_tensor_list.append(new_tensor) 1396 return g.op("SequenceConstruct", *new_tensor_list) 1397 1398 tensor_rank = symbolic_helper._get_tensor_rank(self) 1399 if tensor_rank == 0: 1400 self = symbolic_helper._reshape_helper( 1401 g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) 1402 ) 1403 elif tensor_rank == 1: 1404 self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) 1405 self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) 1406 elif tensor_rank == 2: 1407 self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) 1408 return self 1409 1410 1411@_onnx_symbolic("prim::ConstantChunk") 1412def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): 1413 input_shape = g.op("Shape", self) 1414 axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) 1415 input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) 1416 start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) 1417 chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) 1418 chunk_size_minus_1 = g.op( 1419 "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) 1420 ) 1421 input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) 1422 chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) 1423 res = [] 1424 for i in range(chunks): 1425 index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) 1426 end = g.op("Mul", chunk_dim, index) 1427 res.append(g.op("Slice", self, start, end, axis)) 1428 start = end 1429 return res 1430 1431 1432@_onnx_symbolic("aten::hstack") 1433def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): 1434 tensor_list = atleast_1d(g, tensor_list) 1435 first_tensor = g.op( 1436 "SequenceAt", 1437 tensor_list, 1438 g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), 1439 ) 1440 first_tensor_shape = g.op("Shape", first_tensor) 1441 first_tensor_dim = g.op("Size", first_tensor_shape) 1442 1443 const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) 1444 equal_to_one = g.op("Equal", first_tensor_dim, const_one) 1445 1446 ( 1447 if_op_greater, 1448 (if_context_equal, else_context_equal), 1449 _, 1450 ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) 1451 result_if = if_context_equal.op( 1452 "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 1453 ) 1454 utils._add_output_to_block(if_context_equal.block, result_if) 1455 result_else = else_context_equal.op( 1456 "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 1457 ) 1458 utils._add_output_to_block(else_context_equal.block, result_else) 1459 result = if_op_greater.node().output() 1460 1461 return result 1462 1463 1464@_onnx_symbolic("aten::vstack") 1465def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): 1466 tensor_list = atleast_2d(g, tensor_list) 1467 return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) 1468