1# Owner(s): ["module: onnx"] 2 3import pytorch_test_common 4from onnx_test_common import run_model_test 5 6import torch 7from torch.onnx import OperatorExportTypes 8from torch.onnx._globals import GLOBALS 9from torch.onnx.utils import _model_to_graph 10from torch.testing._internal import common_utils 11 12 13class TestAutogradFuns(pytorch_test_common.ExportTestCase): 14 opset_version = GLOBALS.export_onnx_opset_version 15 keep_initializers_as_inputs = False 16 onnx_shape_inference = True 17 18 def test_single_output(self): 19 class SingleOut(torch.autograd.Function): 20 @staticmethod 21 def forward(ctx, i): 22 result = i.exp() 23 result = result.log() 24 ctx.save_for_backward(result) 25 return result 26 27 @staticmethod 28 def backward(ctx, grad_output): 29 (result,) = ctx.saved_tensors 30 return grad_output * result 31 32 class Caller(torch.nn.Module): 33 def forward(self, input): 34 result = input + 5 35 return SingleOut.apply(result) + 3 36 37 model = Caller() 38 input = torch.ones(1) 39 run_model_test(self, model, input_args=(input,)) 40 41 def test_multi_output(self): 42 class MultiOut(torch.autograd.Function): 43 @staticmethod 44 def forward(ctx, i): 45 result_exp = i.exp() 46 result_log = result_exp.log() 47 ctx.save_for_backward(result_exp, result_log) 48 return result_exp, result_log 49 50 @staticmethod 51 def backward(ctx, grad_output): 52 (result,) = ctx.saved_tensors 53 return grad_output * result 54 55 class Caller(torch.nn.Module): 56 def forward(self, input): 57 return MultiOut.apply(input) 58 59 model = Caller() 60 input = torch.ones(1, 5) 61 run_model_test(self, model, input_args=(input,)) 62 63 def test_partial_output(self): 64 class PartialOut(torch.autograd.Function): 65 @staticmethod 66 def forward(ctx, input): 67 ctx.save_for_backward(input) 68 values, indices = torch.topk(input, 3) 69 return values 70 71 class Caller(torch.nn.Module): 72 def forward(self, input): 73 return PartialOut.apply(input) 74 75 model = Caller() 76 input = torch.ones(1, 5) 77 run_model_test(self, model, input_args=(input,)) 78 79 def test_nested_autograd(self): 80 class Child(torch.autograd.Function): 81 @staticmethod 82 def forward(ctx, i): 83 result = i.log() 84 result_log = result.log() 85 ctx.save_for_backward(result_log) 86 return result_log 87 88 @staticmethod 89 def backward(ctx, grad_output): 90 (result,) = ctx.saved_tensors 91 return grad_output * result 92 93 class Parent(torch.autograd.Function): 94 @staticmethod 95 def forward(ctx, i): 96 result_exp = i.exp() 97 result_log = Child.apply(result_exp) 98 ctx.save_for_backward(result_exp, result_log) 99 return result_exp, result_log 100 101 @staticmethod 102 def backward(ctx, grad_output): 103 (result,) = ctx.saved_tensors 104 return grad_output * result 105 106 class Caller(torch.nn.Module): 107 def forward(self, input): 108 return Parent.apply(input) 109 110 model = Caller() 111 input = torch.ones(1, 5) 112 run_model_test(self, model, input_args=(input,)) 113 114 # Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported 115 def test_aten_unsupported(self): 116 class Erf(torch.autograd.Function): 117 @staticmethod 118 def forward(ctx, x): 119 erf_out = torch.special.erf(x) 120 ctx.save_for_backward(erf_out) 121 return erf_out 122 123 @staticmethod 124 def backward(ctx, grad_output): 125 result = ctx.saved_tensors 126 return torch.special.erfinv(result), None 127 128 class Caller(torch.nn.Module): 129 def forward(self, input): 130 return Erf.apply(input) 131 132 model = Caller() 133 input = torch.ones(1, 5) 134 135 # Test ONNX_FALLTHROUGH_MODE 136 graph, _, _ = _model_to_graph( 137 model, 138 (input,), 139 operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 140 ) 141 iter = graph.nodes() 142 self.assertEqual(next(iter).kind(), "prim::PythonOp") 143 144 # Test ATEN_FALLBACK_MODE 145 graph, _, _ = _model_to_graph( 146 model, 147 (input,), 148 operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, 149 ) 150 iter = graph.nodes() 151 self.assertEqual(next(iter).kind(), "aten::ATen") 152 153 def test_inline_and_symbolic(self): 154 class Exp(torch.autograd.Function): 155 @staticmethod 156 def forward(ctx, i): 157 ctx.save_for_backward(input) 158 return i.exp() 159 160 @staticmethod 161 def symbolic(g, input): 162 return g.op("Exp", input) 163 164 class LogLog(torch.autograd.Function): 165 @staticmethod 166 def forward(ctx, i): 167 ctx.save_for_backward(input) 168 return i.log().log() 169 170 class Caller(torch.nn.Module): 171 def forward(self, input): 172 exp_result = Exp.apply(input) 173 return LogLog.apply(exp_result) 174 175 model = Caller() 176 input = torch.ones(1) 177 run_model_test(self, model, input_args=(input,)) 178 179 def test_inline_with_scoped_tracing(self): 180 class Exp(torch.autograd.Function): 181 @staticmethod 182 def forward(ctx, i): 183 ctx.save_for_backward(input) 184 return i.exp() 185 186 @staticmethod 187 def symbolic(g, input): 188 return g.op("Exp", input) 189 190 class LogLog(torch.autograd.Function): 191 @staticmethod 192 def forward(ctx, i): 193 ctx.save_for_backward(input) 194 return i.log().log() 195 196 class Caller(torch.nn.Module): 197 def forward(self, input): 198 exp_result = Exp.apply(input) 199 return LogLog.apply(exp_result) 200 201 model = Caller() 202 input = torch.ones(1) 203 204 torch.jit._trace._trace_module_map = { 205 _m: torch.typename(type(_m)) for _m in model.modules() 206 } 207 run_model_test(self, model, input_args=(input,)) 208 torch.jit._trace._trace_module_map = None 209 210 211if __name__ == "__main__": 212 common_utils.run_tests() 213