1# Owner(s): ["oncall: quantization"] 2import copy 3import unittest 4 5import torch 6import torch._dynamo as torchdynamo 7from torch.ao.quantization.pt2e.graph_utils import ( 8 find_sequential_partitions, 9 get_equivalent_types, 10 update_equivalent_types_dict, 11) 12from torch.testing._internal.common_utils import IS_WINDOWS, TestCase 13 14 15class TestGraphUtils(TestCase): 16 @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") 17 def test_conv_bn_conv_relu(self): 18 class M(torch.nn.Module): 19 def __init__(self) -> None: 20 super().__init__() 21 self.conv1 = torch.nn.Conv2d(3, 3, 3) 22 self.bn1 = torch.nn.BatchNorm2d(3) 23 self.conv2 = torch.nn.Conv2d(3, 3, 3) 24 self.relu2 = torch.nn.ReLU() 25 26 def forward(self, x): 27 bn_out = self.bn1(self.conv1(x)) 28 relu_out = torch.nn.functional.relu(bn_out) 29 return self.relu2(self.conv2(relu_out)) 30 31 m = M().eval() 32 example_inputs = (torch.randn(1, 3, 5, 5),) 33 34 # program capture 35 m, guards = torchdynamo.export( 36 m, 37 *copy.deepcopy(example_inputs), 38 aten_graph=True, 39 ) 40 fused_partitions = find_sequential_partitions( 41 m, [torch.nn.Conv2d, torch.nn.BatchNorm2d] 42 ) 43 self.assertEqual(len(fused_partitions), 1) 44 fused_partitions = find_sequential_partitions( 45 m, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU] 46 ) 47 self.assertEqual(len(fused_partitions), 1) 48 49 def x(): 50 find_sequential_partitions( 51 m, 52 [ 53 torch.nn.Conv2d, 54 torch.nn.BatchNorm2d, 55 torch.nn.ReLU, 56 torch.nn.functional.conv2d, 57 ], 58 ) 59 60 self.assertRaises(ValueError, x) 61 62 @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") 63 def test_conv_bn_relu(self): 64 class M(torch.nn.Module): 65 def __init__(self) -> None: 66 super().__init__() 67 self.bn1 = torch.nn.BatchNorm2d(3) 68 self.conv2 = torch.nn.Conv2d(3, 3, 3) 69 self.relu2 = torch.nn.ReLU() 70 71 def forward(self, x): 72 bn_out = self.bn1(x) 73 return self.relu2(self.conv2(bn_out)) 74 75 m = M().eval() 76 example_inputs = (torch.randn(1, 3, 5, 5),) 77 78 # program capture 79 m, guards = torchdynamo.export( 80 m, 81 *copy.deepcopy(example_inputs), 82 aten_graph=True, 83 ) 84 fused_partitions = find_sequential_partitions( 85 m, [torch.nn.Conv2d, torch.nn.BatchNorm2d] 86 ) 87 self.assertEqual(len(fused_partitions), 0) 88 fused_partitions = find_sequential_partitions( 89 m, [torch.nn.BatchNorm2d, torch.nn.Conv2d] 90 ) 91 self.assertEqual(len(fused_partitions), 1) 92 fused_partitions = find_sequential_partitions( 93 m, [torch.nn.BatchNorm2d, torch.nn.ReLU] 94 ) 95 self.assertEqual(len(fused_partitions), 0) 96 97 @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") 98 def test_customized_equivalet_types_dict(self): 99 class M(torch.nn.Module): 100 def __init__(self) -> None: 101 super().__init__() 102 self.conv = torch.nn.Conv2d(3, 3, 3) 103 104 def forward(self, x): 105 return torch.nn.functional.relu6(self.conv(x)) 106 107 m = M().eval() 108 example_inputs = (torch.randn(1, 3, 5, 5),) 109 110 # program capture 111 m, guards = torchdynamo.export( 112 m, 113 *copy.deepcopy(example_inputs), 114 aten_graph=True, 115 ) 116 customized_equivalent_types = get_equivalent_types() 117 customized_equivalent_types.append({torch.nn.ReLU6, torch.nn.functional.relu6}) 118 update_equivalent_types_dict(customized_equivalent_types) 119 fused_partitions = find_sequential_partitions( 120 m, 121 [torch.nn.Conv2d, torch.nn.ReLU6], 122 ) 123 self.assertEqual(len(fused_partitions), 1) 124