# Owner(s): ["module: onnx"] import pytorch_test_common import torch import torch._dynamo import torch.fx from torch.onnx._internal.fx.passes import _utils as pass_utils from torch.testing._internal import common_utils class TestFxPasses(common_utils.TestCase): def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self): def func(x, y, z): return x + y + z x = torch.randn(3) y = torch.randn(3) z = torch.randn(3) gm, _ = torch._dynamo.export(func)(x, y, z) torch._dynamo.reset() # Purposely name the nodes in a way that will cause a recursive collision later. # See :func:`set_node_name` for name collision renaming logic. base_name = "tensor" nodes = list(gm.graph.nodes) for i, node in enumerate(nodes[1:]): if i == 0: node.name = base_name else: node.name = f"{base_name}.{i}" # Run `set_node_name` and verify that the names are correct. name_to_node = {node.name: node for node in gm.graph.nodes} pass_utils.set_node_name(nodes[0], base_name, name_to_node) assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}" assert len({node.name for node in nodes}) == len( nodes ), f"Expected all names to be unique, got {nodes}" def test_set_node_name_succeeds_when_no_name_collisions(self): def func(x, y, z): return x + y + z x = torch.randn(3) y = torch.randn(3) z = torch.randn(3) gm, _ = torch._dynamo.export(func)(x, y, z) torch._dynamo.reset() # Run `set_node_name` and verify that the names are correct. new_name = "some_tensor" nodes = list(gm.graph.nodes) name_to_node = {node.name: node for node in nodes} pass_utils.set_node_name(nodes[1], new_name, name_to_node) assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}" assert len({node.name for node in nodes}) == len( nodes ), f"Expected all names to be unique, got {nodes}" def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self): @torch.library.custom_op( "mylibrary::foo_op", device_types="cpu", mutates_args=() ) def foo_op(x: torch.Tensor) -> torch.Tensor: return x + 1 @torch.library.custom_op( "mylibrary::bar_op", device_types="cpu", mutates_args=() ) def bar_op(x: torch.Tensor) -> torch.Tensor: return x + 2 @foo_op.register_fake def _(x): return torch.empty_like(x) @bar_op.register_fake def _(x): return torch.empty_like(x) def func(x, y, z): return foo_op(x) + bar_op(y) + z x = torch.randn(3) y = torch.randn(3) z = torch.randn(3) with self.assertRaises(torch.onnx.OnnxExporterError) as ctx: torch.onnx.dynamo_export(func, x, y, z) inner_exception = ctx.exception.__cause__ self.assertRegex( str(inner_exception), r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op", ) torch._dynamo.reset() @common_utils.instantiate_parametrized_tests class TestModularizePass(common_utils.TestCase): @pytorch_test_common.xfail( error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found", reason="optimizer", ) @common_utils.parametrize( "is_exported_program", [ common_utils.subtest( True, name="exported_program", ), common_utils.subtest( False, name="nn_module", ), ], ) def test_modularize_pass_succeeds_when_submodule_output_is_unused( self, is_exported_program ): # This is an ill-formed model, but exporter must not crash. # It is illegal for submodule to have zero output. For modularization pass it can happen # when the submodule output is unused, so no inner node is connected to any outer # nodes. # However, this also means the entire submodule should be erased by DCE. Hence # it should never occur. # # Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768 class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.unused_relu = torch.nn.ReLU() self.used_gelu = torch.nn.GELU() def forward(self, x, y): result = self.used_gelu(x + y) unused_relu_result = self.unused_relu(x) return result if is_exported_program: model = torch.export.export( TestModule(), args=(torch.randn(3), torch.randn(3)) ) else: model = TestModule() onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3)) model_proto = onnx_program.model_proto function_proto_names = [function.name for function in model_proto.functions] self.assertIn( "torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names ) self.assertFalse(any("ReLU" in name for name in function_proto_names)) @pytorch_test_common.xfail( error_message="'torch_nn_modules_activation_ReLU_relu_1' not found", reason="optimizer", ) @common_utils.parametrize( "is_exported_program", [ common_utils.subtest( True, name="exported_program", ), common_utils.subtest( False, name="nn_module", ), ], ) def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times( self, is_exported_program ): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.relu = torch.nn.ReLU() def forward(self, x, y): out = x + y out = self.relu(out) out = out + x out = self.relu(out) return out if is_exported_program: model = torch.export.export( TestModule(), args=(torch.randn(3), torch.randn(3)) ) else: model = TestModule() onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3)) model_proto = onnx_program.model_proto function_proto_names = [function.name for function in model_proto.functions] self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names) self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names) @pytorch_test_common.xfail( error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found", reason="optimizer", ) @common_utils.parametrize( "is_exported_program", [ common_utils.subtest( True, name="exported_program", ), common_utils.subtest( False, name="nn_module", ), ], ) def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers( self, is_exported_program ): # Minified repro from basic_gnn_edgecnn. class InnerModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(x) class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.inner_module = InnerModule() def forward(self, x, y): out = x + y out = self.inner_module(out) out = out + x out = self.inner_module.relu(out) return out if is_exported_program: model = torch.export.export( TestModule(), args=(torch.randn(3), torch.randn(3)) ) else: model = TestModule() onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3)) model_proto = onnx_program.model_proto function_proto_names = [function.name for function in model_proto.functions] self.assertIn( "torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names ) self.assertIn( "torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names ) # local module qualified name is unstable in test environment depending on different test # invocation methods. self.assertTrue( any("InnerModule_inner_module_1" in name for name in function_proto_names) ) if __name__ == "__main__": common_utils.run_tests()