1# Owner(s): ["oncall: quantization"] 2 3import torch 4import torch.nn as nn 5import torch.nn.functional as F 6import torch.ao.nn.intrinsic.quantized as nniq 7import torch.ao.nn.quantized as nnq 8from torch.ao.quantization import default_qconfig 9from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver 10from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx 11from torch.ao.quantization.fx._equalize import ( 12 _InputEqualizationObserver, 13 _WeightEqualizationObserver, 14 calculate_equalization_scale, 15 default_equalization_qconfig, 16 _convert_equalization_ref, 17 get_layer_sqnr_dict, 18 get_equalization_qconfig_dict, 19) 20 21from torch.testing._internal.common_quantization import ( 22 NodeSpec as ns, 23 QuantizationTestCase, 24 SingleLayerLinearModel, 25 TwoLayerLinearModel, 26 LinearAddModel, 27 SingleLayerFunctionalLinearModel, 28 TwoLayerFunctionalLinearModel, 29 FunctionalLinearAddModel, 30 ConvModel, 31 TwoLayerConvModel, 32 SingleLayerFunctionalConvModel, 33 TwoLayerFunctionalConvModel, 34 skipIfNoFBGEMM, 35 LinearReluModel, 36 LinearReluLinearModel, 37 LinearReluAddModel, 38 FunctionalLinearReluModel, 39 FunctionalLinearReluLinearModel, 40 ConvReluModel, 41 ConvReluConvModel, 42 ConvReluAddModel, 43 FunctionalConvReluModel, 44 FunctionalConvReluConvModel, 45) 46 47# Standard Libraries 48import copy 49import numpy as np 50 51# Testing utils 52from hypothesis import given 53from hypothesis import strategies as st 54 55 56default_qconfig_dict = {"": default_qconfig} 57 58specific_qconfig_dict = { 59 "": None, 60 "object_type": [(nn.Linear, default_qconfig), 61 (F.linear, default_qconfig), 62 (nn.ReLU, default_qconfig), 63 (F.relu, default_qconfig), 64 (nn.Conv2d, default_qconfig), 65 (F.conv2d, default_qconfig)] 66} 67 68default_equalization_qconfig_dict = { 69 "": None, 70 "object_type": [(nn.Linear, default_equalization_qconfig), 71 (F.linear, default_equalization_qconfig), 72 (nn.ReLU, default_equalization_qconfig), 73 (F.relu, default_equalization_qconfig), 74 (nn.Conv2d, default_equalization_qconfig), 75 (F.conv2d, default_equalization_qconfig)] 76} 77 78 79class TestEqualizeFx(QuantizationTestCase): 80 def channel_minmax(self, input, axis=1): 81 ''' Finds the min/max of inputs associated with a specific channel 82 ''' 83 size_of_tensor_dim = input.ndim 84 axis_list = list(range(size_of_tensor_dim)) 85 axis_list.remove(axis) 86 axis_list.sort(reverse=True) 87 88 mins = input.copy() 89 maxs = input.copy() 90 for a in axis_list: 91 mins = mins.min(a) 92 maxs = maxs.max(a) 93 94 return (mins, maxs) 95 96 @given(ndim=st.sampled_from((2, 3, 4, 5)), 97 input_qdtype=st.sampled_from((torch.qint8, torch.quint8)), 98 input_qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), 99 weight_qdtype=st.sampled_from((torch.qint8, torch.quint8)), 100 weight_qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, 101 torch.per_channel_affine_float_qparams))) 102 def test_input_weight_eq_observer(self, ndim, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme): 103 sizes = [] 104 for _ in range((ndim - 1) * 2): 105 sizes.append(np.random.randint(2, 10)) 106 107 channel = np.random.randint(1, 10) 108 if ndim == 2: 109 x = np.random.random(size=(sizes[0], channel)) 110 w = np.random.random(size=(sizes[1], channel)) 111 elif ndim == 3: 112 x = np.random.random(size=(sizes[0], channel, sizes[1])) 113 w = np.random.random(size=(sizes[2], channel, sizes[3])) 114 elif ndim == 4: 115 x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2])) 116 w = np.random.random(size=(sizes[3], channel, sizes[4], sizes[5])) 117 elif ndim == 5: 118 x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2], sizes[3])) 119 w = np.random.random(size=(sizes[4], channel, sizes[5], sizes[6], sizes[7])) 120 121 x = (x * 10).round(decimals=2).astype(np.float32) 122 w = (w * 10).round(decimals=2).astype(np.float32) 123 124 input_eq_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme) 125 weight_eq_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme) 126 127 ret_x = input_eq_obs(torch.tensor(x)) 128 ret_w = weight_eq_obs(torch.tensor(w)) 129 self.assertEqual((ret_x, ret_w), (x, w)) 130 131 # Check the min/max input columns are correct 132 ref_min_inputs, ref_max_inputs = self.channel_minmax(x) 133 min_inputs, max_inputs = input_eq_obs.get_input_minmax() 134 self.assertEqual(min_inputs, torch.tensor(ref_min_inputs, dtype=torch.float32)) 135 self.assertEqual(max_inputs, torch.tensor(ref_max_inputs, dtype=torch.float32)) 136 137 # Check the min/max weight columns are correct 138 ref_min_weights_col, ref_max_weights_col = self.channel_minmax(w) 139 min_weights_col, max_weights_col = weight_eq_obs.get_weight_col_minmax() 140 self.assertEqual(min_weights_col, torch.tensor(ref_min_weights_col, dtype=torch.float32)) 141 self.assertEqual(max_weights_col, torch.tensor(ref_max_weights_col, dtype=torch.float32)) 142 143 # Check the equalization scale is correct 144 equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs) 145 ref_equalization_scale = np.sqrt((ref_max_weights_col - ref_min_weights_col) / 146 (ref_max_inputs - ref_min_inputs)) 147 self.assertEqual(equalization_scale, torch.tensor(ref_equalization_scale, dtype=torch.float32)) 148 149 input_eq_obs.set_equalization_scale(equalization_scale) 150 weight_eq_obs.set_equalization_scale(equalization_scale) 151 152 # Check the input scale/zero-point values 153 min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() 154 input_quant_obs = MinMaxObserver(dtype=input_qdtype, qscheme=input_qscheme) 155 input_quant_obs.min_val = min_input_scaled 156 input_quant_obs.max_val = max_input_scaled 157 input_qparams = input_quant_obs.calculate_qparams() 158 159 ref_min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale) 160 ref_min_input_scaled = min(0, ref_min_input_scaled) 161 ref_max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale) 162 ref_max_input_scaled = max(0, ref_max_input_scaled) 163 164 if input_qscheme == torch.per_tensor_symmetric: 165 ref_scale = 2 * max(abs(ref_min_input_scaled), ref_max_input_scaled) / 255 166 ref_zero_point = 0 if input_qdtype is torch.qint8 else 128 167 else: 168 ref_scale = (ref_max_input_scaled - ref_min_input_scaled) / 255 169 quant_min = -128 if input_qdtype is torch.qint8 else 0 170 quant_max = 127 if input_qdtype is torch.qint8 else 255 171 ref_zero_point = quant_min - np.round(ref_min_input_scaled / ref_scale) 172 np.clip(ref_zero_point, quant_min, quant_max) 173 174 self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0) 175 self.assertEqual(input_qparams[1].item(), ref_zero_point) 176 177 # During input-weight equalization, we will scale the weights so that 178 # the following weight quantized observer will have the correct scaled qparams 179 # Check the weight scale/zero-point values of the quantized observer 180 weight_quant_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=weight_qdtype, qscheme=weight_qscheme) 181 182 # Scale the weights for input-weight equalization 183 new_shape = [1] * w.ndim 184 new_shape[1] = w.shape[1] 185 ref_w_scaled = w * np.reciprocal(ref_equalization_scale.reshape(tuple(new_shape))) 186 187 w = torch.tensor(w) 188 new_shape[1] = w.size(1) 189 w_scaled = torch.mul(w, torch.reciprocal(equalization_scale.view(new_shape))) 190 191 self.assertEqual(w_scaled, ref_w_scaled) 192 193 # Call forward on the weight quantization observer 194 weight_quant_obs(w_scaled) 195 196 # Check the min/max weight rows are correct 197 ref_min_weights_scaled, ref_max_weights_scaled = self.channel_minmax(ref_w_scaled) 198 self.assertEqual(weight_quant_obs.min_val, torch.tensor(ref_min_weights_scaled, dtype=torch.float32)) 199 self.assertEqual(weight_quant_obs.max_val, torch.tensor(ref_max_weights_scaled, dtype=torch.float32)) 200 201 weight_qparams = weight_quant_obs.calculate_qparams() 202 203 if weight_qscheme == torch.per_channel_symmetric: 204 ref_min_weights_scaled = np.minimum(np.zeros(ref_min_weights_scaled.shape), ref_min_weights_scaled) 205 ref_max_weights_scaled = np.maximum(np.zeros(ref_max_weights_scaled.shape), ref_max_weights_scaled) 206 207 ref_scales = 2 * np.maximum(np.abs(ref_min_weights_scaled), ref_max_weights_scaled) / 255 208 ref_zero_points = np.zeros_like( 209 ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(ref_scales) * 128 210 elif weight_qscheme == torch.per_channel_affine_float_qparams: 211 ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255 212 ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales)) 213 ref_zero_points = -1 * ref_min_weights_scaled / ref_scales 214 else: 215 ref_min_weights_scaled = np.minimum(np.zeros_like(ref_min_weights_scaled), ref_min_weights_scaled) 216 ref_max_weights_scaled = np.maximum(np.zeros_like(ref_max_weights_scaled), ref_max_weights_scaled) 217 218 ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255 219 ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0 220 ref_zero_points = ref_zero_points - np.round(ref_min_weights_scaled / ref_scales) 221 222 self.assertEqual(weight_qparams[0], torch.tensor( 223 ref_scales, dtype=weight_qparams[0].dtype), rtol=1e-5, atol=0.0001) 224 self.assertEqual(weight_qparams[1], torch.tensor( 225 ref_zero_points, dtype=weight_qparams[1].dtype), rtol=1e-5, atol=1) 226 227 def test_input_weight_equalization_prepare(self): 228 """ Tests that graphs created after prepare_fx is as expected 229 """ 230 231 single_nn_layer_node_occurrence = { 232 ns.call_module(_InputEqualizationObserver): 1, 233 ns.call_module(MinMaxObserver): 2, 234 } 235 236 two_nn_layer_node_occurrence = { 237 ns.call_module(_InputEqualizationObserver): 2, 238 ns.call_module(MinMaxObserver): 3, 239 } 240 241 single_F_layer_node_occurrence = { 242 ns.call_module(_InputEqualizationObserver): 1, 243 ns.call_module(_WeightEqualizationObserver): 1, 244 ns.call_module(MinMaxObserver): 3, 245 } 246 247 two_F_layer_node_occurrence = { 248 ns.call_module(_InputEqualizationObserver): 2, 249 ns.call_module(_WeightEqualizationObserver): 2, 250 ns.call_module(MinMaxObserver): 5, 251 } 252 253 fp_F_layer_node_occurrence = { 254 ns.call_module(_InputEqualizationObserver): 2, 255 ns.call_module(_WeightEqualizationObserver): 2, 256 ns.call_module(MinMaxObserver): 6, 257 } 258 259 tests = [(SingleLayerLinearModel, single_nn_layer_node_occurrence), 260 (TwoLayerLinearModel, two_nn_layer_node_occurrence), 261 (TwoLayerFunctionalLinearModel, two_F_layer_node_occurrence), 262 (FunctionalLinearAddModel, fp_F_layer_node_occurrence), 263 (LinearReluModel, single_nn_layer_node_occurrence), 264 (LinearReluLinearModel, two_nn_layer_node_occurrence), 265 (FunctionalLinearReluModel, single_F_layer_node_occurrence), 266 (FunctionalLinearReluLinearModel, two_F_layer_node_occurrence), 267 (ConvModel, single_nn_layer_node_occurrence), 268 (TwoLayerConvModel, two_nn_layer_node_occurrence), 269 (TwoLayerFunctionalConvModel, two_F_layer_node_occurrence), 270 (ConvReluModel, single_nn_layer_node_occurrence), 271 (ConvReluConvModel, two_nn_layer_node_occurrence), 272 (FunctionalConvReluModel, single_F_layer_node_occurrence), 273 (FunctionalConvReluConvModel, two_F_layer_node_occurrence)] 274 275 for (M, node_occurrence) in tests: 276 m = M().eval() 277 example_inputs = m.get_example_inputs() 278 prepared = prepare_fx( 279 m, 280 specific_qconfig_dict, 281 example_inputs=example_inputs, 282 _equalization_config=default_equalization_qconfig_dict) 283 self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) 284 285 def test_input_weight_equalization_branching(self): 286 """ Tests that graphs containing branches are prepared correctly. 287 Specifically, equalization observers should not be inserted in front of 288 branches in which both initial layers in the branches plan to be 289 quantized. 290 """ 291 292 # Tests that we do not add an equalization observer due to both initial 293 # nodes in the branch containing layers that need to be equalized. 294 # Note that this should print out 2 warning messages for not being able 295 # to equalize layers linear1 and linear1 because it is part of a branch 296 class TestBranchingWithoutEqualizationModel(nn.Module): 297 def __init__(self) -> None: 298 super().__init__() 299 self.linear1 = nn.Linear(5, 5) 300 self.linear2 = nn.Linear(5, 5) 301 302 def forward(self, x): 303 y = self.linear1(x) 304 z = self.linear2(x) 305 return torch.add(y, z) 306 307 no_eq_branching_node_occurrence = { 308 ns.call_module(_InputEqualizationObserver): 0, 309 ns.call_module(MinMaxObserver): 3, 310 } 311 312 m = TestBranchingWithoutEqualizationModel().eval() 313 example_inputs = (torch.rand(1, 5),) 314 prepared = prepare_fx( 315 m, specific_qconfig_dict, example_inputs=example_inputs, 316 _equalization_config=default_equalization_qconfig_dict) 317 self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence) 318 319 # Tests that we will add an equalization observer because there is only 320 # one initial node in the branch that needs to be equalized 321 class TestBranchingWithEqualizationModel(nn.Module): 322 def __init__(self) -> None: 323 super().__init__() 324 self.linear1 = nn.Linear(5, 5) 325 326 def forward(self, x): 327 y = self.linear1(x) 328 z = torch.add(x, 5) 329 return torch.add(y, z) 330 331 eq_branching_node_occurrence = { 332 ns.call_module(_InputEqualizationObserver): 1, 333 ns.call_module(MinMaxObserver): 2, 334 } 335 336 m = TestBranchingWithEqualizationModel().eval() 337 example_inputs = (torch.randn(1, 5),) 338 prepared = prepare_fx( 339 m, specific_qconfig_dict, example_inputs=example_inputs, 340 _equalization_config=default_equalization_qconfig_dict) 341 self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence) 342 343 @skipIfNoFBGEMM 344 def test_input_weight_equalization_convert(self): 345 """ Tests that the modified model for equalization (before quantization) 346 returns the same output as the original model 347 """ 348 349 tests = [(SingleLayerLinearModel, 2), (LinearAddModel, 2), (TwoLayerLinearModel, 2), 350 (SingleLayerFunctionalLinearModel, 2), (FunctionalLinearAddModel, 2), 351 (TwoLayerFunctionalLinearModel, 2), 352 (LinearReluModel, 2), (LinearReluLinearModel, 2), (LinearReluAddModel, 2), 353 (FunctionalLinearReluModel, 2), (FunctionalLinearReluLinearModel, 2), 354 (ConvModel, 4), (TwoLayerConvModel, 4), (SingleLayerFunctionalConvModel, 4), 355 (TwoLayerFunctionalConvModel, 4), 356 (ConvReluModel, 4), (ConvReluConvModel, 4), (ConvReluAddModel, 4), 357 (FunctionalConvReluModel, 4), (FunctionalConvReluConvModel, 4)] 358 359 for (M, ndim) in tests: 360 m = M().eval() 361 362 if ndim == 2: 363 x = torch.rand((5, 5)) 364 elif ndim == 4: 365 x = torch.rand((16, 3, 224, 224)) 366 367 example_inputs = (x,) 368 prepared = prepare_fx( 369 copy.deepcopy(m), 370 specific_qconfig_dict, 371 example_inputs=example_inputs, 372 _equalization_config=default_equalization_qconfig_dict 373 ) 374 output = prepared(x) 375 376 convert_ref = _convert_equalization_ref(prepared) 377 convert_ref_output = convert_ref(x) 378 379 prepared = prepare_fx( 380 m, specific_qconfig_dict, 381 example_inputs=example_inputs, 382 _equalization_config=default_equalization_qconfig_dict) 383 prepared(x) 384 convert_fx(prepared) # Check if compile 385 self.assertEqual(output, convert_ref_output) 386 387 def calculate_equalization_scale_ref(self, x, w): 388 """ Calculates the equalization scale based on the input and weight 389 """ 390 min_inputs = x.min(axis=0) 391 max_inputs = x.max(axis=0) 392 393 min_weights_col = w.min(axis=0) 394 max_weights_col = w.max(axis=0) 395 396 equalization_scale = np.sqrt((max_weights_col - min_weights_col) / 397 (max_inputs - min_inputs)) 398 return equalization_scale 399 400 def get_expected_eq_scales(self, model, x): 401 """ For each module in the graph, we want to calculate the equalization 402 scale at that point. This only works for models containing single or 403 connected linear layers. 404 """ 405 exp_eq_scales = [] 406 for _, module in model.named_children(): 407 weight = module.weight.detach().numpy() 408 bias = module.bias.detach().numpy() 409 410 eq_scale = self.calculate_equalization_scale_ref(x, weight) 411 exp_eq_scales.append(eq_scale) 412 413 x = x @ weight.T + bias 414 415 return exp_eq_scales 416 417 def test_input_weight_equalization_equalization_scales(self): 418 """ After applying the equalization functions, check if the equalization 419 scales are the expected values 420 """ 421 422 tests = [SingleLayerLinearModel, TwoLayerLinearModel, 423 SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel] 424 425 x = torch.rand((5, 5)) 426 for M in tests: 427 m = M().eval() 428 exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) 429 430 example_inputs = (x,) 431 prepared = prepare_fx( 432 m, specific_qconfig_dict, 433 example_inputs=example_inputs, 434 _equalization_config=default_equalization_qconfig_dict) 435 prepared(*example_inputs) 436 convert_ref = _convert_equalization_ref(prepared) 437 convert_ref(x) 438 439 counter = 0 440 for node in convert_ref.graph.nodes: 441 if 'equalization_scale' in node.name and node.op == 'get_attr': 442 self.assertEqual(convert_ref.get_buffer(str(node.target)).reshape(-1), exp_eq_scales[counter]) 443 counter += 1 444 445 def get_expected_weights_bias(self, model, x, exp_eq_scales): 446 """ For each module in the graph, we want to calculate the expected 447 scaled weight and bias values. This only works for models containing 448 single or connected linear layers. 449 """ 450 exp_weights = [] 451 exp_bias = [] 452 for i, (_, module) in enumerate(model.named_children()): 453 weight = module.weight.detach().numpy() 454 bias = module.bias.detach().numpy() 455 456 scaled_weight = weight * np.reciprocal(exp_eq_scales[i]) 457 scaled_bias = bias 458 if i + 1 < len(exp_eq_scales): 459 scaled_weight = (scaled_weight.T * exp_eq_scales[i + 1]).T 460 scaled_bias = (scaled_bias.T * exp_eq_scales[i + 1]).T 461 462 exp_weights.append(scaled_weight) 463 exp_bias.append(scaled_bias) 464 465 x = x @ weight.T + bias 466 467 return exp_weights, exp_bias 468 469 def test_input_weight_equalization_weights_bias(self): 470 """ After applying the equalization functions check if the weights and 471 biases are as expected 472 """ 473 474 tests = [SingleLayerLinearModel, TwoLayerLinearModel, 475 SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel] 476 477 x = torch.rand((5, 5)) 478 for M in tests: 479 m = M().eval() 480 exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) 481 exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales) 482 483 example_inputs = (x,) 484 prepared = prepare_fx( 485 m, specific_qconfig_dict, 486 example_inputs=example_inputs, 487 _equalization_config=default_equalization_qconfig_dict) 488 prepared(x) 489 convert_ref = _convert_equalization_ref(prepared) 490 convert_ref(x) 491 492 modules = dict(convert_ref.named_modules(remove_duplicate=False)) 493 counter = 0 494 for node in convert_ref.graph.nodes: 495 if node.op == 'call_module' and isinstance(modules[str(node.target)], nn.Linear): 496 self.assertEqual(modules[str(node.target)].weight, exp_weights[counter]) 497 self.assertEqual(modules[str(node.target)].bias, exp_bias[counter]) 498 counter += 1 499 500 def get_expected_inp_act_vals(self, model, x, exp_eq_scales, exp_weights, exp_bias): 501 """ For each module in the graph, we want to calculate the expected 502 min/max values for every input activation node. This only works for 503 models containing only single or connected linear layers. 504 """ 505 x = x * exp_eq_scales[0] 506 507 exp_inp_activation_vals = [] 508 for i, _ in enumerate(model.named_children()): 509 exp_inp_activation_vals.append((x.min(), x.max())) 510 x = x @ exp_weights[i].T + exp_bias[i] 511 512 exp_inp_activation_vals.append((x.min(), x.max())) 513 return exp_inp_activation_vals 514 515 def get_expected_weight_act_vals(self, exp_weights): 516 """ For each module in the graph, we want to calculate the expected 517 min/max values for every weight activation node. This is assuming that 518 the weight observers are all MinMaxObservers. 519 """ 520 521 exp_weight_activation_vals = [] 522 for w in exp_weights: 523 exp_weight_activation_vals.append((w.min(), w.max())) 524 525 return exp_weight_activation_vals 526 527 def test_input_weight_equalization_activation_values(self): 528 """ After applying the equalization functions check if the input 529 observer's min/max values are as expected 530 """ 531 532 tests = [SingleLayerLinearModel, TwoLayerLinearModel, SingleLayerFunctionalLinearModel] 533 534 x = torch.rand((5, 5)) 535 torch.manual_seed(0) 536 for M in tests: 537 m = M().eval() 538 exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) 539 exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales) 540 exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias) 541 exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights) 542 543 example_inputs = (x,) 544 prepared = prepare_fx( 545 m, specific_qconfig_dict, 546 example_inputs=example_inputs, 547 _equalization_config=default_equalization_qconfig_dict) 548 prepared(x) 549 convert_ref = _convert_equalization_ref(prepared) 550 convert_ref(x) 551 552 modules = dict(convert_ref.named_modules(remove_duplicate=False)) 553 inp_counter = 0 554 weight_counter = 0 555 for node in convert_ref.graph.nodes: 556 users = list(node.users) 557 if node.op == 'call_module' and isinstance(modules[str(node.target)], MinMaxObserver): 558 if len(users) == 1 and users[0].target == torch.nn.functional.linear and users[0].args[1] == node: 559 # Check min/max values of weight activation layers 560 exp_min_val, exp_max_val = exp_weight_act_vals[weight_counter] 561 self.assertEqual(modules[str(node.target)].min_val, exp_min_val) 562 self.assertEqual(modules[str(node.target)].max_val, exp_max_val) 563 weight_counter += 1 564 else: 565 # Check min/max values of input activation layers 566 exp_min_val, exp_max_val = exp_inp_act_vals[inp_counter] 567 self.assertEqual(modules[str(node.target)].min_val, exp_min_val) 568 self.assertEqual(modules[str(node.target)].max_val, exp_max_val) 569 inp_counter += 1 570 571 572 def check_orig_and_eq_graphs(self, orig_model, eq_model): 573 """ Given a non-equalized model and an equalized model, check that the 574 graphs are structured in the same way, except the equalized model has 575 additional 'equalization_scale' and 'mul' nodes. 576 """ 577 orig_idx = 0 578 orig_nodes = list(orig_model.graph.nodes) 579 orig_modules = dict(orig_model.named_modules(remove_duplicate=False)) 580 581 eq_idx = 0 582 eq_nodes = list(eq_model.graph.nodes) 583 eq_modules = dict(eq_model.named_modules(remove_duplicate=False)) 584 585 while orig_idx < len(orig_nodes) and eq_idx < len(eq_nodes): 586 if 'equalization_scale' in eq_nodes[eq_idx].name and 'mul' in eq_nodes[eq_idx + 1].name: 587 # Skip the equalization and mul nodes 588 eq_idx += 2 589 continue 590 elif orig_nodes[orig_idx].op != eq_nodes[eq_idx].op: 591 return False 592 elif orig_nodes[orig_idx].op == 'call_module': 593 # Check that the type of call_modules are the same (ex. nn.Linear, MinMaxObserver) 594 orig_node = orig_nodes[orig_idx] 595 eq_node = eq_nodes[eq_idx] 596 if type(orig_modules[orig_node.target]) is not type(eq_modules[eq_node.target]): 597 return False 598 elif orig_nodes[orig_idx].op == 'call_function': 599 # Check that the call_functions are the same (ex. F.linear) 600 orig_node = orig_nodes[orig_idx] 601 eq_node = eq_nodes[eq_idx] 602 if orig_node.target != eq_node.target: 603 return False 604 605 eq_idx += 1 606 orig_idx += 1 607 608 return True 609 610 @skipIfNoFBGEMM 611 def test_input_weight_equalization_graphs(self): 612 """ Tests that the modified model for equalization has the same graph 613 structure as the model without equalization (before and after 614 quantization). 615 """ 616 617 linear_node_list = [ 618 ns.call_function(torch.mul), 619 ns.call_function(torch.quantize_per_tensor), 620 ns.call_module(nnq.Linear), 621 ns.call_method('dequantize') 622 ] 623 624 linearAdd_node_list = [ 625 ns.call_function(torch.mul), 626 ns.call_function(torch.quantize_per_tensor), 627 ns.call_module(nnq.Linear), 628 ns.call_method('dequantize'), 629 ns.call_function(torch.add), 630 ns.call_function(torch.mul), 631 ns.call_function(torch.quantize_per_tensor), 632 ns.call_module(nnq.Linear), 633 ns.call_method('dequantize') 634 ] 635 636 linear2_node_list = [ 637 ns.call_function(torch.mul), 638 ns.call_function(torch.quantize_per_tensor), 639 ns.call_module(nnq.Linear), 640 ns.call_module(nnq.Linear), 641 ns.call_method('dequantize') 642 ] 643 644 functionalLinear_node_list = [ 645 ns.call_function(torch.mul), 646 ns.call_function(torch.quantize_per_tensor), 647 ns.call_function(torch.ops.quantized.linear), 648 ns.call_method('dequantize') 649 ] 650 651 functionalLinearAdd_node_list = [ 652 ns.call_function(torch.mul), 653 ns.call_function(torch.quantize_per_tensor), 654 ns.call_function(torch.ops.quantized.linear), 655 ns.call_method('dequantize'), 656 ns.call_function(torch.add), 657 ns.call_function(torch.mul), 658 ns.call_function(torch.quantize_per_tensor), 659 ns.call_function(torch.ops.quantized.linear), 660 ns.call_method('dequantize') 661 ] 662 663 functionalLinear2_node_list = [ 664 ns.call_function(torch.mul), 665 ns.call_function(torch.quantize_per_tensor), 666 ns.call_function(torch.ops.quantized.linear), 667 ns.call_function(torch.ops.quantized.linear), 668 ns.call_method('dequantize') 669 ] 670 671 linearRelu_node_list = [ 672 ns.call_function(torch.mul), 673 ns.call_function(torch.quantize_per_tensor), 674 ns.call_module(nniq.LinearReLU), 675 ns.call_method('dequantize') 676 ] 677 678 linearReluLinear_node_list = [ 679 ns.call_function(torch.mul), 680 ns.call_function(torch.quantize_per_tensor), 681 ns.call_module(nniq.LinearReLU), 682 ns.call_module(nnq.Linear), 683 ns.call_method('dequantize') 684 ] 685 686 functionalLinearRelu_node_list = [ 687 ns.call_function(torch.mul), 688 ns.call_function(torch.quantize_per_tensor), 689 ns.call_function(torch.ops.quantized.linear_relu), 690 ns.call_method('dequantize') 691 ] 692 693 functionalLinearReluLinear_node_list = [ 694 ns.call_function(torch.mul), 695 ns.call_function(torch.quantize_per_tensor), 696 ns.call_function(torch.ops.quantized.linear_relu), 697 ns.call_function(torch.ops.quantized.linear), 698 ns.call_method('dequantize') 699 ] 700 701 conv_node_list = [ 702 ns.call_function(torch.mul), 703 ns.call_function(torch.quantize_per_tensor), 704 ns.call_module(nnq.Conv2d), 705 ns.call_method('dequantize') 706 ] 707 708 conv2_node_list = [ 709 ns.call_function(torch.mul), 710 ns.call_function(torch.quantize_per_tensor), 711 ns.call_module(nnq.Conv2d), 712 ns.call_module(nnq.Conv2d), 713 ns.call_method('dequantize') 714 ] 715 716 functionalConv_node_list = [ 717 ns.call_function(torch.mul), 718 ns.call_function(torch.quantize_per_tensor), 719 ns.call_function(torch.ops.quantized.conv2d), 720 ns.call_method('dequantize') 721 ] 722 723 functionalConv2_node_list = [ 724 ns.call_function(torch.mul), 725 ns.call_function(torch.quantize_per_tensor), 726 ns.call_function(torch.ops.quantized.conv2d), 727 ns.call_function(torch.ops.quantized.conv2d), 728 ns.call_method('dequantize') 729 ] 730 731 convRelu_node_list = [ 732 ns.call_function(torch.mul), 733 ns.call_function(torch.quantize_per_tensor), 734 ns.call_module(nniq.ConvReLU2d), 735 ns.call_method('dequantize') 736 ] 737 738 convReluConv_node_list = [ 739 ns.call_function(torch.mul), 740 ns.call_function(torch.quantize_per_tensor), 741 ns.call_module(nniq.ConvReLU2d), 742 ns.call_module(nnq.Conv2d), 743 ns.call_method('dequantize') 744 ] 745 746 functionalConvRelu_node_list = [ 747 ns.call_function(torch.mul), 748 ns.call_function(torch.quantize_per_tensor), 749 ns.call_function(torch.ops.quantized.conv2d_relu), 750 ns.call_method('dequantize') 751 ] 752 753 functionalConvReluConv_node_list = [ 754 ns.call_function(torch.mul), 755 ns.call_function(torch.quantize_per_tensor), 756 ns.call_function(torch.ops.quantized.conv2d_relu), 757 ns.call_function(torch.ops.quantized.conv2d), 758 ns.call_method('dequantize') 759 ] 760 761 tests = [(SingleLayerLinearModel, linear_node_list), 762 (LinearAddModel, linearAdd_node_list), 763 (TwoLayerLinearModel, linear2_node_list), 764 (SingleLayerFunctionalLinearModel, functionalLinear_node_list), 765 (FunctionalLinearAddModel, functionalLinearAdd_node_list), 766 (TwoLayerFunctionalLinearModel, functionalLinear2_node_list), 767 (LinearReluModel, linearRelu_node_list), 768 (LinearReluLinearModel, linearReluLinear_node_list), 769 (FunctionalLinearReluModel, functionalLinearRelu_node_list), 770 (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list), 771 (ConvModel, conv_node_list), 772 (TwoLayerConvModel, conv2_node_list), 773 (SingleLayerFunctionalConvModel, functionalConv_node_list), 774 (TwoLayerFunctionalConvModel, functionalConv2_node_list), 775 (ConvReluModel, convRelu_node_list), 776 (ConvReluConvModel, convReluConv_node_list), 777 (FunctionalConvReluModel, functionalConvRelu_node_list), 778 (FunctionalConvReluConvModel, functionalConvReluConv_node_list)] 779 780 for (M, node_list) in tests: 781 m = M().eval() 782 example_inputs = m.get_example_inputs() 783 prepared = prepare_fx( 784 m, specific_qconfig_dict, 785 example_inputs=example_inputs, 786 _equalization_config=default_equalization_qconfig_dict) 787 equalized_quantized_model = convert_fx(prepared) 788 789 # Check the order of nodes in the graph 790 self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list) 791 792 @skipIfNoFBGEMM 793 def test_input_weight_equalization_results(self): 794 """ Tests that for small models, the results of quantized models that 795 have been equalized are very close to models that have not been equalized. 796 """ 797 798 tests = [SingleLayerLinearModel, TwoLayerLinearModel, LinearAddModel, 799 SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel] 800 801 x = torch.rand((5, 5)) 802 for M in tests: 803 m = M().eval() 804 805 # No equalization 806 example_inputs = (x,) 807 prepared = prepare_fx( 808 copy.deepcopy(m), 809 specific_qconfig_dict, 810 example_inputs=example_inputs, 811 _equalization_config={}) 812 prepared(x) 813 quantized = convert_fx(prepared) # Check if compile 814 quantized_output = quantized(x) 815 816 # With equalization 817 prepared = prepare_fx( 818 copy.deepcopy(m), 819 specific_qconfig_dict, 820 example_inputs=example_inputs, 821 _equalization_config=default_equalization_qconfig_dict 822 ) 823 prepared(x) 824 equalized_and_quantized = convert_fx(prepared) # Check if compile 825 equalized_and_quantized_output = equalized_and_quantized(x) 826 self.assertEqual(quantized_output, equalized_and_quantized_output, rtol=1e-5, atol=0.1) 827 828 @skipIfNoFBGEMM 829 def test_selective_equalization(self): 830 """ Tests that we are able to run numeric suite on the equalized model 831 and construct a valid equalization_config equalizing only the top 832 4 layers with the highest quantization errors. 833 """ 834 835 torch.manual_seed(1) 836 837 class M(nn.Module): 838 def __init__(self) -> None: 839 super().__init__() 840 self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5)) 841 self.top = torch.nn.Sequential(torch.nn.Linear(5, 5)) 842 843 def forward(self, x): 844 x = self.bot(x) 845 x = torch.add(x, 5) 846 x = self.top(x) 847 return x 848 849 float_model = M().eval() 850 # Hard coded so that the top layer has a higher quantization error 851 x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957], 852 [0.8373, 0.8851, 0.8229, 0.0212, 0.8987], 853 [0.9077, 0.7538, 0.4530, 0.5772, 0.1376], 854 [0.0690, 0.9002, 0.7998, 0.2768, 0.8985], 855 [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) 856 857 # Quantize the float model 858 example_inputs = (x,) 859 prepared_model = prepare_fx( 860 copy.deepcopy(float_model), 861 specific_qconfig_dict, 862 example_inputs=example_inputs 863 ) 864 prepared_model(x) 865 quantized_model = convert_fx(copy.deepcopy(prepared_model)) 866 867 # Get the SQNR between the float and quantized model 868 layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x) 869 870 # Construct the equalization_qconfig_dict equalizing layers with the highest 871 # quantization errors 872 selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1) 873 874 # Create the selectively equalized model 875 prepared_model = prepare_fx( 876 copy.deepcopy(float_model), 877 specific_qconfig_dict, 878 example_inputs=example_inputs, 879 _equalization_config=selective_equalization_qconfig_dict, 880 ) 881 prepared_model(x) 882 equalized_model = convert_fx(prepared_model) 883 884 node_list = [ 885 ns.call_function(torch.quantize_per_tensor), 886 ns.call_module(nnq.Linear), 887 ns.call_method('dequantize'), 888 ns.call_function(torch.add), 889 ns.call_function(torch.mul), 890 ns.call_function(torch.quantize_per_tensor), 891 ns.call_module(nnq.Linear), 892 ns.call_method('dequantize') 893 ] 894 895 # Check the order of nodes in the graph 896 self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list) 897