xref: /aosp_15_r20/external/pytorch/test/onnx/test_onnxscript_no_runtime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3"""Test the support on onnxscript in PyTorch-ONNX converter."""
4
5import io
6from typing import List
7
8import onnx
9
10import onnxscript
11from onnxscript.onnx_types import FLOAT
12
13import torch
14from torch.onnx._internal import jit_utils
15from torch.testing._internal import common_utils
16
17
18class TestONNXScriptExport(common_utils.TestCase):
19    # opset version is
20    # 1. local function is supported after opset 15
21    # 2. onnx-script requires users to determine opset in local function
22    opset_version = 15
23
24    def test_onnxscript_registration_with_multiple_models(self):
25        from onnxscript.onnx_opset import opset15 as op
26
27        # 1. Register Selu onnxscript function as custom Op
28        custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
29
30        @onnxscript.script(custom_opset)
31        def Selu(X):
32            # default value is not supported by onnxscript
33            alpha = 1.67326  # auto wrapped as Constants
34            gamma = 1.0507
35            alphaX = op.CastLike(alpha, X)
36            gammaX = op.CastLike(gamma, X)
37            neg = gammaX * (alphaX * op.Exp(X) - alphaX)
38            pos = gammaX * X
39            zero = op.CastLike(0, X)
40            return op.Where(X <= zero, neg, pos)
41
42        def custom_selu(g: jit_utils.GraphContext, X):
43            return g.onnxscript_op(Selu, X).setType(X.type())
44
45        torch.onnx.register_custom_op_symbolic(
46            symbolic_name="aten::selu",
47            symbolic_fn=custom_selu,
48            opset_version=self.opset_version,
49        )
50
51        # 2. Register layer_norm onnxscript function as custom Op
52        @onnxscript.script(custom_opset)
53        def layer_norm(
54            X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
55        ):
56            mean = op.ReduceMean(X, axes=axes)
57            D = X - mean  # op.Sub(X, mean)
58            DD = D * D  # op.Mul(D, D)
59            var = op.ReduceMean(DD, axes=axes)
60            vareps = var + eps  # op.Add(var, eps)
61            stddev = op.Sqrt(vareps)
62            invstddev = op.Reciprocal(stddev)
63            normalized = D * invstddev  # op.Mul(D, invstddev)
64            normalizedw = op.CastLike(
65                normalized, weight
66            )  # Type issue if missing this Op
67            normalizedscaled = normalizedw * weight  # op.Mul(normalized, weight)
68            return normalizedscaled + bias
69
70        @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
71        def custom_layer_norm(
72            g, input, normalized_shape, weight, bias, eps, cudnn_enable
73        ):
74            # comprehension is not supported by onnxscript
75            axes = [-i for i in range(len(normalized_shape), 0, -1)]
76            return g.onnxscript_op(
77                layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
78            ).setType(input.type())
79
80        torch.onnx.register_custom_op_symbolic(
81            symbolic_name="aten::layer_norm",
82            symbolic_fn=custom_layer_norm,
83            opset_version=self.opset_version,
84        )
85
86        # 3. export two models
87        x = torch.randn(1, 2, 3, 4, requires_grad=True)
88        model_selu = torch.nn.SELU()
89        selu_onnx = io.BytesIO()
90        torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version)
91
92        N, C = 3, 4
93        y = torch.randn(N, C)
94        model_layer_norm = torch.nn.LayerNorm(C)
95        layer_norm_onnx = io.BytesIO()
96        torch.onnx.export(
97            model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version
98        )
99
100        # 4. test on models
101        selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue()))
102        layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue()))
103
104        self.assertEqual(len(selu_proto.functions), 1)
105        self.assertEqual(len(layer_norm_proto.functions), 1)
106        self.assertEqual(selu_proto.functions[0].name, "Selu")
107        self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm")
108
109    def test_loop_registration(self):
110        # Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py,
111        # which has recursive logic to go through every nodes with subgraph in model proto
112        class NestedLoopsModel(torch.jit.ScriptModule):
113            def __init__(self) -> None:
114                super().__init__()
115                self.selu = torch.nn.SELU()
116
117            @torch.jit.script_method
118            def forward(self, x):
119                y = x
120                for i in range(x.size(3)):
121                    if i == 0:
122                        y = self.selu(x)
123                    else:
124                        y += i
125                return y
126
127        model = NestedLoopsModel()
128        inputs = torch.zeros(1, 2, 3, 4)
129
130        from onnxscript.onnx_opset import opset15 as op
131
132        custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2)
133
134        @onnxscript.script(custom_opset)
135        def Selu(X):
136            alpha = 1.6732632423543772848170429916717
137            gamma = 1.0507009873554804934193349852946
138            alphaX = op.CastLike(alpha, X)
139            gammaX = op.CastLike(gamma, X)
140            neg = gammaX * (alphaX * op.Exp(X) - alphaX)
141            pos = gammaX * X
142            zero = op.CastLike(0, X)
143            return op.Where(X <= zero, neg, pos)
144
145        def custom_selu(g, X):
146            # domain of the Op should be aligned with onnx-script
147            # setType API is required for custom Op to support
148            # torchscript shape type inference
149            print("custom_selu is used!")
150            return g.onnxscript_op(Selu, X).setType(X.type())
151
152        torch.onnx.register_custom_op_symbolic(
153            symbolic_name="aten::selu",
154            symbolic_fn=custom_selu,
155            opset_version=15,
156        )
157
158        saved_model = io.BytesIO()
159        torch.onnx.export(
160            torch.jit.script(model), inputs, f=saved_model, opset_version=15
161        )
162        loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
163        self.assertEqual(len(loop_selu_proto.functions), 1)
164