1# Owner(s): ["module: onnx"] 2 3"""Test the support on onnxscript in PyTorch-ONNX converter.""" 4 5import io 6from typing import List 7 8import onnx 9 10import onnxscript 11from onnxscript.onnx_types import FLOAT 12 13import torch 14from torch.onnx._internal import jit_utils 15from torch.testing._internal import common_utils 16 17 18class TestONNXScriptExport(common_utils.TestCase): 19 # opset version is 20 # 1. local function is supported after opset 15 21 # 2. onnx-script requires users to determine opset in local function 22 opset_version = 15 23 24 def test_onnxscript_registration_with_multiple_models(self): 25 from onnxscript.onnx_opset import opset15 as op 26 27 # 1. Register Selu onnxscript function as custom Op 28 custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) 29 30 @onnxscript.script(custom_opset) 31 def Selu(X): 32 # default value is not supported by onnxscript 33 alpha = 1.67326 # auto wrapped as Constants 34 gamma = 1.0507 35 alphaX = op.CastLike(alpha, X) 36 gammaX = op.CastLike(gamma, X) 37 neg = gammaX * (alphaX * op.Exp(X) - alphaX) 38 pos = gammaX * X 39 zero = op.CastLike(0, X) 40 return op.Where(X <= zero, neg, pos) 41 42 def custom_selu(g: jit_utils.GraphContext, X): 43 return g.onnxscript_op(Selu, X).setType(X.type()) 44 45 torch.onnx.register_custom_op_symbolic( 46 symbolic_name="aten::selu", 47 symbolic_fn=custom_selu, 48 opset_version=self.opset_version, 49 ) 50 51 # 2. Register layer_norm onnxscript function as custom Op 52 @onnxscript.script(custom_opset) 53 def layer_norm( 54 X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float 55 ): 56 mean = op.ReduceMean(X, axes=axes) 57 D = X - mean # op.Sub(X, mean) 58 DD = D * D # op.Mul(D, D) 59 var = op.ReduceMean(DD, axes=axes) 60 vareps = var + eps # op.Add(var, eps) 61 stddev = op.Sqrt(vareps) 62 invstddev = op.Reciprocal(stddev) 63 normalized = D * invstddev # op.Mul(D, invstddev) 64 normalizedw = op.CastLike( 65 normalized, weight 66 ) # Type issue if missing this Op 67 normalizedscaled = normalizedw * weight # op.Mul(normalized, weight) 68 return normalizedscaled + bias 69 70 @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") 71 def custom_layer_norm( 72 g, input, normalized_shape, weight, bias, eps, cudnn_enable 73 ): 74 # comprehension is not supported by onnxscript 75 axes = [-i for i in range(len(normalized_shape), 0, -1)] 76 return g.onnxscript_op( 77 layer_norm, input, weight, bias, axes_i=axes, eps_f=eps 78 ).setType(input.type()) 79 80 torch.onnx.register_custom_op_symbolic( 81 symbolic_name="aten::layer_norm", 82 symbolic_fn=custom_layer_norm, 83 opset_version=self.opset_version, 84 ) 85 86 # 3. export two models 87 x = torch.randn(1, 2, 3, 4, requires_grad=True) 88 model_selu = torch.nn.SELU() 89 selu_onnx = io.BytesIO() 90 torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version) 91 92 N, C = 3, 4 93 y = torch.randn(N, C) 94 model_layer_norm = torch.nn.LayerNorm(C) 95 layer_norm_onnx = io.BytesIO() 96 torch.onnx.export( 97 model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version 98 ) 99 100 # 4. test on models 101 selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue())) 102 layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue())) 103 104 self.assertEqual(len(selu_proto.functions), 1) 105 self.assertEqual(len(layer_norm_proto.functions), 1) 106 self.assertEqual(selu_proto.functions[0].name, "Selu") 107 self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm") 108 109 def test_loop_registration(self): 110 # Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py, 111 # which has recursive logic to go through every nodes with subgraph in model proto 112 class NestedLoopsModel(torch.jit.ScriptModule): 113 def __init__(self) -> None: 114 super().__init__() 115 self.selu = torch.nn.SELU() 116 117 @torch.jit.script_method 118 def forward(self, x): 119 y = x 120 for i in range(x.size(3)): 121 if i == 0: 122 y = self.selu(x) 123 else: 124 y += i 125 return y 126 127 model = NestedLoopsModel() 128 inputs = torch.zeros(1, 2, 3, 4) 129 130 from onnxscript.onnx_opset import opset15 as op 131 132 custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2) 133 134 @onnxscript.script(custom_opset) 135 def Selu(X): 136 alpha = 1.6732632423543772848170429916717 137 gamma = 1.0507009873554804934193349852946 138 alphaX = op.CastLike(alpha, X) 139 gammaX = op.CastLike(gamma, X) 140 neg = gammaX * (alphaX * op.Exp(X) - alphaX) 141 pos = gammaX * X 142 zero = op.CastLike(0, X) 143 return op.Where(X <= zero, neg, pos) 144 145 def custom_selu(g, X): 146 # domain of the Op should be aligned with onnx-script 147 # setType API is required for custom Op to support 148 # torchscript shape type inference 149 print("custom_selu is used!") 150 return g.onnxscript_op(Selu, X).setType(X.type()) 151 152 torch.onnx.register_custom_op_symbolic( 153 symbolic_name="aten::selu", 154 symbolic_fn=custom_selu, 155 opset_version=15, 156 ) 157 158 saved_model = io.BytesIO() 159 torch.onnx.export( 160 torch.jit.script(model), inputs, f=saved_model, opset_version=15 161 ) 162 loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue())) 163 self.assertEqual(len(loop_selu_proto.functions), 1) 164