xref: /aosp_15_r20/external/pytorch/test/onnx/test_custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["module: onnx"]
2 
3 import onnx_test_common
4 import pytorch_test_common
5 
6 import torch
7 import torch.utils.cpp_extension
8 from torch.onnx import symbolic_helper
9 from torch.testing._internal import common_utils
10 
11 
12 class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase):
13     opset_version = 9
14     keep_initializers_as_inputs = False
15     onnx_shape_inference = True
16 
17     def test_symbolic(self):
18         class MyClip(torch.autograd.Function):
19             @staticmethod
20             def forward(ctx, input, scalar):
21                 ctx.save_for_backward(input)
22                 return input.clamp(min=scalar)
23 
24             @staticmethod
25             def symbolic(g, input, scalar):
26                 return g.op("Clip", input, min_f=scalar)
27 
28         class MyModule(torch.nn.Module):
29             def __init__(self) -> None:
30                 super().__init__()
31                 self.clip = MyClip.apply
32 
33             def forward(self, x):
34                 h = self.clip(x, 2)
35                 return h
36 
37         x = torch.randn(2, 3, 4, requires_grad=True)
38         model = MyModule()
39         onnx_test_common.run_model_test(self, model, input_args=(x,))
40 
41     def test_register_op(self):
42         class MyClip(torch.autograd.Function):
43             @staticmethod
44             def forward(ctx, input, scalar):
45                 ctx.save_for_backward(input)
46                 return input.clamp(min=scalar)
47 
48         class MyRelu(torch.autograd.Function):
49             @staticmethod
50             def forward(ctx, input):
51                 ctx.save_for_backward(input)
52                 return input.clamp(min=0)
53 
54         class MyModule(torch.nn.Module):
55             def __init__(self) -> None:
56                 super().__init__()
57                 self.clip = MyClip.apply
58                 self.relu = MyRelu.apply
59 
60             def forward(self, x):
61                 h = self.clip(x, 2)
62                 h = self.relu(h)
63                 return h
64 
65         def symbolic_pythonop(g, *args, **kwargs):
66             name = kwargs["name"]
67             if name == "MyClip":
68                 return g.op("Clip", args[0], min_f=args[1])
69             elif name == "MyRelu":
70                 return g.op("Relu", args[0])
71             else:
72                 return symbolic_helper._unimplemented(
73                     "prim::PythonOp", "unknown node kind: " + name
74                 )
75 
76         from torch.onnx import register_custom_op_symbolic
77 
78         register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
79 
80         x = torch.randn(2, 3, 4, requires_grad=True)
81         model = MyModule()
82         onnx_test_common.run_model_test(self, model, input_args=(x,))
83 
84 
85 class TestExportAsContribOps(pytorch_test_common.ExportTestCase):
86     opset_version = 14
87     keep_initializers_as_inputs = False
88     onnx_shape_inference = True
89 
90     def test_contrib_op_with_loop(self):
91         class M(torch.nn.Module):
92             def __init__(self) -> None:
93                 super().__init__()
94                 self.gelu = torch.nn.GELU(approximate="none")
95 
96             def forward(self, x):
97                 res = []
98                 res2 = []
99                 for i in range(x.size(0)):
100                     if len(res) > 0:
101                         res2.append(res[0])
102                     else:
103                         res2.append(self.gelu(x[0]))
104                     res.append(x[0])
105                 return torch.stack(res), torch.stack(res2)
106 
107         def symbolic_custom_gelu(g, input, approximate):
108             return g.op("com.microsoft::Gelu", input).setType(input.type())
109 
110         from torch.onnx import register_custom_op_symbolic
111 
112         register_custom_op_symbolic("::gelu", symbolic_custom_gelu, 1)
113 
114         x = torch.randn(3, 3, 4, requires_grad=True)
115         model = torch.jit.script(M())
116         onnx_test_common.run_model_test(self, model, input_args=(x,))
117 
118 
119 if __name__ == "__main__":
120     common_utils.run_tests()
121