xref: /aosp_15_r20/external/pytorch/test/onnx/test_autograd_funs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import pytorch_test_common
4from onnx_test_common import run_model_test
5
6import torch
7from torch.onnx import OperatorExportTypes
8from torch.onnx._globals import GLOBALS
9from torch.onnx.utils import _model_to_graph
10from torch.testing._internal import common_utils
11
12
13class TestAutogradFuns(pytorch_test_common.ExportTestCase):
14    opset_version = GLOBALS.export_onnx_opset_version
15    keep_initializers_as_inputs = False
16    onnx_shape_inference = True
17
18    def test_single_output(self):
19        class SingleOut(torch.autograd.Function):
20            @staticmethod
21            def forward(ctx, i):
22                result = i.exp()
23                result = result.log()
24                ctx.save_for_backward(result)
25                return result
26
27            @staticmethod
28            def backward(ctx, grad_output):
29                (result,) = ctx.saved_tensors
30                return grad_output * result
31
32        class Caller(torch.nn.Module):
33            def forward(self, input):
34                result = input + 5
35                return SingleOut.apply(result) + 3
36
37        model = Caller()
38        input = torch.ones(1)
39        run_model_test(self, model, input_args=(input,))
40
41    def test_multi_output(self):
42        class MultiOut(torch.autograd.Function):
43            @staticmethod
44            def forward(ctx, i):
45                result_exp = i.exp()
46                result_log = result_exp.log()
47                ctx.save_for_backward(result_exp, result_log)
48                return result_exp, result_log
49
50            @staticmethod
51            def backward(ctx, grad_output):
52                (result,) = ctx.saved_tensors
53                return grad_output * result
54
55        class Caller(torch.nn.Module):
56            def forward(self, input):
57                return MultiOut.apply(input)
58
59        model = Caller()
60        input = torch.ones(1, 5)
61        run_model_test(self, model, input_args=(input,))
62
63    def test_partial_output(self):
64        class PartialOut(torch.autograd.Function):
65            @staticmethod
66            def forward(ctx, input):
67                ctx.save_for_backward(input)
68                values, indices = torch.topk(input, 3)
69                return values
70
71        class Caller(torch.nn.Module):
72            def forward(self, input):
73                return PartialOut.apply(input)
74
75        model = Caller()
76        input = torch.ones(1, 5)
77        run_model_test(self, model, input_args=(input,))
78
79    def test_nested_autograd(self):
80        class Child(torch.autograd.Function):
81            @staticmethod
82            def forward(ctx, i):
83                result = i.log()
84                result_log = result.log()
85                ctx.save_for_backward(result_log)
86                return result_log
87
88            @staticmethod
89            def backward(ctx, grad_output):
90                (result,) = ctx.saved_tensors
91                return grad_output * result
92
93        class Parent(torch.autograd.Function):
94            @staticmethod
95            def forward(ctx, i):
96                result_exp = i.exp()
97                result_log = Child.apply(result_exp)
98                ctx.save_for_backward(result_exp, result_log)
99                return result_exp, result_log
100
101            @staticmethod
102            def backward(ctx, grad_output):
103                (result,) = ctx.saved_tensors
104                return grad_output * result
105
106        class Caller(torch.nn.Module):
107            def forward(self, input):
108                return Parent.apply(input)
109
110        model = Caller()
111        input = torch.ones(1, 5)
112        run_model_test(self, model, input_args=(input,))
113
114    # Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported
115    def test_aten_unsupported(self):
116        class Erf(torch.autograd.Function):
117            @staticmethod
118            def forward(ctx, x):
119                erf_out = torch.special.erf(x)
120                ctx.save_for_backward(erf_out)
121                return erf_out
122
123            @staticmethod
124            def backward(ctx, grad_output):
125                result = ctx.saved_tensors
126                return torch.special.erfinv(result), None
127
128        class Caller(torch.nn.Module):
129            def forward(self, input):
130                return Erf.apply(input)
131
132        model = Caller()
133        input = torch.ones(1, 5)
134
135        # Test ONNX_FALLTHROUGH_MODE
136        graph, _, _ = _model_to_graph(
137            model,
138            (input,),
139            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
140        )
141        iter = graph.nodes()
142        self.assertEqual(next(iter).kind(), "prim::PythonOp")
143
144        # Test ATEN_FALLBACK_MODE
145        graph, _, _ = _model_to_graph(
146            model,
147            (input,),
148            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
149        )
150        iter = graph.nodes()
151        self.assertEqual(next(iter).kind(), "aten::ATen")
152
153    def test_inline_and_symbolic(self):
154        class Exp(torch.autograd.Function):
155            @staticmethod
156            def forward(ctx, i):
157                ctx.save_for_backward(input)
158                return i.exp()
159
160            @staticmethod
161            def symbolic(g, input):
162                return g.op("Exp", input)
163
164        class LogLog(torch.autograd.Function):
165            @staticmethod
166            def forward(ctx, i):
167                ctx.save_for_backward(input)
168                return i.log().log()
169
170        class Caller(torch.nn.Module):
171            def forward(self, input):
172                exp_result = Exp.apply(input)
173                return LogLog.apply(exp_result)
174
175        model = Caller()
176        input = torch.ones(1)
177        run_model_test(self, model, input_args=(input,))
178
179    def test_inline_with_scoped_tracing(self):
180        class Exp(torch.autograd.Function):
181            @staticmethod
182            def forward(ctx, i):
183                ctx.save_for_backward(input)
184                return i.exp()
185
186            @staticmethod
187            def symbolic(g, input):
188                return g.op("Exp", input)
189
190        class LogLog(torch.autograd.Function):
191            @staticmethod
192            def forward(ctx, i):
193                ctx.save_for_backward(input)
194                return i.log().log()
195
196        class Caller(torch.nn.Module):
197            def forward(self, input):
198                exp_result = Exp.apply(input)
199                return LogLog.apply(exp_result)
200
201        model = Caller()
202        input = torch.ones(1)
203
204        torch.jit._trace._trace_module_map = {
205            _m: torch.typename(type(_m)) for _m in model.modules()
206        }
207        run_model_test(self, model, input_args=(input,))
208        torch.jit._trace._trace_module_map = None
209
210
211if __name__ == "__main__":
212    common_utils.run_tests()
213