1# Owner(s): ["oncall: export"] 2# flake8: noqa 3import unittest 4from typing import Dict, List, Tuple 5 6import torch 7import torch._dynamo 8from torch._dynamo.test_case import run_tests, TestCase 9from torch._export.wrappers import _mark_strict_experimental 10from torch._functorch.aot_autograd import aot_export_module 11from torch.export._trace import _convert_ts_to_export_experimental 12from torch.export.experimental import _export_forward_backward 13from torch.testing import FileCheck 14 15 16@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") 17class TestExperiment(TestCase): 18 def test_with_buffer_as_submodule(self): 19 @_mark_strict_experimental 20 class B(torch.nn.Module): 21 def __init__(self) -> None: 22 super().__init__() 23 self.buffer1 = torch.nn.Buffer(torch.ones(3)) 24 25 def forward(self, x): 26 y = x + 2 27 y.add_(4) 28 # this doesnt' work today with HOO 29 # self.buffer1.add_(6) 30 buffer_updated = self.buffer1 + 6 31 return x.sum() + y.sum() + buffer_updated.sum() 32 33 class M(torch.nn.Module): 34 def __init__(self) -> None: 35 super().__init__() 36 self.submodule = B() 37 38 def forward(self, x): 39 x_v2 = x.sin() 40 return (self.submodule(x_v2), x + 3) 41 42 inp = torch.randn(3) 43 ep = torch.export.export(M(), (inp,), strict=False) 44 self.assertExpectedInline( 45 str(ep.graph_module.code.strip()), 46 """\ 47def forward(self, b_submodule_buffer1, x): 48 sin = torch.ops.aten.sin.default(x) 49 strict_graph_0 = self.strict_graph_0 50 strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None 51 getitem_2 = strict_mode[0]; strict_mode = None 52 add = torch.ops.aten.add.Tensor(x, 3); x = None 53 return (getitem_2, add)""", 54 ) 55 56 self.assertExpectedInline( 57 str(ep.graph_module.strict_graph_0.code.strip()), 58 """\ 59def forward(self, arg0_1, arg1_1): 60 add = torch.ops.aten.add.Tensor(arg0_1, 2) 61 add_1 = torch.ops.aten.add.Tensor(add, 4); add = None 62 add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None 63 sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None 64 sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None 65 add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None 66 sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None 67 add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None 68 return (add_4,)""", 69 ) 70 71 eager_mod = M() 72 ep = torch.export.export(eager_mod, (inp,), strict=True) 73 74 graph_res_1, graph_res_2 = ep.module()(inp) 75 eager_res_1, eager_res_2 = eager_mod(inp) 76 77 self.assertTrue(torch.allclose(graph_res_2, eager_res_2)) 78 self.assertTrue(torch.allclose(graph_res_1, eager_res_1)) 79 80 graph_res_1, graph_res_2 = ep.module()(inp) 81 eager_res_1, eager_res_2 = eager_mod(inp) 82 83 self.assertTrue(torch.allclose(graph_res_2, eager_res_2)) 84 self.assertTrue(torch.allclose(graph_res_1, eager_res_1)) 85 86 def test_mark_strict_with_container_type(self): 87 @_mark_strict_experimental 88 class B(torch.nn.Module): 89 def __init__(self) -> None: 90 super().__init__() 91 92 def forward(self, x): 93 x0 = x[0][0] 94 return x0.sum() 95 96 class M(torch.nn.Module): 97 def __init__(self) -> None: 98 super().__init__() 99 self.submodule = B() 100 101 def forward(self, x): 102 return self.submodule(x) 103 104 inp = ((torch.randn(3),),) 105 with self.assertRaisesRegex( 106 RuntimeError, "strict_mode HOO doesn't work unless" 107 ): 108 ep = torch.export.export(M(), inp, strict=False) 109 110 def test_torchscript_module_export(self): 111 class M(torch.nn.Module): 112 def forward(self, x): 113 return x.cos() + x.sin() 114 115 model_to_trace = M() 116 inps = (torch.randn(4, 4),) 117 traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) 118 119 exported_module = _convert_ts_to_export_experimental( 120 traced_module_by_torchscript, inps 121 ) 122 123 self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps))) 124 125 def test_torchscript_module_export_single_input(self): 126 class M(torch.nn.Module): 127 def forward(self, x): 128 return x.cos() + x.sin() 129 130 model_to_trace = M() 131 inps = torch.randn(4, 4) 132 traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) 133 134 exported_module = _convert_ts_to_export_experimental( 135 traced_module_by_torchscript, inps 136 ) 137 138 self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps))) 139 140 def test_torchscript_module_export_various_inputs_with_annotated_input_names(self): 141 def _check_equality_and_annotations(m_func, inps): 142 # Original module. 143 model_to_trace = m_func() 144 145 # ExportedProgram from TorchScript module. 146 traced_module_by_torchscript = torch.jit.trace( 147 m_func(), example_inputs=inps 148 ) 149 exported_module = _convert_ts_to_export_experimental( 150 traced_module_by_torchscript, inps 151 ) 152 153 # ExportedProgram from original module. 154 original_exported_module = torch.export.export(m_func(), inps) 155 156 # Check whether input annotations are the same as tracing the original module. 157 orig_ph_name_list = [ 158 n.name 159 for n in original_exported_module.graph.nodes 160 if n.op == "placeholder" 161 ] 162 ph_name_list = [ 163 n.name for n in exported_module.graph.nodes if n.op == "placeholder" 164 ] 165 self.assertEqual(orig_ph_name_list, ph_name_list) 166 167 # Check results equality. 168 self.assertTrue( 169 torch.allclose(exported_module(*inps), model_to_trace(*inps)) 170 ) 171 172 # Tuple 173 class MTuple(torch.nn.Module): 174 def forward(self, x: Tuple[torch.Tensor]): 175 return x[0] + x[1] 176 177 _check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),)) 178 179 # List 180 class MList(torch.nn.Module): 181 def forward(self, x: List[torch.Tensor]): 182 return x[0] + x[1] 183 184 _check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],)) 185 186 # Dict 187 class MDict(torch.nn.Module): 188 def forward(self, x: Dict[str, torch.Tensor]): 189 return x["0"] + x["1"] 190 191 _check_equality_and_annotations( 192 MDict, ({"0": torch.randn(4), "1": torch.randn(4)},) 193 ) 194 195 def test_joint_basic(self) -> None: 196 class Module(torch.nn.Module): 197 def __init__(self) -> None: 198 super().__init__() 199 self.linear = torch.nn.Linear(3, 3) 200 self.loss = torch.nn.CrossEntropyLoss() 201 202 def forward(self, x): 203 return self.loss( 204 self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0]) 205 ) 206 207 m = Module() 208 example_inputs = (torch.randn(3),) 209 m(*example_inputs) 210 ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True) 211 joint_ep = _export_forward_backward(ep) 212 self.assertExpectedInline( 213 str(joint_ep.graph_module.code).strip(), 214 """\ 215def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): 216 view = torch.ops.aten.view.default(x, [1, 3]); x = None 217 permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None 218 addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None 219 view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None 220 _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None 221 alias = torch.ops.aten.alias.default(_softmax) 222 alias_1 = torch.ops.aten.alias.default(alias); alias = None 223 clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None 224 alias_2 = torch.ops.aten.alias.default(clone); clone = None 225 alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None 226 alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None 227 _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None 228 alias_5 = torch.ops.aten.alias.default(_log_softmax) 229 alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None 230 mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None 231 sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None 232 neg = torch.ops.aten.neg.default(sum_1); sum_1 = None 233 div = torch.ops.aten.div.Scalar(neg, 1); neg = None 234 full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) 235 div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None 236 neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None 237 expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None 238 mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None 239 alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None 240 alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None 241 exp = torch.ops.aten.exp.default(alias_8); alias_8 = None 242 sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) 243 mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None 244 sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None 245 alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None 246 alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None 247 mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None 248 sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) 249 mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None 250 sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None 251 view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None 252 permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) 253 mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None 254 permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None 255 sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None 256 view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None 257 permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None 258 return (div, permute_3, view_3)""", 259 ) 260 ep = joint_ep.run_decompositions() 261 self.assertExpectedInline( 262 str(ep.graph_module.code).strip(), 263 """\ 264def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): 265 view = torch.ops.aten.view.default(x, [1, 3]); x = None 266 permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None 267 addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None 268 view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None 269 _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None 270 alias = torch.ops.aten.alias.default(_softmax) 271 alias_1 = torch.ops.aten.alias.default(alias); alias = None 272 clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None 273 alias_2 = torch.ops.aten.alias.default(clone); clone = None 274 alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None 275 alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None 276 _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None 277 alias_5 = torch.ops.aten.alias.default(_log_softmax) 278 alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None 279 mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None 280 sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None 281 neg = torch.ops.aten.neg.default(sum_1); sum_1 = None 282 div = torch.ops.aten.div.Scalar(neg, 1); neg = None 283 full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) 284 div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None 285 neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None 286 expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None 287 mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None 288 alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None 289 alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None 290 exp = torch.ops.aten.exp.default(alias_8); alias_8 = None 291 sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) 292 mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None 293 sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None 294 alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None 295 alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None 296 mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None 297 sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) 298 mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None 299 sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None 300 view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None 301 permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) 302 mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None 303 permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None 304 sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None 305 view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None 306 permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None 307 return (div, permute_3, view_3)""", 308 ) 309 310 def test_joint_dynamic(self) -> None: 311 from torch.export import Dim 312 313 class Module(torch.nn.Module): 314 def __init__(self) -> None: 315 super().__init__() 316 self.y = torch.nn.Parameter(torch.randn(3)) 317 318 def forward(self, x): 319 x = torch.ones(x.shape[0], 3) 320 return (self.y + x).sum() 321 322 m = Module() 323 example_inputs = (torch.randn(3),) 324 m(*example_inputs) 325 ep = torch.export._trace._export( 326 m, example_inputs, pre_dispatch=True, dynamic_shapes={"x": {0: Dim("x0")}} 327 ) 328 joint_ep = _export_forward_backward(ep) 329 330 331if __name__ == "__main__": 332 run_tests() 333