xref: /aosp_15_r20/external/pytorch/test/quantization/fx/test_equalize_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6import torch.ao.nn.intrinsic.quantized as nniq
7import torch.ao.nn.quantized as nnq
8from torch.ao.quantization import default_qconfig
9from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
10from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
11from torch.ao.quantization.fx._equalize import (
12    _InputEqualizationObserver,
13    _WeightEqualizationObserver,
14    calculate_equalization_scale,
15    default_equalization_qconfig,
16    _convert_equalization_ref,
17    get_layer_sqnr_dict,
18    get_equalization_qconfig_dict,
19)
20
21from torch.testing._internal.common_quantization import (
22    NodeSpec as ns,
23    QuantizationTestCase,
24    SingleLayerLinearModel,
25    TwoLayerLinearModel,
26    LinearAddModel,
27    SingleLayerFunctionalLinearModel,
28    TwoLayerFunctionalLinearModel,
29    FunctionalLinearAddModel,
30    ConvModel,
31    TwoLayerConvModel,
32    SingleLayerFunctionalConvModel,
33    TwoLayerFunctionalConvModel,
34    skipIfNoFBGEMM,
35    LinearReluModel,
36    LinearReluLinearModel,
37    LinearReluAddModel,
38    FunctionalLinearReluModel,
39    FunctionalLinearReluLinearModel,
40    ConvReluModel,
41    ConvReluConvModel,
42    ConvReluAddModel,
43    FunctionalConvReluModel,
44    FunctionalConvReluConvModel,
45)
46
47# Standard Libraries
48import copy
49import numpy as np
50
51# Testing utils
52from hypothesis import given
53from hypothesis import strategies as st
54
55
56default_qconfig_dict = {"": default_qconfig}
57
58specific_qconfig_dict = {
59    "": None,
60    "object_type": [(nn.Linear, default_qconfig),
61                    (F.linear, default_qconfig),
62                    (nn.ReLU, default_qconfig),
63                    (F.relu, default_qconfig),
64                    (nn.Conv2d, default_qconfig),
65                    (F.conv2d, default_qconfig)]
66}
67
68default_equalization_qconfig_dict = {
69    "": None,
70    "object_type": [(nn.Linear, default_equalization_qconfig),
71                    (F.linear, default_equalization_qconfig),
72                    (nn.ReLU, default_equalization_qconfig),
73                    (F.relu, default_equalization_qconfig),
74                    (nn.Conv2d, default_equalization_qconfig),
75                    (F.conv2d, default_equalization_qconfig)]
76}
77
78
79class TestEqualizeFx(QuantizationTestCase):
80    def channel_minmax(self, input, axis=1):
81        ''' Finds the min/max of inputs associated with a specific channel
82        '''
83        size_of_tensor_dim = input.ndim
84        axis_list = list(range(size_of_tensor_dim))
85        axis_list.remove(axis)
86        axis_list.sort(reverse=True)
87
88        mins = input.copy()
89        maxs = input.copy()
90        for a in axis_list:
91            mins = mins.min(a)
92            maxs = maxs.max(a)
93
94        return (mins, maxs)
95
96    @given(ndim=st.sampled_from((2, 3, 4, 5)),
97           input_qdtype=st.sampled_from((torch.qint8, torch.quint8)),
98           input_qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
99           weight_qdtype=st.sampled_from((torch.qint8, torch.quint8)),
100           weight_qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric,
101                                           torch.per_channel_affine_float_qparams)))
102    def test_input_weight_eq_observer(self, ndim, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme):
103        sizes = []
104        for _ in range((ndim - 1) * 2):
105            sizes.append(np.random.randint(2, 10))
106
107        channel = np.random.randint(1, 10)
108        if ndim == 2:
109            x = np.random.random(size=(sizes[0], channel))
110            w = np.random.random(size=(sizes[1], channel))
111        elif ndim == 3:
112            x = np.random.random(size=(sizes[0], channel, sizes[1]))
113            w = np.random.random(size=(sizes[2], channel, sizes[3]))
114        elif ndim == 4:
115            x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2]))
116            w = np.random.random(size=(sizes[3], channel, sizes[4], sizes[5]))
117        elif ndim == 5:
118            x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2], sizes[3]))
119            w = np.random.random(size=(sizes[4], channel, sizes[5], sizes[6], sizes[7]))
120
121        x = (x * 10).round(decimals=2).astype(np.float32)
122        w = (w * 10).round(decimals=2).astype(np.float32)
123
124        input_eq_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme)
125        weight_eq_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme)
126
127        ret_x = input_eq_obs(torch.tensor(x))
128        ret_w = weight_eq_obs(torch.tensor(w))
129        self.assertEqual((ret_x, ret_w), (x, w))
130
131        # Check the min/max input columns are correct
132        ref_min_inputs, ref_max_inputs = self.channel_minmax(x)
133        min_inputs, max_inputs = input_eq_obs.get_input_minmax()
134        self.assertEqual(min_inputs, torch.tensor(ref_min_inputs, dtype=torch.float32))
135        self.assertEqual(max_inputs, torch.tensor(ref_max_inputs, dtype=torch.float32))
136
137        # Check the min/max weight columns are correct
138        ref_min_weights_col, ref_max_weights_col = self.channel_minmax(w)
139        min_weights_col, max_weights_col = weight_eq_obs.get_weight_col_minmax()
140        self.assertEqual(min_weights_col, torch.tensor(ref_min_weights_col, dtype=torch.float32))
141        self.assertEqual(max_weights_col, torch.tensor(ref_max_weights_col, dtype=torch.float32))
142
143        # Check the equalization scale is correct
144        equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
145        ref_equalization_scale = np.sqrt((ref_max_weights_col - ref_min_weights_col) /
146                                         (ref_max_inputs - ref_min_inputs))
147        self.assertEqual(equalization_scale, torch.tensor(ref_equalization_scale, dtype=torch.float32))
148
149        input_eq_obs.set_equalization_scale(equalization_scale)
150        weight_eq_obs.set_equalization_scale(equalization_scale)
151
152        # Check the input scale/zero-point values
153        min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
154        input_quant_obs = MinMaxObserver(dtype=input_qdtype, qscheme=input_qscheme)
155        input_quant_obs.min_val = min_input_scaled
156        input_quant_obs.max_val = max_input_scaled
157        input_qparams = input_quant_obs.calculate_qparams()
158
159        ref_min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale)
160        ref_min_input_scaled = min(0, ref_min_input_scaled)
161        ref_max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale)
162        ref_max_input_scaled = max(0, ref_max_input_scaled)
163
164        if input_qscheme == torch.per_tensor_symmetric:
165            ref_scale = 2 * max(abs(ref_min_input_scaled), ref_max_input_scaled) / 255
166            ref_zero_point = 0 if input_qdtype is torch.qint8 else 128
167        else:
168            ref_scale = (ref_max_input_scaled - ref_min_input_scaled) / 255
169            quant_min = -128 if input_qdtype is torch.qint8 else 0
170            quant_max = 127 if input_qdtype is torch.qint8 else 255
171            ref_zero_point = quant_min - np.round(ref_min_input_scaled / ref_scale)
172            np.clip(ref_zero_point, quant_min, quant_max)
173
174        self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
175        self.assertEqual(input_qparams[1].item(), ref_zero_point)
176
177        # During input-weight equalization, we will scale the weights so that
178        # the following weight quantized observer will have the correct scaled qparams
179        # Check the weight scale/zero-point values of the quantized observer
180        weight_quant_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=weight_qdtype, qscheme=weight_qscheme)
181
182        # Scale the weights for input-weight equalization
183        new_shape = [1] * w.ndim
184        new_shape[1] = w.shape[1]
185        ref_w_scaled = w * np.reciprocal(ref_equalization_scale.reshape(tuple(new_shape)))
186
187        w = torch.tensor(w)
188        new_shape[1] = w.size(1)
189        w_scaled = torch.mul(w, torch.reciprocal(equalization_scale.view(new_shape)))
190
191        self.assertEqual(w_scaled, ref_w_scaled)
192
193        # Call forward on the weight quantization observer
194        weight_quant_obs(w_scaled)
195
196        # Check the min/max weight rows are correct
197        ref_min_weights_scaled, ref_max_weights_scaled = self.channel_minmax(ref_w_scaled)
198        self.assertEqual(weight_quant_obs.min_val, torch.tensor(ref_min_weights_scaled, dtype=torch.float32))
199        self.assertEqual(weight_quant_obs.max_val, torch.tensor(ref_max_weights_scaled, dtype=torch.float32))
200
201        weight_qparams = weight_quant_obs.calculate_qparams()
202
203        if weight_qscheme == torch.per_channel_symmetric:
204            ref_min_weights_scaled = np.minimum(np.zeros(ref_min_weights_scaled.shape), ref_min_weights_scaled)
205            ref_max_weights_scaled = np.maximum(np.zeros(ref_max_weights_scaled.shape), ref_max_weights_scaled)
206
207            ref_scales = 2 * np.maximum(np.abs(ref_min_weights_scaled), ref_max_weights_scaled) / 255
208            ref_zero_points = np.zeros_like(
209                ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(ref_scales) * 128
210        elif weight_qscheme == torch.per_channel_affine_float_qparams:
211            ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255
212            ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales))
213            ref_zero_points = -1 * ref_min_weights_scaled / ref_scales
214        else:
215            ref_min_weights_scaled = np.minimum(np.zeros_like(ref_min_weights_scaled), ref_min_weights_scaled)
216            ref_max_weights_scaled = np.maximum(np.zeros_like(ref_max_weights_scaled), ref_max_weights_scaled)
217
218            ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255
219            ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0
220            ref_zero_points = ref_zero_points - np.round(ref_min_weights_scaled / ref_scales)
221
222        self.assertEqual(weight_qparams[0], torch.tensor(
223            ref_scales, dtype=weight_qparams[0].dtype), rtol=1e-5, atol=0.0001)
224        self.assertEqual(weight_qparams[1], torch.tensor(
225            ref_zero_points, dtype=weight_qparams[1].dtype), rtol=1e-5, atol=1)
226
227    def test_input_weight_equalization_prepare(self):
228        """ Tests that graphs created after prepare_fx is as expected
229        """
230
231        single_nn_layer_node_occurrence = {
232            ns.call_module(_InputEqualizationObserver): 1,
233            ns.call_module(MinMaxObserver): 2,
234        }
235
236        two_nn_layer_node_occurrence = {
237            ns.call_module(_InputEqualizationObserver): 2,
238            ns.call_module(MinMaxObserver): 3,
239        }
240
241        single_F_layer_node_occurrence = {
242            ns.call_module(_InputEqualizationObserver): 1,
243            ns.call_module(_WeightEqualizationObserver): 1,
244            ns.call_module(MinMaxObserver): 3,
245        }
246
247        two_F_layer_node_occurrence = {
248            ns.call_module(_InputEqualizationObserver): 2,
249            ns.call_module(_WeightEqualizationObserver): 2,
250            ns.call_module(MinMaxObserver): 5,
251        }
252
253        fp_F_layer_node_occurrence = {
254            ns.call_module(_InputEqualizationObserver): 2,
255            ns.call_module(_WeightEqualizationObserver): 2,
256            ns.call_module(MinMaxObserver): 6,
257        }
258
259        tests = [(SingleLayerLinearModel, single_nn_layer_node_occurrence),
260                 (TwoLayerLinearModel, two_nn_layer_node_occurrence),
261                 (TwoLayerFunctionalLinearModel, two_F_layer_node_occurrence),
262                 (FunctionalLinearAddModel, fp_F_layer_node_occurrence),
263                 (LinearReluModel, single_nn_layer_node_occurrence),
264                 (LinearReluLinearModel, two_nn_layer_node_occurrence),
265                 (FunctionalLinearReluModel, single_F_layer_node_occurrence),
266                 (FunctionalLinearReluLinearModel, two_F_layer_node_occurrence),
267                 (ConvModel, single_nn_layer_node_occurrence),
268                 (TwoLayerConvModel, two_nn_layer_node_occurrence),
269                 (TwoLayerFunctionalConvModel, two_F_layer_node_occurrence),
270                 (ConvReluModel, single_nn_layer_node_occurrence),
271                 (ConvReluConvModel, two_nn_layer_node_occurrence),
272                 (FunctionalConvReluModel, single_F_layer_node_occurrence),
273                 (FunctionalConvReluConvModel, two_F_layer_node_occurrence)]
274
275        for (M, node_occurrence) in tests:
276            m = M().eval()
277            example_inputs = m.get_example_inputs()
278            prepared = prepare_fx(
279                m,
280                specific_qconfig_dict,
281                example_inputs=example_inputs,
282                _equalization_config=default_equalization_qconfig_dict)
283            self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
284
285    def test_input_weight_equalization_branching(self):
286        """ Tests that graphs containing branches are prepared correctly.
287        Specifically, equalization observers should not be inserted in front of
288        branches in which both initial layers in the branches plan to be
289        quantized.
290        """
291
292        # Tests that we do not add an equalization observer due to both initial
293        # nodes in the branch containing layers that need to be equalized.
294        # Note that this should print out 2 warning messages for not being able
295        # to equalize layers linear1 and linear1 because it is part of a branch
296        class TestBranchingWithoutEqualizationModel(nn.Module):
297            def __init__(self) -> None:
298                super().__init__()
299                self.linear1 = nn.Linear(5, 5)
300                self.linear2 = nn.Linear(5, 5)
301
302            def forward(self, x):
303                y = self.linear1(x)
304                z = self.linear2(x)
305                return torch.add(y, z)
306
307        no_eq_branching_node_occurrence = {
308            ns.call_module(_InputEqualizationObserver): 0,
309            ns.call_module(MinMaxObserver): 3,
310        }
311
312        m = TestBranchingWithoutEqualizationModel().eval()
313        example_inputs = (torch.rand(1, 5),)
314        prepared = prepare_fx(
315            m, specific_qconfig_dict, example_inputs=example_inputs,
316            _equalization_config=default_equalization_qconfig_dict)
317        self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)
318
319        # Tests that we will add an equalization observer because there is only
320        # one initial node in the branch that needs to be equalized
321        class TestBranchingWithEqualizationModel(nn.Module):
322            def __init__(self) -> None:
323                super().__init__()
324                self.linear1 = nn.Linear(5, 5)
325
326            def forward(self, x):
327                y = self.linear1(x)
328                z = torch.add(x, 5)
329                return torch.add(y, z)
330
331        eq_branching_node_occurrence = {
332            ns.call_module(_InputEqualizationObserver): 1,
333            ns.call_module(MinMaxObserver): 2,
334        }
335
336        m = TestBranchingWithEqualizationModel().eval()
337        example_inputs = (torch.randn(1, 5),)
338        prepared = prepare_fx(
339            m, specific_qconfig_dict, example_inputs=example_inputs,
340            _equalization_config=default_equalization_qconfig_dict)
341        self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)
342
343    @skipIfNoFBGEMM
344    def test_input_weight_equalization_convert(self):
345        """ Tests that the modified model for equalization (before quantization)
346        returns the same output as the original model
347        """
348
349        tests = [(SingleLayerLinearModel, 2), (LinearAddModel, 2), (TwoLayerLinearModel, 2),
350                 (SingleLayerFunctionalLinearModel, 2), (FunctionalLinearAddModel, 2),
351                 (TwoLayerFunctionalLinearModel, 2),
352                 (LinearReluModel, 2), (LinearReluLinearModel, 2), (LinearReluAddModel, 2),
353                 (FunctionalLinearReluModel, 2), (FunctionalLinearReluLinearModel, 2),
354                 (ConvModel, 4), (TwoLayerConvModel, 4), (SingleLayerFunctionalConvModel, 4),
355                 (TwoLayerFunctionalConvModel, 4),
356                 (ConvReluModel, 4), (ConvReluConvModel, 4), (ConvReluAddModel, 4),
357                 (FunctionalConvReluModel, 4), (FunctionalConvReluConvModel, 4)]
358
359        for (M, ndim) in tests:
360            m = M().eval()
361
362            if ndim == 2:
363                x = torch.rand((5, 5))
364            elif ndim == 4:
365                x = torch.rand((16, 3, 224, 224))
366
367            example_inputs = (x,)
368            prepared = prepare_fx(
369                copy.deepcopy(m),
370                specific_qconfig_dict,
371                example_inputs=example_inputs,
372                _equalization_config=default_equalization_qconfig_dict
373            )
374            output = prepared(x)
375
376            convert_ref = _convert_equalization_ref(prepared)
377            convert_ref_output = convert_ref(x)
378
379            prepared = prepare_fx(
380                m, specific_qconfig_dict,
381                example_inputs=example_inputs,
382                _equalization_config=default_equalization_qconfig_dict)
383            prepared(x)
384            convert_fx(prepared)  # Check if compile
385            self.assertEqual(output, convert_ref_output)
386
387    def calculate_equalization_scale_ref(self, x, w):
388        """ Calculates the equalization scale based on the input and weight
389        """
390        min_inputs = x.min(axis=0)
391        max_inputs = x.max(axis=0)
392
393        min_weights_col = w.min(axis=0)
394        max_weights_col = w.max(axis=0)
395
396        equalization_scale = np.sqrt((max_weights_col - min_weights_col) /
397                                     (max_inputs - min_inputs))
398        return equalization_scale
399
400    def get_expected_eq_scales(self, model, x):
401        """ For each module in the graph, we want to calculate the equalization
402        scale at that point. This only works for models containing single or
403        connected linear layers.
404        """
405        exp_eq_scales = []
406        for _, module in model.named_children():
407            weight = module.weight.detach().numpy()
408            bias = module.bias.detach().numpy()
409
410            eq_scale = self.calculate_equalization_scale_ref(x, weight)
411            exp_eq_scales.append(eq_scale)
412
413            x = x @ weight.T + bias
414
415        return exp_eq_scales
416
417    def test_input_weight_equalization_equalization_scales(self):
418        """ After applying the equalization functions, check if the equalization
419        scales are the expected values
420        """
421
422        tests = [SingleLayerLinearModel, TwoLayerLinearModel,
423                 SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel]
424
425        x = torch.rand((5, 5))
426        for M in tests:
427            m = M().eval()
428            exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
429
430            example_inputs = (x,)
431            prepared = prepare_fx(
432                m, specific_qconfig_dict,
433                example_inputs=example_inputs,
434                _equalization_config=default_equalization_qconfig_dict)
435            prepared(*example_inputs)
436            convert_ref = _convert_equalization_ref(prepared)
437            convert_ref(x)
438
439            counter = 0
440            for node in convert_ref.graph.nodes:
441                if 'equalization_scale' in node.name and node.op == 'get_attr':
442                    self.assertEqual(convert_ref.get_buffer(str(node.target)).reshape(-1), exp_eq_scales[counter])
443                    counter += 1
444
445    def get_expected_weights_bias(self, model, x, exp_eq_scales):
446        """ For each module in the graph, we want to calculate the expected
447        scaled weight and bias values. This only works for models containing
448        single or connected linear layers.
449        """
450        exp_weights = []
451        exp_bias = []
452        for i, (_, module) in enumerate(model.named_children()):
453            weight = module.weight.detach().numpy()
454            bias = module.bias.detach().numpy()
455
456            scaled_weight = weight * np.reciprocal(exp_eq_scales[i])
457            scaled_bias = bias
458            if i + 1 < len(exp_eq_scales):
459                scaled_weight = (scaled_weight.T * exp_eq_scales[i + 1]).T
460                scaled_bias = (scaled_bias.T * exp_eq_scales[i + 1]).T
461
462            exp_weights.append(scaled_weight)
463            exp_bias.append(scaled_bias)
464
465            x = x @ weight.T + bias
466
467        return exp_weights, exp_bias
468
469    def test_input_weight_equalization_weights_bias(self):
470        """ After applying the equalization functions check if the weights and
471        biases are as expected
472        """
473
474        tests = [SingleLayerLinearModel, TwoLayerLinearModel,
475                 SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel]
476
477        x = torch.rand((5, 5))
478        for M in tests:
479            m = M().eval()
480            exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
481            exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales)
482
483            example_inputs = (x,)
484            prepared = prepare_fx(
485                m, specific_qconfig_dict,
486                example_inputs=example_inputs,
487                _equalization_config=default_equalization_qconfig_dict)
488            prepared(x)
489            convert_ref = _convert_equalization_ref(prepared)
490            convert_ref(x)
491
492            modules = dict(convert_ref.named_modules(remove_duplicate=False))
493            counter = 0
494            for node in convert_ref.graph.nodes:
495                if node.op == 'call_module' and isinstance(modules[str(node.target)], nn.Linear):
496                    self.assertEqual(modules[str(node.target)].weight, exp_weights[counter])
497                    self.assertEqual(modules[str(node.target)].bias, exp_bias[counter])
498                    counter += 1
499
500    def get_expected_inp_act_vals(self, model, x, exp_eq_scales, exp_weights, exp_bias):
501        """ For each module in the graph, we want to calculate the expected
502        min/max values for every input activation node. This only works for
503        models containing only single or connected linear layers.
504        """
505        x = x * exp_eq_scales[0]
506
507        exp_inp_activation_vals = []
508        for i, _ in enumerate(model.named_children()):
509            exp_inp_activation_vals.append((x.min(), x.max()))
510            x = x @ exp_weights[i].T + exp_bias[i]
511
512        exp_inp_activation_vals.append((x.min(), x.max()))
513        return exp_inp_activation_vals
514
515    def get_expected_weight_act_vals(self, exp_weights):
516        """ For each module in the graph, we want to calculate the expected
517        min/max values for every weight activation node. This is assuming that
518        the weight observers are all MinMaxObservers.
519        """
520
521        exp_weight_activation_vals = []
522        for w in exp_weights:
523            exp_weight_activation_vals.append((w.min(), w.max()))
524
525        return exp_weight_activation_vals
526
527    def test_input_weight_equalization_activation_values(self):
528        """ After applying the equalization functions check if the input
529        observer's min/max values are as expected
530        """
531
532        tests = [SingleLayerLinearModel, TwoLayerLinearModel, SingleLayerFunctionalLinearModel]
533
534        x = torch.rand((5, 5))
535        torch.manual_seed(0)
536        for M in tests:
537            m = M().eval()
538            exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
539            exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales)
540            exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias)
541            exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights)
542
543            example_inputs = (x,)
544            prepared = prepare_fx(
545                m, specific_qconfig_dict,
546                example_inputs=example_inputs,
547                _equalization_config=default_equalization_qconfig_dict)
548            prepared(x)
549            convert_ref = _convert_equalization_ref(prepared)
550            convert_ref(x)
551
552            modules = dict(convert_ref.named_modules(remove_duplicate=False))
553            inp_counter = 0
554            weight_counter = 0
555            for node in convert_ref.graph.nodes:
556                users = list(node.users)
557                if node.op == 'call_module' and isinstance(modules[str(node.target)], MinMaxObserver):
558                    if len(users) == 1 and users[0].target == torch.nn.functional.linear and users[0].args[1] == node:
559                        # Check min/max values of weight activation layers
560                        exp_min_val, exp_max_val = exp_weight_act_vals[weight_counter]
561                        self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
562                        self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
563                        weight_counter += 1
564                    else:
565                        # Check min/max values of input activation layers
566                        exp_min_val, exp_max_val = exp_inp_act_vals[inp_counter]
567                        self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
568                        self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
569                        inp_counter += 1
570
571
572    def check_orig_and_eq_graphs(self, orig_model, eq_model):
573        """ Given a non-equalized model and an equalized model, check that the
574        graphs are structured in the same way, except the equalized model has
575        additional 'equalization_scale' and 'mul' nodes.
576        """
577        orig_idx = 0
578        orig_nodes = list(orig_model.graph.nodes)
579        orig_modules = dict(orig_model.named_modules(remove_duplicate=False))
580
581        eq_idx = 0
582        eq_nodes = list(eq_model.graph.nodes)
583        eq_modules = dict(eq_model.named_modules(remove_duplicate=False))
584
585        while orig_idx < len(orig_nodes) and eq_idx < len(eq_nodes):
586            if 'equalization_scale' in eq_nodes[eq_idx].name and 'mul' in eq_nodes[eq_idx + 1].name:
587                # Skip the equalization and mul nodes
588                eq_idx += 2
589                continue
590            elif orig_nodes[orig_idx].op != eq_nodes[eq_idx].op:
591                return False
592            elif orig_nodes[orig_idx].op == 'call_module':
593                # Check that the type of call_modules are the same (ex. nn.Linear, MinMaxObserver)
594                orig_node = orig_nodes[orig_idx]
595                eq_node = eq_nodes[eq_idx]
596                if type(orig_modules[orig_node.target]) is not type(eq_modules[eq_node.target]):
597                    return False
598            elif orig_nodes[orig_idx].op == 'call_function':
599                # Check that the call_functions are the same (ex. F.linear)
600                orig_node = orig_nodes[orig_idx]
601                eq_node = eq_nodes[eq_idx]
602                if orig_node.target != eq_node.target:
603                    return False
604
605            eq_idx += 1
606            orig_idx += 1
607
608        return True
609
610    @skipIfNoFBGEMM
611    def test_input_weight_equalization_graphs(self):
612        """ Tests that the modified model for equalization has the same graph
613        structure as the model without equalization (before and after
614        quantization).
615        """
616
617        linear_node_list = [
618            ns.call_function(torch.mul),
619            ns.call_function(torch.quantize_per_tensor),
620            ns.call_module(nnq.Linear),
621            ns.call_method('dequantize')
622        ]
623
624        linearAdd_node_list = [
625            ns.call_function(torch.mul),
626            ns.call_function(torch.quantize_per_tensor),
627            ns.call_module(nnq.Linear),
628            ns.call_method('dequantize'),
629            ns.call_function(torch.add),
630            ns.call_function(torch.mul),
631            ns.call_function(torch.quantize_per_tensor),
632            ns.call_module(nnq.Linear),
633            ns.call_method('dequantize')
634        ]
635
636        linear2_node_list = [
637            ns.call_function(torch.mul),
638            ns.call_function(torch.quantize_per_tensor),
639            ns.call_module(nnq.Linear),
640            ns.call_module(nnq.Linear),
641            ns.call_method('dequantize')
642        ]
643
644        functionalLinear_node_list = [
645            ns.call_function(torch.mul),
646            ns.call_function(torch.quantize_per_tensor),
647            ns.call_function(torch.ops.quantized.linear),
648            ns.call_method('dequantize')
649        ]
650
651        functionalLinearAdd_node_list = [
652            ns.call_function(torch.mul),
653            ns.call_function(torch.quantize_per_tensor),
654            ns.call_function(torch.ops.quantized.linear),
655            ns.call_method('dequantize'),
656            ns.call_function(torch.add),
657            ns.call_function(torch.mul),
658            ns.call_function(torch.quantize_per_tensor),
659            ns.call_function(torch.ops.quantized.linear),
660            ns.call_method('dequantize')
661        ]
662
663        functionalLinear2_node_list = [
664            ns.call_function(torch.mul),
665            ns.call_function(torch.quantize_per_tensor),
666            ns.call_function(torch.ops.quantized.linear),
667            ns.call_function(torch.ops.quantized.linear),
668            ns.call_method('dequantize')
669        ]
670
671        linearRelu_node_list = [
672            ns.call_function(torch.mul),
673            ns.call_function(torch.quantize_per_tensor),
674            ns.call_module(nniq.LinearReLU),
675            ns.call_method('dequantize')
676        ]
677
678        linearReluLinear_node_list = [
679            ns.call_function(torch.mul),
680            ns.call_function(torch.quantize_per_tensor),
681            ns.call_module(nniq.LinearReLU),
682            ns.call_module(nnq.Linear),
683            ns.call_method('dequantize')
684        ]
685
686        functionalLinearRelu_node_list = [
687            ns.call_function(torch.mul),
688            ns.call_function(torch.quantize_per_tensor),
689            ns.call_function(torch.ops.quantized.linear_relu),
690            ns.call_method('dequantize')
691        ]
692
693        functionalLinearReluLinear_node_list = [
694            ns.call_function(torch.mul),
695            ns.call_function(torch.quantize_per_tensor),
696            ns.call_function(torch.ops.quantized.linear_relu),
697            ns.call_function(torch.ops.quantized.linear),
698            ns.call_method('dequantize')
699        ]
700
701        conv_node_list = [
702            ns.call_function(torch.mul),
703            ns.call_function(torch.quantize_per_tensor),
704            ns.call_module(nnq.Conv2d),
705            ns.call_method('dequantize')
706        ]
707
708        conv2_node_list = [
709            ns.call_function(torch.mul),
710            ns.call_function(torch.quantize_per_tensor),
711            ns.call_module(nnq.Conv2d),
712            ns.call_module(nnq.Conv2d),
713            ns.call_method('dequantize')
714        ]
715
716        functionalConv_node_list = [
717            ns.call_function(torch.mul),
718            ns.call_function(torch.quantize_per_tensor),
719            ns.call_function(torch.ops.quantized.conv2d),
720            ns.call_method('dequantize')
721        ]
722
723        functionalConv2_node_list = [
724            ns.call_function(torch.mul),
725            ns.call_function(torch.quantize_per_tensor),
726            ns.call_function(torch.ops.quantized.conv2d),
727            ns.call_function(torch.ops.quantized.conv2d),
728            ns.call_method('dequantize')
729        ]
730
731        convRelu_node_list = [
732            ns.call_function(torch.mul),
733            ns.call_function(torch.quantize_per_tensor),
734            ns.call_module(nniq.ConvReLU2d),
735            ns.call_method('dequantize')
736        ]
737
738        convReluConv_node_list = [
739            ns.call_function(torch.mul),
740            ns.call_function(torch.quantize_per_tensor),
741            ns.call_module(nniq.ConvReLU2d),
742            ns.call_module(nnq.Conv2d),
743            ns.call_method('dequantize')
744        ]
745
746        functionalConvRelu_node_list = [
747            ns.call_function(torch.mul),
748            ns.call_function(torch.quantize_per_tensor),
749            ns.call_function(torch.ops.quantized.conv2d_relu),
750            ns.call_method('dequantize')
751        ]
752
753        functionalConvReluConv_node_list = [
754            ns.call_function(torch.mul),
755            ns.call_function(torch.quantize_per_tensor),
756            ns.call_function(torch.ops.quantized.conv2d_relu),
757            ns.call_function(torch.ops.quantized.conv2d),
758            ns.call_method('dequantize')
759        ]
760
761        tests = [(SingleLayerLinearModel, linear_node_list),
762                 (LinearAddModel, linearAdd_node_list),
763                 (TwoLayerLinearModel, linear2_node_list),
764                 (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
765                 (FunctionalLinearAddModel, functionalLinearAdd_node_list),
766                 (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
767                 (LinearReluModel, linearRelu_node_list),
768                 (LinearReluLinearModel, linearReluLinear_node_list),
769                 (FunctionalLinearReluModel, functionalLinearRelu_node_list),
770                 (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list),
771                 (ConvModel, conv_node_list),
772                 (TwoLayerConvModel, conv2_node_list),
773                 (SingleLayerFunctionalConvModel, functionalConv_node_list),
774                 (TwoLayerFunctionalConvModel, functionalConv2_node_list),
775                 (ConvReluModel, convRelu_node_list),
776                 (ConvReluConvModel, convReluConv_node_list),
777                 (FunctionalConvReluModel, functionalConvRelu_node_list),
778                 (FunctionalConvReluConvModel, functionalConvReluConv_node_list)]
779
780        for (M, node_list) in tests:
781            m = M().eval()
782            example_inputs = m.get_example_inputs()
783            prepared = prepare_fx(
784                m, specific_qconfig_dict,
785                example_inputs=example_inputs,
786                _equalization_config=default_equalization_qconfig_dict)
787            equalized_quantized_model = convert_fx(prepared)
788
789            # Check the order of nodes in the graph
790            self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
791
792    @skipIfNoFBGEMM
793    def test_input_weight_equalization_results(self):
794        """ Tests that for small models, the results of quantized models that
795        have been equalized are very close to models that have not been equalized.
796        """
797
798        tests = [SingleLayerLinearModel, TwoLayerLinearModel, LinearAddModel,
799                 SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel]
800
801        x = torch.rand((5, 5))
802        for M in tests:
803            m = M().eval()
804
805            # No equalization
806            example_inputs = (x,)
807            prepared = prepare_fx(
808                copy.deepcopy(m),
809                specific_qconfig_dict,
810                example_inputs=example_inputs,
811                _equalization_config={})
812            prepared(x)
813            quantized = convert_fx(prepared)  # Check if compile
814            quantized_output = quantized(x)
815
816            # With equalization
817            prepared = prepare_fx(
818                copy.deepcopy(m),
819                specific_qconfig_dict,
820                example_inputs=example_inputs,
821                _equalization_config=default_equalization_qconfig_dict
822            )
823            prepared(x)
824            equalized_and_quantized = convert_fx(prepared)  # Check if compile
825            equalized_and_quantized_output = equalized_and_quantized(x)
826            self.assertEqual(quantized_output, equalized_and_quantized_output, rtol=1e-5, atol=0.1)
827
828    @skipIfNoFBGEMM
829    def test_selective_equalization(self):
830        """ Tests that we are able to run numeric suite on the equalized model
831        and construct a valid equalization_config equalizing only the top
832        4 layers with the highest quantization errors.
833        """
834
835        torch.manual_seed(1)
836
837        class M(nn.Module):
838            def __init__(self) -> None:
839                super().__init__()
840                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
841                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))
842
843            def forward(self, x):
844                x = self.bot(x)
845                x = torch.add(x, 5)
846                x = self.top(x)
847                return x
848
849        float_model = M().eval()
850        # Hard coded so that the top layer has a higher quantization error
851        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
852                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
853                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
854                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
855                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])
856
857        # Quantize the float model
858        example_inputs = (x,)
859        prepared_model = prepare_fx(
860            copy.deepcopy(float_model),
861            specific_qconfig_dict,
862            example_inputs=example_inputs
863        )
864        prepared_model(x)
865        quantized_model = convert_fx(copy.deepcopy(prepared_model))
866
867        # Get the SQNR between the float and quantized model
868        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x)
869
870        # Construct the equalization_qconfig_dict equalizing layers with the highest
871        # quantization errors
872        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1)
873
874        # Create the selectively equalized model
875        prepared_model = prepare_fx(
876            copy.deepcopy(float_model),
877            specific_qconfig_dict,
878            example_inputs=example_inputs,
879            _equalization_config=selective_equalization_qconfig_dict,
880        )
881        prepared_model(x)
882        equalized_model = convert_fx(prepared_model)
883
884        node_list = [
885            ns.call_function(torch.quantize_per_tensor),
886            ns.call_module(nnq.Linear),
887            ns.call_method('dequantize'),
888            ns.call_function(torch.add),
889            ns.call_function(torch.mul),
890            ns.call_function(torch.quantize_per_tensor),
891            ns.call_module(nnq.Linear),
892            ns.call_method('dequantize')
893        ]
894
895        # Check the order of nodes in the graph
896        self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
897