1# mypy: allow-untyped-defs 2""" 3Note [ONNX operators that are added/updated from opset 7 to opset 8] 4~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 5New operators: 6 Expand 7 8Updated operators: 9 Min, Max, Sum, Mean: supports multidirectional broadcasting. 10 MaxPool: added optional indices output. 11 Scan 12""" 13 14import functools 15import warnings 16 17from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 18from torch.onnx._internal import jit_utils, registration 19 20 21_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) 22 23block_listed_operators = ( 24 "scan", 25 "expand", 26 "expand_as", 27 "meshgrid", 28 "adaptive_max_pool1d", 29 "adaptive_max_pool2d", 30 "adaptive_max_pool3d", 31 "max_pool1d_with_indices", 32 "max_pool2d_with_indices", 33 "max_pool3d_with_indices", 34) 35 36 37# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. 38# torch.max (same for torch.min) actually has two interfaces smashed together: 39# torch.max(x, dim, keepdim) and torch.max(x, y) 40@_onnx_symbolic("aten::max") 41def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 42 # torch.max(input, other) 43 if keepdim is None and dim_or_y is not None: 44 warnings.warn( 45 "Multidirectional broadcasting is not supported in opset 7. " 46 "This might cause the onnx model to be incorrect, if inputs to max operators " 47 "have different shapes" 48 ) 49 return opset9.max(g, self, dim_or_y, keepdim) 50 51 52@_onnx_symbolic("aten::min") 53def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 54 # torch.min(input, other) 55 if keepdim is None and dim_or_y is not None: 56 warnings.warn( 57 "Multidirectional broadcasting is not supported in opset 7. " 58 "This might cause the onnx model to be incorrect, if inputs to min operators " 59 "have different shapes" 60 ) 61 return opset9.min(g, self, dim_or_y, keepdim) 62 63 64for block_listed_op in block_listed_operators: 65 _onnx_symbolic(f"aten::{block_listed_op}")( 66 symbolic_helper._block_list_in_opset(block_listed_op) 67 ) 68