1# Owner(s): ["oncall: quantization"] 2import copy 3from typing import Any, Dict, Tuple 4 5import torch 6from torch._export import capture_pre_autograd_graph 7from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 8from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 9from torch.ao.quantization.quantizer import Quantizer 10from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 11 get_symmetric_quantization_config, 12 XNNPACKQuantizer, 13) 14from torch.testing._internal.common_quantization import ( 15 NodeSpec as ns, 16 QuantizationTestCase, 17 skipIfNoQNNPACK, 18 TestHelperModules, 19) 20 21 22@skipIfNoQNNPACK 23class TestPT2ERepresentation(QuantizationTestCase): 24 def _test_representation( 25 self, 26 model: torch.nn.Module, 27 example_inputs: Tuple[Any, ...], 28 quantizer: Quantizer, 29 ref_node_occurrence: Dict[ns, int], 30 non_ref_node_occurrence: Dict[ns, int], 31 fixed_output_tol: float = None, 32 output_scale_idx: int = 2, 33 ) -> torch.nn.Module: 34 # resetting dynamo cache 35 torch._dynamo.reset() 36 model = capture_pre_autograd_graph( 37 model, 38 example_inputs, 39 ) 40 model_copy = copy.deepcopy(model) 41 42 model = prepare_pt2e(model, quantizer) 43 # Calibrate 44 model(*example_inputs) 45 model = convert_pt2e(model, use_reference_representation=True) 46 self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) 47 # make sure it runs 48 pt2e_quant_output = model(*example_inputs) 49 50 # TODO: torchdynamo times out when we do this, we can enable numerical checking 51 # after that is fixed 52 model_copy = prepare_pt2e(model_copy, quantizer) 53 # Calibrate 54 model_copy(*example_inputs) 55 model_copy = convert_pt2e(model_copy, use_reference_representation=False) 56 self.checkGraphModuleNodes( 57 model_copy, expected_node_occurrence=non_ref_node_occurrence 58 ) 59 pt2e_quant_output_copy = model_copy(*example_inputs) 60 61 output_tol = None 62 if fixed_output_tol is not None: 63 output_tol = fixed_output_tol 64 else: 65 idx = 0 66 for n in model_copy.graph.nodes: 67 if ( 68 n.target 69 == torch.ops.quantized_decomposed.quantize_per_tensor.default 70 ): 71 idx += 1 72 if idx == output_scale_idx: 73 output_tol = n.args[1] 74 assert output_tol is not None 75 76 # make sure the result is off by one at most in the quantized integer representation 77 self.assertTrue( 78 torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) 79 <= (2 * output_tol + 1e-5) 80 ) 81 82 def test_static_linear(self): 83 class M(torch.nn.Module): 84 def __init__(self) -> None: 85 super().__init__() 86 self.linear = torch.nn.Linear(5, 5) 87 88 def forward(self, x): 89 return self.linear(x) 90 91 quantizer = XNNPACKQuantizer() 92 operator_config = get_symmetric_quantization_config(is_per_channel=False) 93 quantizer.set_global(operator_config) 94 example_inputs = (torch.randn(2, 5),) 95 96 self._test_representation( 97 M().eval(), 98 example_inputs, 99 quantizer, 100 ref_node_occurrence={}, 101 non_ref_node_occurrence={}, 102 ) 103 104 def test_dynamic_linear(self): 105 class M(torch.nn.Module): 106 def __init__(self) -> None: 107 super().__init__() 108 self.linear = torch.nn.Linear(5, 5) 109 110 def forward(self, x): 111 return self.linear(x) 112 113 quantizer = XNNPACKQuantizer() 114 operator_config = get_symmetric_quantization_config( 115 is_per_channel=False, is_dynamic=True 116 ) 117 quantizer.set_global(operator_config) 118 example_inputs = (torch.randn(2, 5),) 119 120 self._test_representation( 121 M().eval(), 122 example_inputs, 123 quantizer, 124 ref_node_occurrence={}, 125 non_ref_node_occurrence={}, 126 fixed_output_tol=1e-4, 127 ) 128 129 def test_conv2d(self): 130 class M(torch.nn.Module): 131 def __init__(self) -> None: 132 super().__init__() 133 self.conv2d = torch.nn.Conv2d(3, 3, 3) 134 135 def forward(self, x): 136 return self.conv2d(x) 137 138 quantizer = XNNPACKQuantizer() 139 operator_config = get_symmetric_quantization_config(is_per_channel=False) 140 quantizer.set_global(operator_config) 141 example_inputs = (torch.randn(1, 3, 3, 3),) 142 143 self._test_representation( 144 M().eval(), 145 example_inputs, 146 quantizer, 147 ref_node_occurrence={}, 148 non_ref_node_occurrence={}, 149 ) 150 151 def test_add(self): 152 class M(torch.nn.Module): 153 def __init__(self) -> None: 154 super().__init__() 155 156 def forward(self, x, y): 157 return x + y 158 159 quantizer = XNNPACKQuantizer() 160 quantization_config = get_symmetric_quantization_config(is_per_channel=True) 161 quantizer.set_global(quantization_config) 162 m_eager = M().eval() 163 164 example_inputs = ( 165 torch.randn(1, 3, 3, 3), 166 torch.randn(1, 3, 3, 3), 167 ) 168 169 self._test_representation( 170 M().eval(), 171 example_inputs, 172 quantizer, 173 ref_node_occurrence={}, 174 non_ref_node_occurrence={}, 175 ) 176 177 def test_add_relu(self): 178 class M(torch.nn.Module): 179 def __init__(self) -> None: 180 super().__init__() 181 182 def forward(self, x, y): 183 out = x + y 184 out = torch.nn.functional.relu(out) 185 return out 186 187 quantizer = XNNPACKQuantizer() 188 operator_config = get_symmetric_quantization_config(is_per_channel=True) 189 quantizer.set_global(operator_config) 190 191 example_inputs = ( 192 torch.randn(1, 3, 3, 3), 193 torch.randn(1, 3, 3, 3), 194 ) 195 ref_node_occurrence = { 196 ns.call_function(out_dtype): 2, 197 } 198 199 self._test_representation( 200 M().eval(), 201 example_inputs, 202 quantizer, 203 ref_node_occurrence=ref_node_occurrence, 204 non_ref_node_occurrence={}, 205 ) 206 207 def test_maxpool2d(self): 208 quantizer = XNNPACKQuantizer() 209 operator_config = get_symmetric_quantization_config(is_per_channel=True) 210 quantizer.set_global(operator_config) 211 m_eager = TestHelperModules.ConvMaxPool2d().eval() 212 213 example_inputs = (torch.randn(1, 2, 2, 2),) 214 215 self._test_representation( 216 m_eager, 217 example_inputs, 218 quantizer, 219 ref_node_occurrence={}, 220 non_ref_node_occurrence={}, 221 ) 222 223 def test_qdq_per_channel(self): 224 """Test representation for quantize_per_channel and dequantize_per_channel op""" 225 226 class M(torch.nn.Module): 227 def __init__(self) -> None: 228 super().__init__() 229 self.linear = torch.nn.Linear(5, 5) 230 231 def forward(self, x): 232 return self.linear(x) 233 234 quantizer = XNNPACKQuantizer() 235 # use per channel quantization for weight 236 operator_config = get_symmetric_quantization_config(is_per_channel=True) 237 quantizer.set_global(operator_config) 238 m_eager = M().eval() 239 240 inputs = [ 241 (torch.randn(1, 5),), 242 (torch.randn(1, 3, 5),), 243 (torch.randn(1, 3, 3, 5),), 244 (torch.randn(1, 3, 3, 3, 5),), 245 ] 246 for example_inputs in inputs: 247 ref_node_occurrence = { 248 ns.call_function( 249 torch.ops.quantized_decomposed.quantize_per_channel.default 250 ): 0, 251 ns.call_function( 252 torch.ops.quantized_decomposed.dequantize_per_channel.default 253 ): 0, 254 } 255 non_ref_node_occurrence = { 256 # quantize_per_channel is folded 257 ns.call_function( 258 torch.ops.quantized_decomposed.quantize_per_channel.default 259 ): 0, 260 ns.call_function( 261 torch.ops.quantized_decomposed.dequantize_per_channel.default 262 ): 1, 263 } 264 265 self._test_representation( 266 M().eval(), 267 example_inputs, 268 quantizer, 269 ref_node_occurrence, 270 non_ref_node_occurrence, 271 output_scale_idx=2, 272 ) 273 274 def test_qdq(self): 275 """Test representation for quantize and dequantize op""" 276 277 class M(torch.nn.Module): 278 def __init__(self) -> None: 279 super().__init__() 280 281 def forward(self, x, y): 282 return x + y 283 284 quantizer = XNNPACKQuantizer() 285 quantization_config = get_symmetric_quantization_config(is_per_channel=True) 286 quantizer.set_global(quantization_config) 287 m_eager = M().eval() 288 289 example_inputs = ( 290 torch.randn(1, 3, 3, 3), 291 torch.randn(1, 3, 3, 3), 292 ) 293 ref_node_occurrence = { 294 ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0, 295 ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0, 296 } 297 non_ref_node_occurrence = { 298 ns.call_function( 299 torch.ops.quantized_decomposed.quantize_per_tensor.default 300 ): 3, 301 ns.call_function( 302 torch.ops.quantized_decomposed.dequantize_per_tensor.default 303 ): 3, 304 } 305 self._test_representation( 306 M().eval(), 307 example_inputs, 308 quantizer, 309 ref_node_occurrence, 310 non_ref_node_occurrence, 311 ) 312