1# Owner(s): ["oncall: export"] 2# flake8: noqa 3import copy 4import io 5import unittest 6 7import torch 8import torch._dynamo as torchdynamo 9import torch.utils._pytree as pytree 10from torch._dynamo.test_case import TestCase 11from torch.export import export, load, save 12from torch.export._trace import _export 13from torch.testing._internal.common_device_type import ( 14 instantiate_device_type_tests, 15 ops, 16) 17from torch.testing._internal.common_utils import ( 18 IS_WINDOWS, 19 run_tests, 20 TestCase as TorchTestCase, 21) 22from torch.testing._internal.hop_db import ( 23 hop_db, 24 hop_that_doesnt_have_opinfo_test_allowlist, 25) 26 27 28hop_tests = [] 29 30for op_info in hop_db: 31 op_info_hop_name = op_info.name 32 if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist: 33 continue 34 hop_tests.append(op_info) 35 36 37class TestHOPGeneric(TestCase): 38 def test_all_hops_have_op_info(self): 39 from torch._ops import _higher_order_ops 40 41 hops_that_have_op_info = set([k.name for k in hop_db]) 42 all_hops = _higher_order_ops.keys() 43 44 missing_ops = [] 45 46 for op in all_hops: 47 if ( 48 op not in hops_that_have_op_info 49 and op not in hop_that_doesnt_have_opinfo_test_allowlist 50 ): 51 missing_ops.append(op) 52 53 self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}") 54 55 56@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 57@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") 58class TestHOP(TestCase): 59 def _compare(self, eager_model, export, args, kwargs): 60 eager_args = copy.deepcopy(args) 61 eager_kwargs = copy.deepcopy(kwargs) 62 export_args = copy.deepcopy(args) 63 export_kwargs = copy.deepcopy(kwargs) 64 65 flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs)) 66 flat_loaded_outputs = pytree.tree_leaves( 67 export.module()(*export_args, **export_kwargs) 68 ) 69 70 for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): 71 self.assertEqual(type(orig), type(loaded)) 72 self.assertEqual(orig, loaded) 73 74 @ops(hop_tests, allowed_dtypes=(torch.float,)) 75 def test_aot_export(self, device, dtype, op): 76 class Foo(torch.nn.Module): 77 def forward(self, *args): 78 return op.op(*args) 79 80 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) 81 for inp in sample_inputs_itr: 82 model = Foo() 83 input = inp.input if isinstance(inp.input, tuple) else (inp.input,) 84 args = (*input, *inp.args) 85 kwargs = inp.kwargs 86 ep = export(model, args, kwargs) 87 self._compare(model, ep, args, kwargs) 88 89 @ops(hop_tests, allowed_dtypes=(torch.float,)) 90 def test_pre_dispatch_export(self, device, dtype, op): 91 class Foo(torch.nn.Module): 92 def forward(self, *args): 93 return op.op(*args) 94 95 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) 96 for inp in sample_inputs_itr: 97 model = Foo() 98 input = inp.input if isinstance(inp.input, tuple) else (inp.input,) 99 args = (*input, *inp.args) 100 kwargs = inp.kwargs 101 ep = _export(model, args, kwargs, pre_dispatch=True) 102 self._compare(model, ep, args, kwargs) 103 104 @ops(hop_tests, allowed_dtypes=(torch.float,)) 105 def test_retrace_export(self, device, dtype, op): 106 class Foo(torch.nn.Module): 107 def forward(self, *args): 108 return op.op(*args) 109 110 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) 111 for inp in sample_inputs_itr: 112 model = Foo() 113 input = inp.input if isinstance(inp.input, tuple) else (inp.input,) 114 args = (*input, *inp.args) 115 kwargs = inp.kwargs 116 ep = _export(model, args, kwargs, pre_dispatch=True) 117 ep = ep.run_decompositions() 118 self._compare(model, ep, args, kwargs) 119 120 @ops(hop_tests, allowed_dtypes=(torch.float,)) 121 def test_serialize_export(self, device, dtype, op): 122 class Foo(torch.nn.Module): 123 def forward(self, *args): 124 return op.op(*args) 125 126 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) 127 for inp in sample_inputs_itr: 128 model = Foo() 129 input = inp.input if isinstance(inp.input, tuple) else (inp.input,) 130 args = (*input, *inp.args) 131 kwargs = inp.kwargs 132 ep = _export(model, args, kwargs, pre_dispatch=True) 133 ep = ep.run_decompositions() 134 buffer = io.BytesIO() 135 save(ep, buffer) 136 buffer.seek(0) 137 ep = load(buffer) 138 if "while_loop" in str(op): 139 # while_loop's arguments are cast into list after deserailize 140 # but while_loop expects it to still be tuple 141 with self.assertRaisesRegex( 142 RuntimeError, "carried_inputs must be a tuple" 143 ): 144 self._compare(model, ep, args, kwargs) 145 else: 146 self._compare(model, ep, args, kwargs) 147 148 149instantiate_device_type_tests(TestHOP, globals()) 150 151if __name__ == "__main__": 152 run_tests() 153