# Owner(s): ["module: fx"] import os import sys from typing import Callable import torch import torch.nn.functional as F from torch.fx import symbolic_trace from torch.fx.experimental.proxy_tensor import make_fx pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) import unittest from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( SubgraphMatcherWithNameNodeMap, ) from torch.testing._internal.common_utils import IS_WINDOWS, run_tests from torch.testing._internal.jit_utils import JitTestCase class WrapperModule(torch.nn.Module): def __init__(self, fn: Callable): super().__init__() self.fn = fn def forward(self, *args, **kwargs): return self.fn(*args, **kwargs) class TestMatcher(JitTestCase): def test_subgraph_matcher_with_attributes(self): class LargeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self._weight = torch.nn.Parameter(torch.ones(3, 3)) self._bias = torch.nn.Parameter(torch.ones(3, 3)) def forward(self, x): return torch.ops.aten.addmm.default(self._bias, x, self._weight) # Large Model graph: # opcode name target args kwargs # ------------- ------------- ------------------ ------------------- -------- # placeholder x x () {} # get_attr _bias _bias () {} # get_attr _weight _weight () {} # call_function addmm_default aten.addmm.default (_bias, x, _weight) {} # output output output (addmm_default,) {} large_model_graph = symbolic_trace(LargeModel()).graph class PatternModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) def forward(self, x): return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph subgraph_matcher = SubgraphMatcher(pattern_graph) match_result = subgraph_matcher.match(large_model_graph) self.assertEqual(len(match_result), 1) def test_subgraph_matcher_with_list(self): def original(x, y): return torch.ops.aten.view(x, [5, y.shape[0]]) original_graph = torch.fx.symbolic_trace(original).graph def pattern(x, y, z): return torch.ops.aten.view(x, [z, y.shape[0]]) pattern_graph = torch.fx.symbolic_trace(pattern).graph subgraph_matcher = SubgraphMatcher(pattern_graph) match_result = subgraph_matcher.match(original_graph) self.assertEqual(len(match_result), 1) def test_subgraph_matcher_with_list_bad(self): def original(x, y): return torch.ops.aten._reshape_alias_copy.default( x, [1, y.shape[0]], [y.shape[1], y.shape[1]] ) original_graph = torch.fx.symbolic_trace(original).graph def pattern(x, y, b): return torch.ops.aten._reshape_alias_copy.default( x, [b, y.shape[0], y.shape[1]], [y.shape[1]] ) pattern_graph = torch.fx.symbolic_trace(pattern).graph subgraph_matcher = SubgraphMatcher(pattern_graph) match_result = subgraph_matcher.match(original_graph) self.assertEqual(len(match_result), 0) def test_subgraph_matcher_ignore_literals(self): def original(x): return x + 1 original_graph = make_fx(original)(torch.ones(3, 3)).graph original_graph.eliminate_dead_code() def pattern(x): return x + 2 pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph pattern_graph.eliminate_dead_code() subgraph_matcher = SubgraphMatcher(pattern_graph) match_result = subgraph_matcher.match(original_graph) self.assertEqual(len(match_result), 0) subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True) match_result = subgraph_matcher.match(original_graph) self.assertEqual(len(match_result), 1) def test_variatic_arg_matching(self): inputs = (torch.randn(20, 16, 50, 32),) def maxpool(x, kernel_size, stride, padding, dilation): return torch.ops.aten.max_pool2d_with_indices.default( x, kernel_size, stride, padding, dilation ) maxpool_graph = torch.fx.symbolic_trace(maxpool).graph maxpool_matcher = SubgraphMatcher(maxpool_graph) match_result = maxpool_matcher.match(maxpool_graph) self.assertEqual(len(match_result), 1) # Graph only contains "stride" argument maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval() maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph match_s_result = maxpool_matcher.match(maxpool_s_graph) self.assertEqual(len(match_s_result), 1) # Graph only contains "padding" argument maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1) maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph match_p_result = maxpool_matcher.match(maxpool_p_graph) self.assertEqual(len(match_p_result), 1) # Graph only contains "stride, padding" argument maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1) maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph match_sp_result = maxpool_matcher.match(maxpool_sp_graph) self.assertEqual(len(match_sp_result), 1) @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") def test_split_to_graph_and_name_node_map(self): """Testing the internal helper function for splitting the pattern graph""" from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( _split_to_graph_and_name_node_map, ) def pattern(x, weight): conv = F.conv2d(x, weight) relu = F.relu(conv) relu_mul_by_two = relu * 2 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} from torch._export import capture_pre_autograd_graph example_inputs = ( torch.randn(1, 3, 3, 3) * 10, torch.randn(3, 3, 3, 3), ) pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) before_split_res = pattern_gm(*example_inputs) pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) after_split_res = pattern_gm(*example_inputs) self.assertEqual(before_split_res[0], after_split_res[0]) self.assertEqual(before_split_res[1], after_split_res[1]) @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") def test_matcher_with_name_node_map_function(self): """Testing SubgraphMatcherWithNameNodeMap with function pattern""" def target_graph(x, weight): x = x * 2 weight = weight * 3 conv = F.conv2d(x, weight) relu = F.relu(conv) relu2 = relu * 2 return relu + relu2 def pattern(x, weight): conv = F.conv2d(x, weight) relu = F.relu(conv) relu_mul_by_two = relu * 2 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} from torch._export import capture_pre_autograd_graph example_inputs = ( torch.randn(1, 3, 3, 3) * 10, torch.randn(3, 3, 3, 3), ) pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) target_gm = capture_pre_autograd_graph( WrapperModule(target_graph), example_inputs ) internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map assert "conv" in name_node_map assert "relu" in name_node_map name_node_map["conv"].meta["custom_annotation"] = "annotation" # check if we correctly annotated the target graph module for n in target_gm.graph.nodes: if n == name_node_map["conv"]: assert ( "custom_annotation" in n.meta and n.meta["custom_annotation"] == "annotation" ) @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") def test_matcher_with_name_node_map_module(self): """Testing SubgraphMatcherWithNameNodeMap with module pattern""" class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): return self.linear(x) class Pattern(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): linear = self.linear(x) # Note: we can't put "weight": self.linear.weight in dictionary since # nn.Parameter is not an allowed output type in dynamo return linear, {"linear": linear, "x": x} from torch._export import capture_pre_autograd_graph example_inputs = (torch.randn(3, 5),) pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs) matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) target_gm = capture_pre_autograd_graph(M(), example_inputs) internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map assert "linear" in name_node_map assert "x" in name_node_map name_node_map["linear"].meta["custom_annotation"] = "annotation" # check if we correctly annotated the target graph module for n in target_gm.graph.nodes: if n == name_node_map["linear"]: assert ( "custom_annotation" in n.meta and n.meta["custom_annotation"] == "annotation" ) if __name__ == "__main__": run_tests()