""" PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_sym_bool) """ # Owner(s): ["oncall: export"] import copy import io import tempfile import unittest import zipfile from pathlib import Path import torch import torch._dynamo as torchdynamo import torch.export._trace import torch.utils._pytree as pytree from torch._export.db.case import ExportCase, SupportLevel from torch._export.db.examples import all_examples from torch._export.serde.serialize import ( canonicalize, deserialize, ExportedProgramDeserializer, ExportedProgramSerializer, serialize, SerializeError, ) from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.export import Dim, export, load, save from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize, run_tests, TemporaryFileName, TestCase, ) from torch.testing._internal.torchbind_impls import init_torchbind_implementations def get_filtered_export_db_tests(): return [ (name, case) for name, case in all_examples().items() if case.support_level == SupportLevel.SUPPORTED ] @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSerialize(TestCase): def test_export_with_extension_op_serialization(self): class TestModule(torch.nn.Module): def forward(self, x): return x + x class FooExtensionOp: def __hash__(self): return 0 def __eq__(self, other): return type(other) == type(self) def __call__(self, *args, **kwargs): return torch.ops.aten.add.Tensor(*args, **kwargs) @property def __name__(self): return "foo.my_op" class ExtensionVerifier(torch._export.verifier.Verifier): dialect = "FOO" def allowed_op_types(self): return super().allowed_op_types() + (FooExtensionOp,) class FooExtensionHandler(torch._export.serde.serialize.ExtensionHandler): @classmethod def namespace(cls): return "foo" @classmethod def to_op_name(cls, op): return "my_op" @classmethod def from_op_name(cls, name: str): self.assertEqual(name, "my_op") return FooExtensionOp() @classmethod def op_schema(cls, op): return torch.ops.aten.add.Tensor._schema inp = (torch.ones(10),) ep = export(TestModule(), inp) # Register the custom op handler. foo_custom_op = FooExtensionOp() torch._export.serde.serialize.register_extension( FooExtensionOp, FooExtensionHandler ) new_gm = copy.deepcopy(ep.graph_module) # Inject the custom operator. for node in new_gm.graph.nodes: if node.name == "add": node.target = foo_custom_op new_ep = ep._update(new_gm, ep.graph_signature, verifiers=[ExtensionVerifier]) serialized = serialize(new_ep) deserialized = deserialize(serialized) self.assertEqual( len( deserialized.graph.find_nodes(op="call_function", target=foo_custom_op) ), 1, ) def test_predispatch_export_with_autograd_op(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): with torch.enable_grad(): return x + x inp = (torch.ones(10),) with torch.no_grad(): from torch.export._trace import _export ep = _export(Foo(), inp, pre_dispatch=True) buffer = io.BytesIO() torch.export.save(ep, buffer) buffer.seek(0) loaded_ep = torch.export.load(buffer) exp_out = ep.module()(*inp) actual_out = loaded_ep.module()(*inp) self.assertEqual(exp_out, actual_out) self.assertEqual(exp_out.requires_grad, actual_out.requires_grad) def test_export_example_inputs_preserved(self): class MyModule(torch.nn.Module): """A test module with that has multiple args and uses kwargs""" def __init__(self) -> None: super().__init__() self.p = torch.nn.Parameter(torch.ones(2, 3)) def forward(self, x, y, use_p=False): out = x + y if use_p: out += self.p return out model = MyModule().eval() random_inputs = (torch.rand([2, 3]), torch.rand([2, 3])) exp_program = torch.export.export(model, random_inputs, {"use_p": True}) output_buffer = io.BytesIO() # Tests that example inputs are preserved when saving and loading module. torch.export.save(exp_program, output_buffer) loaded_model = torch.export.load(output_buffer) # Extract the example inputs from before and after saving. orig_args, orig_kwargs = exp_program.example_inputs loaded_args, loaded_kwargs = loaded_model.example_inputs # Run both modules and confirm that outputs match. orig_out = exp_program.module()(*orig_args, **orig_kwargs) loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs) self.assertEqual(orig_out, loaded_out) def test_metadata_parsing_with_layer_split(self): # Tests that modules with more complicated layer patterns can be serialized # and deserialized correctly. class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.layers = torch.nn.Sequential( torch.nn.SiLU(), torch.nn.SiLU(), torch.nn.SiLU(), ) def forward(self, x): # Splitting layers of a sequential stack introduces commas and parens # into metadata trace. out_start, out_rest = self.layers[0], self.layers[1:] h = out_start(x) h = out_rest(h) return h inp = (torch.ones(10),) # Module will only be able to roundtrip if metadata # can be correctly parsed. ep = export(MyModule(), inp) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) # Check that both modules run to confirm load was successful. exp_out = ep.module()(*inp) actual_out = loaded_ep.module()(*inp) self.assertEqual(exp_out, actual_out) def test_serialize_constant_outputs(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): # Along with tensor output, return Nonetype # and constant. Although these outputs aren't # very useful, they do show up in graphs. return x + 1, None, 1024 # Check that module can be roundtripped, thereby confirming proper deserialization. inp = (torch.ones(10),) ep = export(MyModule(), inp) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) exp_out = ep.module()(*inp) actual_out = loaded_ep.module()(*inp) self.assertEqual(exp_out, actual_out) def test_serialize_multiple_returns_from_node(self) -> None: class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w, b): return torch.nn.functional.layer_norm( x, x.size()[1:], weight=w, bias=b, eps=1e-5, ) exported_module = export( MyModule(), ( torch.ones([512, 512], requires_grad=True), torch.ones([512]), torch.ones([512]), ), ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] self.assertEqual(node.target, "torch.ops.aten.native_layer_norm.default") # aten::native_layer_norm returns 3 tensors self.assertEqual(len(node.outputs), 3) # check the names are unique seen = set() for output in node.outputs: name = output.as_tensor.name self.assertNotIn(name, seen) seen.add(name) def test_serialize_sym_int(self) -> None: class DynamicShapeSimpleModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, a, b, c) -> torch.Tensor: d = (torch.matmul(a, b) + c) / 2 d_s0 = d.shape[0] d_s1 = d.shape[1] d_s3 = d_s0 * d_s1 e = d.view(d_s3) return torch.cat([e, e]) inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) dim0_ac = torch.export.Dim("dim0_ac") dim1_bc = torch.export.Dim("dim1_b") dynamic_shapes = { "a": {0: dim0_ac}, "b": {1: dim1_bc}, "c": {0: dim0_ac, 1: dim1_bc}, } exported_module = export( DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) sym_size_nodes = [ node for node in serialized.exported_program.graph_module.graph.nodes if node.target == "torch.ops.aten.sym_size.int" ] for node in sym_size_nodes: self.assertEqual(node.inputs[0].name, "self") self.assertEqual(node.inputs[1].name, "dim") def test_serialize_list_returns(self) -> None: class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return torch.split(x, 2) input = torch.arange(10.0).reshape(5, 2) exported_module = export(MyModule(), (input,)).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] # split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default") self.assertEqual(len(node.outputs), 1) # Input looks like: # tensor([[0, 1], # [2, 3], # [4, 5], # [6, 7], # [8, 9]]) # Output looks like: # (tensor([[0, 1], # [2, 3]]), # tensor([[4, 5], # [6, 7]]), # tensor([[8, 9]])) self.assertEqual(len(node.outputs[0].as_tensors), 3) # check the names are unique seen = set() for output in node.outputs[0].as_tensors: name = output.name self.assertNotIn(name, seen) seen.add(name) def test_multi_return_some_unused(self) -> None: """ Make sure the serialized output matches the op schema, even if some of the arguments are never used in the graph. """ class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return torch.ops.aten.var_mean.correction(x, [1])[0] exported_module = export( MyModule(), (torch.ones([512, 512], requires_grad=True),), ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] self.assertEqual(node.target, "torch.ops.aten.var_mean.correction") self.assertEqual(len(node.outputs), 2) # check the names are unique seen = set() for output in node.outputs: name = output.as_tensor.name self.assertNotIn(name, seen) seen.add(name) def test_rational_ranges(self) -> None: class M(torch.nn.Module): def forward(self, x): return x + x ep = torch.export.export( M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},) ) range_constraints = list(ep.range_constraints.keys()) assert len(range_constraints) == 1 symint = range_constraints[0] import sympy upper_range = sympy.Rational(10, 3) lower_range = sympy.Rational(10, 6) ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range) serialized = ExportedProgramSerializer().serialize(ep) self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2) self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3) def test_kwargs_default(self) -> None: """ Tests that the kwargs default values are serialized even if they are not specified """ class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: values = torch.randn(3, 2) return torch.searchsorted(x, values, side="right", right=True) f = Foo() x, _ = torch.sort(torch.randn(3, 4)) exported_module = export(f, (x,)).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor") self.assertEqual(len(node.inputs), 4) self.assertEqual(node.inputs[2].name, "right") self.assertEqual(node.inputs[2].arg.as_bool, True) self.assertEqual(node.inputs[3].name, "side") self.assertEqual(node.inputs[3].arg.as_string, "right") def test_canonicalize(self) -> None: class Module(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: a = y + x b = x + y return b + a ep = torch.export.export(Module(), (torch.randn(3, 2), torch.randn(3, 2))) s = ExportedProgramSerializer().serialize(ep) c = canonicalize(s.exported_program) g = c.graph_module.graph self.assertLess( g.nodes[0].inputs[0].arg.as_tensor.name, g.nodes[1].inputs[0].arg.as_tensor.name, ) def test_int_list(self) -> None: class M(torch.nn.Module): def forward(self, x): return torch.ops.aten.sum.dim_IntList(x, []) ep = torch.export.export(M(), (torch.randn(3, 2),)) serialized = ExportedProgramSerializer().serialize(ep) for node in serialized.exported_program.graph_module.graph.nodes: if "aten.sum.dim_IntList" in node.target: self.assertEqual(node.inputs[1].arg.type, "as_ints") @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestDeserialize(TestCase): def setUp(self): super().setUp() init_torchbind_implementations() def _check_graph_nodes(self, gm1, gm2, _check_meta=True): # TODO: The _check_meta flag bypasses checking for # source_fn/nn_module_stack as there is an issue with # roundtripping the source_fn value on torch.ops.map nodes # original source_fn: # deserialized source_fn: 'functorch.experimental._map.map' self.assertEqual(len(gm1.graph.nodes), len(gm2.graph.nodes)) for node1, node2 in zip(gm1.graph.nodes, gm2.graph.nodes): self.assertEqual(node1.op, node2.op) if node1.op == "call_function": # Check "val" metadata val1 = node1.meta.get("val", None) val2 = node2.meta.get("val", None) if val1 is None or val2 is None: # Either both are None self.assertEqual(val1, val2) elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor): # Or both are fake tensors with the same shape/dtype self.assertEqual(len(val1.shape), len(val2.shape)) for s1, s2 in zip(val1.shape, val2.shape): if is_concrete_int(s1) and is_concrete_int(s2): self.assertEqual(s1, s2) else: self.assertEqual(str(s1), str(s2)) self.assertEqual(val1.dtype, val2.dtype) elif isinstance(val1, (list, tuple)) and isinstance( val2, (list, tuple) ): # Or both are fake tensors lists with one element and with the # same shape/dtype for v1, v2 in zip( pytree.tree_leaves(val1), pytree.tree_leaves(val2) ): if isinstance(v1, FakeTensor): self.assertEqual(v1.shape, v2.shape) self.assertEqual(v1.dtype, v2.dtype) else: # For expressions like 's0 < 10' can only compare through string self.assertEqual(str(val1), str(val2)) # Check "stack_trace" metadata self.assertEqual( node1.meta.get("stack_trace", None), node2.meta.get("stack_trace", None), ) if node1.target == torch.ops.higher_order.cond: true_graph1 = getattr(gm1, node1.args[1].target) true_graph2 = getattr(gm2, node2.args[1].target) self._check_graph_nodes(true_graph1, true_graph2) false_graph1 = getattr(gm1, node1.args[2].target) false_graph2 = getattr(gm2, node2.args[2].target) self._check_graph_nodes(false_graph1, false_graph2) elif node1.target == torch.ops.higher_order.map_impl: map_graph1 = getattr(gm1, node1.args[0].target) map_graph2 = getattr(gm2, node2.args[0].target) self._check_graph_nodes(map_graph1, map_graph2, False) if _check_meta and node1.op not in ("get_attr", "placeholder", "output"): # Check "nn_module_stack" metadata self.assertEqual( node1.meta.get("nn_module_stack", None), node2.meta.get("nn_module_stack", None), ) # Check "source_fn_stack" metadata self.assertEqual( node1.meta.get("source_fn_stack", None), node2.meta.get("source_fn_stack", None), ) def check_graph( self, fn, inputs, dynamic_shapes=None, _check_meta=True, use_pre_dispatch=True, strict=True, ) -> None: """Export a graph, serialize it, deserialize it, and compare the results.""" def _deepcopy_inputs(inputs): # copy.deepcopy(deepcopy) can fail if tensor inputs have attribute (i.e. __dict__). # we remove __dict__ when deepcopying. dict_mapping = dict() inputs_clone = () for idx, i in enumerate(inputs): if isinstance(i, torch.Tensor) and hasattr(inputs[0], "__dict__"): dict_mapping[idx] = i.__dict__ i.__dict__ = {} inputs_clone += (copy.deepcopy(i),) # Add __dict__ back. for k, v in dict_mapping.items(): inputs[k].__dict__ = v inputs_clone[k].__dict__ = v return inputs_clone def _check_graph(pre_dispatch): if pre_dispatch: ep = torch.export._trace._export( fn, _deepcopy_inputs(inputs), {}, dynamic_shapes=dynamic_shapes, pre_dispatch=True, strict=strict, ) else: ep = torch.export.export( fn, _deepcopy_inputs(inputs), {}, dynamic_shapes=dynamic_shapes, strict=strict, ) ep.graph.eliminate_dead_code() serialized_artifact = serialize(ep, opset_version={"aten": 0}) deserialized_ep = deserialize( serialized_artifact, expected_opset_version={"aten": 0} ) deserialized_ep.graph.eliminate_dead_code() orig_outputs = ep.module()(*_deepcopy_inputs(inputs)) loaded_outputs = deserialized_ep.module()(*_deepcopy_inputs(inputs)) flat_orig_outputs = pytree.tree_leaves(orig_outputs) flat_loaded_outputs = pytree.tree_leaves(loaded_outputs) for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): self.assertEqual(type(orig), type(loaded)) if isinstance(orig, torch.Tensor): if orig.is_meta: self.assertEqual(orig, loaded) else: self.assertTrue(torch.allclose(orig, loaded)) else: self.assertEqual(orig, loaded) self._check_graph_nodes( ep.graph_module, deserialized_ep.graph_module, _check_meta ) if use_pre_dispatch: _check_graph(pre_dispatch=True) _check_graph(pre_dispatch=False) else: _check_graph(pre_dispatch=False) def test_optional_tuple(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", "(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl("mylib::foo", "cpu", lib=lib) @torch.library.impl_abstract("mylib::foo") def foo_impl(a, b, c): res2 = None if c is not None: res2 = c + a + b return a + b, res2 class M(torch.nn.Module): def forward(self, a, b, c): return torch.ops.mylib.foo(a, b, c) self.check_graph(M(), (torch.randn(3), torch.randn(3), torch.randn(3))) def test_auto_functionalize(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo1", "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) torch.library.define( "mylib::foo2", "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) torch.library.define( "mylib::foo3", "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl("mylib::foo1", "cpu", lib=lib) @torch.library.impl_abstract("mylib::foo1") def foo1_impl(x, y, z, w, n): x.add_(y[0] + w) z.add_(y[1] + n) return n + n @torch.library.impl("mylib::foo2", "cpu", lib=lib) @torch.library.impl_abstract("mylib::foo2") def foo2_impl(x, y, z, w, n): x.add_(y[0] + w) z.add_(y[1] + n) return (n + n, n * n) @torch.library.impl("mylib::foo3", "cpu", lib=lib) @torch.library.impl_abstract("mylib::foo3") def foo3_impl(x, y, z, w, n): x.add_(y[0] + w) z.add_(y[1] + n) return class M(torch.nn.Module): def forward(self, x, y, z, n): n = torch.ops.mylib.foo1(x, y, z, 2, n) torch.ops.mylib.foo3(x, y, z, 2, n) return torch.ops.mylib.foo2(x, y, z, 2, n) x = torch.randn(3) y = (torch.randn(3), torch.randn(3)) z = torch.randn(3) n = torch.randn(3) orig_args = (x, y, z, n) # TODO Auto_functionalize is not supported on pre_dispatch IR self.check_graph(M(), orig_args, use_pre_dispatch=False) def test_multi_return(self) -> None: """ Test multiple return from a single node (ex. layer_norm has 2 outputs) """ class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w, b): return torch.nn.functional.layer_norm( x, x.size()[1:], weight=w, bias=b, eps=1e-5, ) inputs = ( torch.ones([512, 512], requires_grad=True), torch.ones([512]), torch.ones([512]), ) self.check_graph(MyModule(), inputs) def test_basic(self) -> None: class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): x = x + x x = x * x x = x / x return x, x.clone() inputs = (torch.ones([512], requires_grad=True),) self.check_graph(MyModule(), inputs) def test_dynamic(self) -> None: class DynamicShapeSimpleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, a, b, c) -> torch.Tensor: d = (torch.matmul(a, b) + c) / 2 d_s0 = d.shape[0] d_s1 = d.shape[1] d_s3 = d_s0 * d_s1 e = d.view(d_s3) return torch.cat([e, e]) inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) dim0_ac = torch.export.Dim("dim0_ac") dynamic_shapes = {"a": {0: dim0_ac}, "b": None, "c": {0: dim0_ac}} self.check_graph(DynamicShapeSimpleModel(), inputs, dynamic_shapes) def test_sym_bool(self): class Module(torch.nn.Module): def forward(self, x, y): assert x.size(0) in y return x + y f = Module() self.check_graph(f, (torch.ones(1), torch.ones(3))) def test_shape(self): class Foo(torch.nn.Module): def forward(self, x): z, y = x.size() return z + y + x[0], z inputs = (torch.ones(2, 3),) dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x") dynamic_shapes = {"x": (dim0_x, dim1_x)} self.check_graph(Foo(), inputs, dynamic_shapes) def test_module(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(3, 3) self.relu = torch.nn.ReLU() self.linear2 = torch.nn.Linear(3, 5) def forward(self, x): x = self.linear1(x) x = self.linear1(x) x = torch.nn.functional.relu(x) x = self.linear2(x) return x inputs = (torch.randn(3, 3),) self.check_graph(M(), inputs) def test_module_meta(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = torch.nn.Parameter(torch.ones(3, 3)) def forward(self, x): return self.p + x with torch.device("meta"): mod = M() inputs = (torch.randn(3, 3, device="meta"),) self.check_graph(mod, inputs) def test_cond(self): from functorch.experimental.control_flow import cond inputs = torch.ones(4, 3), torch.zeros(4, 3) class M(torch.nn.Module): def forward(self, x, y): def t(x, y): return x + y def f(x, y): return x - y return cond(x[0][0] > 4, t, f, [x, y]) self.check_graph(M(), inputs) def test_map(self): from functorch.experimental import control_flow def f(x, y): return x + y class Module(torch.nn.Module): def forward(self, xs, y): return control_flow.map(f, xs, y) g = Module() inputs = (torch.ones(3, 2, 2), torch.ones(2)) self.check_graph(g, inputs, _check_meta=False) def test_tensor_tensor_list(self): with torch.library._scoped_library("_export", "FRAGMENT") as lib: lib.define( "_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])", tags=torch.Tag.pt2_compliant_tag, ) def _test_tensor_tensor_list_output(x, y): return y, [x] lib.impl( "_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "CPU", ) lib.impl( "_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "Meta", ) class M(torch.nn.Module): def forward(self, x, y): a, b = torch.ops._export._test_tensor_tensor_list_output.default( x, y ) return a + b[0] self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2))) def test_list_of_optional_tensors(self) -> None: class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y, z): indices = [None, None, torch.tensor([1, 3, 5, 7])] indexed = torch.ops.aten.index.Tensor(x + y, indices) return indexed + z inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4)) self.check_graph(MyModule(), inputs) def test_sym_ite(self): class Foo(torch.nn.Module): def forward(self, x): b = x.shape[0] == 5 ret = torch.sym_ite(b, x.shape[0], x.shape[1]) return ret dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}} self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes) def test_multiple_getitem(self): class M(torch.nn.Module): def forward(self, x): a, b = torch.topk(x, 2) a = a * 2 return a, b ep = torch.export.export(M(), (torch.ones(3),)) # insert another getitem node for node in ep.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor: getitem_0 = node.args[0] with ep.graph.inserting_before(getitem_0): getitem_copy = ep.graph.node_copy(getitem_0) mul_node = ep.graph.call_function( torch.ops.aten.mul.Tensor, (getitem_copy, 2) ) mul_node.meta = copy.copy(getitem_copy.meta) node.args = (getitem_0, mul_node) deserialized_ep = deserialize(serialize(ep)) inp = (torch.randn(3),) orig_res = ep.module()(*inp) res = deserialized_ep.module()(*inp) self.assertTrue(torch.allclose(orig_res[0], res[0])) self.assertTrue(torch.allclose(orig_res[1], res[1])) # The deserialized graph should have deduped getitem calls self.assertExpectedInline( deserialized_ep.graph_module.code.strip("\n"), """\ def forward(self, x): topk_default = torch.ops.aten.topk.default(x, 2); x = None getitem = topk_default[0] getitem_1 = topk_default[1]; topk_default = None mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2) mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None return (mul, getitem_1) """, ) @parametrize( "name,case", get_filtered_export_db_tests(), name_fn=lambda name, case: f"case_{name}", ) def test_exportdb_supported(self, name: str, case: ExportCase) -> None: model = case.model _check_meta = "map" not in name self.check_graph(model, case.example_args, _check_meta=_check_meta) def test_constraints(self): class Module(torch.nn.Module): def forward(self, x, y): n = x.item() torch._check_is_size(n) return y.sum() + torch.ones(n, 5).sum() f = Module() self.check_graph(f, (torch.tensor(3), torch.randn(4, 5))) def test_get_attr(self) -> None: class Module(torch.nn.Module): def forward(self, x): return x + torch.tensor(3) f = Module() self.check_graph(f, (torch.tensor(3),)) def test_get_attr_list(self) -> None: class Module(torch.nn.Module): def forward(self, x): return torch.cat([x, torch.tensor([1, 1])]) f = Module() self.check_graph(f, (torch.tensor([1, 1]),)) @unittest.skipIf(not torch.cuda.is_available(), "Requires cuda") def test_device(self) -> None: class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): conv = self.conv(x) relu = self.relu(conv) mul = relu * 0.5 return mul inp = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda") model = MyModule().eval().cuda() self.check_graph(model, (inp,)) def test_custom_obj_tuple_out(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, x): a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) y = a[0] + a[1] b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) return x + b m = MyModule() inputs = (torch.ones(2, 3),) self.check_graph(m, inputs, strict=False) def test_custom_obj(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, x): a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x) b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a) return x + b m = MyModule() inputs = (torch.ones(2, 3),) self.check_graph(m, inputs, strict=False) def test_custom_obj_list_out(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, x): a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x) y = a[0] + a[1] + a[2] b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) return x + b m = MyModule() inputs = (torch.ones(2, 3),) self.check_graph(m, inputs, strict=False) def test_export_no_inputs(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = torch.ones(3, 3) def forward(self): return self.p * self.p ep = torch.export.export(M(), ()) ep._example_inputs = None roundtrip_ep = deserialize(serialize(ep)) self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) instantiate_parametrized_tests(TestDeserialize) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSchemaVersioning(TestCase): def test_error(self): class Module(torch.nn.Module): def forward(self, x): return x + x f = Module() ep = export(f, (torch.randn(1, 3),)) serialized_program = ExportedProgramSerializer().serialize(ep) serialized_program.exported_program.schema_version.major = -1 with self.assertRaisesRegex( SerializeError, r"Serialized schema version .* does not match our current" ): ExportedProgramDeserializer().deserialize( serialized_program.exported_program, serialized_program.state_dict, serialized_program.constants, serialized_program.example_inputs, ) # We didn't set up kwargs input yet unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_fn_with_kwargs) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSaveLoad(TestCase): def test_save_buffer(self): inp = (torch.tensor([0.1, 0.1]),) class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): x = x + 1 y = x.t() y = y.relu() y = self.linear(y) return y ep = export(Module(), inp) buffer = io.BytesIO() save(ep, buffer) buffer.seek(0) loaded_ep = load(buffer) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) def test_save_file(self): class Foo(torch.nn.Module): def forward(self, x): return x * x f = Foo() inp = (torch.randn(2, 2),) ep = export(f, inp) with tempfile.NamedTemporaryFile() as f: save(ep, f) f.seek(0) loaded_ep = load(f) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) def test_save_path(self): class Foo(torch.nn.Module): def forward(self, x, y): return x + y f = Foo() inp = (torch.tensor([6]), torch.tensor([7])) ep = export(f, inp) with TemporaryFileName() as fname: path = Path(fname) save(ep, path) loaded_ep = load(path) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) def test_save_extra(self): inp = (torch.tensor([0.1, 0.1]),) class Foo(torch.nn.Module): def forward(self, x): return x * x + x f = Foo() ep = export(f, inp) buffer = io.BytesIO() save(ep, buffer, extra_files={"extra.txt": "moo"}) buffer.seek(0) extra_files = {"extra.txt": ""} loaded_ep = load(buffer, extra_files=extra_files) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) self.assertEqual(extra_files["extra.txt"], "moo") def test_version_error(self): class Foo(torch.nn.Module): def forward(self, x): return x + x f = Foo() ep = export(f, (torch.randn(1, 3),)) with tempfile.NamedTemporaryFile() as f: save(ep, f) f.seek(0) # Modify the version with zipfile.ZipFile(f, "a") as zipf: zipf.writestr("version", "-1.1") with self.assertRaisesRegex( RuntimeError, r"Serialized version .* does not match our current" ): f.seek(0) load(f) def test_save_constants(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.tensor(3) def forward(self, x): list_tensor = [torch.tensor(3), torch.tensor(4)] return x + self.a + list_tensor[0] + list_tensor[1] ep = export(Foo(), (torch.tensor(1),)) buffer = io.BytesIO() save(ep, buffer) buffer.seek(0) loaded_ep = load(buffer) inp = (torch.tensor(1),) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSerializeCustomClass(TestCase): def setUp(self): super().setUp() init_torchbind_implementations() def test_custom_class(self): custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4]) class Foo(torch.nn.Module): def forward(self, x): return x + x f = Foo() inputs = (torch.zeros(4, 4),) ep = export(f, inputs) # Replace one of the values with an instance of our custom class for node in ep.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: with ep.graph.inserting_before(node): custom_node = ep.graph.call_function( torch.ops._TorchScriptTesting.take_an_instance.default, (custom_obj,), ) custom_node.meta["val"] = torch.ones(4, 4) custom_node.meta["torch_fn"] = ( "take_an_instance", "take_an_instance", ) arg0, _ = node.args node.args = (arg0, custom_node) serialized_vals = serialize(ep) ep_str = serialized_vals.exported_program.decode("utf-8") assert "class_fqn" in ep_str assert custom_obj._type().qualified_name() in ep_str deserialized_ep = deserialize(serialized_vals) for node in deserialized_ep.graph.nodes: if ( node.op == "call_function" and node.target == torch.ops._TorchScriptTesting.take_an_instance.default ): arg = node.args[0] self.assertTrue(isinstance(arg, torch._C.ScriptObject)) self.assertEqual(arg._type(), custom_obj._type()) self.assertEqual(arg.__getstate__(), custom_obj.__getstate__()) self.assertEqual(arg.top(), 7) def test_custom_class_containing_fake_tensor(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor( torch.rand(2, 3) ) def forward(self, x): return x + self.custom_obj.get() with FakeTensorMode(): f = Foo() inputs = (torch.zeros(2, 3),) with enable_torchbind_tracing(): ep = export(f, inputs, strict=False) serialized_vals = serialize(ep) ep = deserialize(serialized_vals) self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor)) def test_custom_tag_metadata_serialization(self): class Foo(torch.nn.Module): def forward(self, x): return x + x f = Foo() inputs = (torch.zeros(4, 4),) ep = export(f, inputs) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} new_gm.meta["custom"]["f"] = "bar" for node in new_gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: node.meta["custom"] = {} node.meta["custom"]["quantization_tag"] = "foo" new_ep = ep._update(new_gm, ep.graph_signature) serialized_vals = serialize(new_ep) new_ep = deserialize(serialized_vals) self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") counter = 0 for node in new_ep.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: counter += 1 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertEqual(counter, 1) def test_custom_tag_metadata_decomp(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): return self.linear(x) f = Foo() inputs = (torch.ones(2, 2),) ep = export(f, inputs) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} new_gm.meta["custom"]["f"] = "bar" counter = 0 for node in new_gm.graph.nodes: if ( node.op == "call_function" and node.target == torch.ops.aten.linear.default ): counter += 1 node.meta["custom"] = {} node.meta["custom"]["quantization_tag"] = "foo" self.assertEqual(counter, 1) new_ep = ep._update(new_gm, ep.graph_signature) new_ep = new_ep.run_decompositions() self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") counter = 0 for node in new_ep.graph.nodes: if node.op == "call_function": counter += 1 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertTrue(counter > 1) # TODO For some reason, this doesn't work on Windows ONLY. # def test_custom_tag_metadata_reexport(self): # class Foo(torch.nn.Module): # def forward(self, x): # return x + x # # f = Foo() # # inputs = (torch.zeros(4, 4),) # ep = export(f, inputs) # # new_gm = copy.deepcopy(ep.graph_module) # new_gm.meta["custom"] = {} # new_gm.meta["custom"]["f"] = "bar" # # for node in new_gm.graph.nodes: # if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: # node.meta["custom"] = {} # node.meta["custom"]["quantization_tag"] = "foo" # # new_ep = ep._update(new_gm, ep.graph_signature) # new_ep = torch.export.export(new_ep.module(), inputs) # # self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") # counter = 0 # for node in new_ep.graph.nodes: # if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: # counter += 1 # self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") # self.assertEqual(counter, 1) def test_custom_tag_metadata_copy(self): class Foo(torch.nn.Module): def forward(self, x): return x + x f = Foo() inputs = (torch.zeros(4, 4),) ep = export(f, inputs) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} new_gm.meta["custom"]["f"] = "bar" for node in new_gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: node.meta["custom"] = {} node.meta["custom"]["quantization_tag"] = "foo" new_gm = copy.deepcopy(new_gm) self.assertEqual(new_gm.meta["custom"]["f"], "bar") counter = 0 for node in new_gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: counter += 1 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertEqual(counter, 1) if __name__ == "__main__": run_tests()