xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_quantized_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["oncall: quantization"]
2 
3 import torch
4 import torch.nn as nn
5 import torch.ao.nn.intrinsic as nni
6 import torch.ao.nn.intrinsic.quantized as nniq
7 import torch.ao.nn.quantized.reference as nnqr
8 import torch.ao.quantization
9 import torch.ao.nn.quantized as nnq
10 import torch.ao.nn.quantized.dynamic as nnqd
11 
12 from torch.ao.quantization import (
13     get_default_static_quant_module_mappings,
14     default_float_qparams_observer,
15     PerChannelMinMaxObserver,
16 )
17 from torch.package import PackageExporter, PackageImporter
18 from torch.testing._internal.common_quantization import (
19     QuantizationTestCase,
20     prepare_dynamic,
21     _make_conv_test_input,
22     skipIfNoFBGEMM,
23     lengths_to_offsets,
24     skipIfNoONEDNN,
25     _make_conv_add_extra_input_tensor,
26 )
27 from torch.testing._internal.common_quantized import (
28     _calculate_dynamic_qparams,
29     override_quantized_engine,
30     override_qengines,
31     qengine_is_qnnpack,
32     qengine_is_onednn,
33 )
34 import torch.fx
35 from hypothesis import assume, given
36 from hypothesis import strategies as st
37 import torch.testing._internal.hypothesis_utils as hu
38 hu.assert_deadline_disabled()
39 
40 import copy
41 import io
42 import numpy as np
43 import itertools
44 
45 """
46 Note that tests in this file are just API test, to make sure we wrapped the
47 quantized operator implementations correctly in the user facing APIs, these are
48 not correctness test for the underlying quantized operators. For correctness
49 test please see `test/quantization/test_quantized_op.py`.
50 """
51 
52 class TestStaticQuantizedModule(QuantizationTestCase):
53     def test_relu(self):
54         relu_module = nn.ReLU()
55         relu6_module = nnq.ReLU6()
56 
57         x = torch.arange(-10, 10, dtype=torch.float)
58         y_ref = torch.relu(x)
59         y6_ref = torch.nn.modules.ReLU6()(x)
60 
61         qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.qint32)
62         qy = relu_module(qx)
63         qy6 = relu6_module(qx)
64 
65         self.assertEqual(y_ref, qy.dequantize(),
66                          msg="ReLU module API failed")
67         self.assertEqual(y6_ref, qy6.dequantize(),
68                          msg="ReLU6 module API failed")
69 
70     @override_qengines
71     def test_linear(self):
72         """test API functionality for nn.quantized.linear"""
73         options = itertools.product(
74             [1, 5],
75             [16, 32],
76             [4, 8],
77             [True, False],
78             [True, False])
79         for (batch_size, in_features, out_features, use_bias, per_channel) in options:
80             self._test_linear_api_impl(
81                 nnq.Linear, 'QuantizedLinear', torch.ops.quantized.linear, batch_size,
82                 in_features, out_features, use_bias, per_channel)
83 
84     @override_qengines
85     def test_linear_relu(self):
86         """test API functionality for nn.intrinsic.quantized.linear_relu"""
87         options = itertools.product(
88             [1, 5],
89             [16, 32],
90             [4, 8],
91             [True, False],
92             [True, False])
93         for (batch_size, in_features, out_features, use_bias, per_channel) in options:
94             self._test_linear_api_impl(
95                 nniq.LinearReLU, 'QuantizedLinearReLU', torch.ops.quantized.linear_relu,
96                 batch_size, in_features, out_features, use_bias, per_channel)
97 
98     def _test_linear_api_impl(self, qlinear_module, module_name, qlinear_op,
99                               batch_size, in_features, out_features, use_bias,
100                               per_channel, **post_ops_kwargs):
101         if torch.backends.quantized.engine == 'qnnpack':
102             per_channel = False
103 
104         W = torch.rand(out_features, in_features).float()
105         if per_channel:
106             scale_tensor = torch.ones(out_features, dtype=torch.double)
107             zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
108             for i in range(len(scale_tensor)):
109                 scale_tensor[i] = (i + 1.0) / 255.0
110             W_q = torch.quantize_per_channel(W, scales=scale_tensor,
111                                              zero_points=zero_point_tensor,
112                                              axis=0, dtype=torch.qint8)
113         else:
114             # ONEDNN only supports symmetric quantization of weight
115             W_zp = 0 if qengine_is_onednn() else 4
116             W_q = torch.quantize_per_tensor(W, 0.1, W_zp, torch.qint8)
117 
118         X = torch.rand(batch_size, in_features).float()
119         X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
120         B = torch.rand(out_features).float() if use_bias else None
121         scale = 0.5
122         zero_point = 3
123         qlinear = qlinear_module(in_features, out_features, **post_ops_kwargs)
124 
125         qlinear_copy = copy.deepcopy(qlinear)
126         # set random quantized weight and bias before test torch scriptable
127         qlinear_copy.set_weight_bias(W_q, B)
128         self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
129         # Run module with default-initialized parameters.
130         # This tests that the constructor is correct.
131         qlinear(X_q)
132 
133         qlinear.set_weight_bias(W_q, B)
134         # Simple round-trip test to ensure weight()/set_weight() API
135         self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)
136 
137         # testing packed param implementation
138         qlinear.scale = float(scale)
139         qlinear.zero_point = int(zero_point)
140         Z_q = qlinear(X_q)
141 
142         # Check if the module implementation matches calling the
143         # ops directly
144         W_pack = qlinear._packed_params._packed_params
145         Z_ref = qlinear_op(X_q, W_pack, scale, zero_point, **post_ops_kwargs)
146 
147         self.assertEqual(Z_ref, Z_q)
148         self.assertTrue(module_name in str(qlinear))
149 
150         # Test serialization of quantized Linear Module using state_dict
151         model_dict = qlinear.state_dict()
152         b = io.BytesIO()
153         torch.save(model_dict, b)
154         for weights_only in [True, False]:
155             b.seek(0)
156             loaded_dict = torch.load(b, weights_only=weights_only)
157             for key in model_dict:
158                 if isinstance(model_dict[key], torch._C.ScriptObject):
159                     assert isinstance(loaded_dict[key], torch._C.ScriptObject)
160                     w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
161                     w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
162                     self.assertEqual(w_model, w_loaded)
163                     self.assertEqual(b_model, b_loaded)
164                 else:
165                     self.assertEqual(model_dict[key], loaded_dict[key])
166 
167             loaded_qlinear = qlinear_module(
168                 in_features, out_features, **post_ops_kwargs)
169             loaded_qlinear.load_state_dict(loaded_dict)
170             linear_unpack = torch.ops.quantized.linear_unpack
171             self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
172                              linear_unpack(loaded_qlinear._packed_params._packed_params))
173             self.assertEqual(qlinear.scale, loaded_qlinear.scale)
174             self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
175             # scripting will add __overloads__ to __dict__, which is why we script a copy
176             # to be able to do the check in the next line
177             self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True)
178             self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
179             self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
180             self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
181             Z_q2 = loaded_qlinear(X_q)
182             self.assertEqual(Z_q, Z_q2)
183 
184         # Test serialization
185         b = io.BytesIO()
186         torch.save(qlinear, b)
187         b.seek(0)
188         # weights_only=False as this is legacy code that saves the model
189         loaded = torch.load(b, weights_only=False)
190         self.assertEqual(qlinear.weight(), loaded.weight())
191         self.assertEqual(qlinear.scale, loaded.scale)
192         self.assertEqual(qlinear.zero_point, loaded.zero_point)
193 
194         # Test torch.package
195         buffer = io.BytesIO()
196         with PackageExporter(buffer) as pe:
197             pe.save_pickle("module", "qlinear.pkl", qlinear)
198         buffer.seek(0)
199 
200         importer = PackageImporter(buffer)
201         loaded_from_package = importer.load_pickle("module", "qlinear.pkl")
202         self.assertEqual(qlinear.weight(), loaded_from_package.weight())
203         self.assertEqual(qlinear.scale, loaded_from_package.scale)
204         self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point)
205 
206         for name, module in loaded_from_package.named_modules():
207             # noop, just make sure attribute "_modules" is restored correctly during torch.package import
208             assert(name is not None)  # noqa: E275
209 
210         # Test copy and deepcopy
211         copied_linear = copy.copy(qlinear)
212         self.assertEqual(copied_linear.bias(), qlinear.bias())
213         self.assertEqual(copied_linear.scale, qlinear.scale)
214         self.assertEqual(copied_linear.zero_point,
215                          qlinear.zero_point)
216         Y_copied = copied_linear(X_q)
217         np.testing.assert_array_almost_equal(
218             Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
219 
220         deepcopied_linear = copy.deepcopy(qlinear)
221         self.assertEqual(deepcopied_linear.bias(), qlinear.bias())
222         self.assertEqual(deepcopied_linear.scale, qlinear.scale)
223         self.assertEqual(deepcopied_linear.zero_point,
224                          qlinear.zero_point)
225         Y_deepcopied = copied_linear(X_q)
226         np.testing.assert_array_almost_equal(
227             Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
228 
229         # Test JIT
230         self.checkScriptable(qlinear, [[X_q]], check_save_load=True)
231 
232         # Make sure `from_float` works for all linear variants
233         modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
234 
235         for mut in modules_under_test:
236             # Test from_float.
237             float_linear = mut(in_features, out_features).float()
238             float_linear.qconfig = torch.ao.quantization.default_qconfig
239             torch.ao.quantization.prepare(float_linear, inplace=True)
240             float_linear(X.float())
241             # Sequential allows swapping using "convert".
242             quantized_float_linear = torch.nn.Sequential(float_linear)
243             quantized_float_linear = torch.ao.quantization.convert(quantized_float_linear, inplace=True)
244 
245             # Smoke test to make sure the module actually runs
246             quantized_float_linear(X_q)
247 
248             # Smoke test extra_repr
249             self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
250 
251     def test_quant_dequant_api(self):
252         r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float)
253         scale, zero_point, dtype = 1.0, 2, torch.qint8
254         # testing Quantize API
255         qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
256         quant_m = nnq.Quantize(scale, zero_point, dtype)
257         qr2 = quant_m(r)
258         self.assertEqual(qr, qr2)
259         # testing Dequantize API
260         rqr = qr.dequantize()
261         dequant_m = nnq.DeQuantize()
262         rqr2 = dequant_m(qr2)
263         self.assertEqual(rqr, rqr2)
264 
265     def _test_conv_api_impl(
266             self, module_name, qconv_module, conv_module, batch_size,
267             in_channels_per_group, input_feature_map_size, out_channels_per_group,
268             groups, kernel_size, stride, padding, padding_mode, dilation,
269             X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
270             use_bias, post_op, use_channelwise, X2_scale=1.0, X2_zero_point=0):
271         for i in range(len(kernel_size)):
272             assume(input_feature_map_size[i] + 2 * padding[i]
273                    >= dilation[i] * (kernel_size[i] - 1) + 1)
274 
275         in_channels = in_channels_per_group * groups
276         out_channels = out_channels_per_group * groups
277         (X, X_q, W, W_q, b) = _make_conv_test_input(
278             batch_size, in_channels_per_group, input_feature_map_size,
279             out_channels_per_group, groups, kernel_size, X_scale, X_zero_point,
280             W_scale, W_zero_point, use_bias, use_channelwise)
281         example_input = [X, ]
282         example_input_q = [X_q, ]
283 
284         if post_op in ["add", "add_relu"]:
285             X2, X2_q = _make_conv_add_extra_input_tensor(X2_scale, X2_zero_point, conv_module[0](X).size())
286             example_input = [X, X2]
287             example_input_q = [X_q, X2_q]
288 
289         # Make sure the weight shape is correct
290         self.assertTrue(qconv_module.weight().shape == W_q.shape)
291 
292         qconv_module.set_weight_bias(W_q, b)
293         qconv_module.scale = Y_scale
294         qconv_module.zero_point = Y_zero_point
295 
296         raw_conv_module = conv_module[0] if post_op in ["relu", "add", "add_relu"] else conv_module
297         raw_conv_module.weight.data = W
298         if use_bias:
299             raw_conv_module.bias.data = b
300 
301         # Test members
302         self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name())
303         self.assertTrue(hasattr(qconv_module, '_packed_params'))
304         self.assertTrue(hasattr(qconv_module, 'scale'))
305         self.assertTrue(hasattr(qconv_module, 'zero_point'))
306 
307         # Test properties
308         self.assertEqual(W_q, qconv_module.weight())
309         if use_bias:
310             self.assertEqual(b, qconv_module.bias())
311         self.assertEqual(Y_scale, qconv_module.scale)
312         self.assertEqual(Y_zero_point, qconv_module.zero_point)
313 
314         # Test forward
315         Y_exp = conv_module(*example_input)
316         Y_exp = torch.quantize_per_tensor(
317             Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8)
318         Y_act = qconv_module(*example_input_q)
319 
320         # Make sure the results match
321         # assert_array_almost_equal compares using the following formula:
322         #     abs(desired-actual) < 1.5 * 10**(-decimal)
323         # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
324         # We use decimal = 0 to ignore off-by-1 differences between reference
325         # and test. Off-by-1 differences arise due to the order of round and
326         # zero_point addition operation, i.e., if addition followed by round is
327         # used by reference and round followed by addition is used by test, the
328         # results may differ by 1.
329         # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
330         # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
331         # skip numerics checking for reference module
332         np.testing.assert_array_almost_equal(
333             Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
334 
335         # Test serialization of quantized Conv Module using state_dict
336         model_dict = qconv_module.state_dict()
337         self.assertEqual(model_dict['weight'], W_q)
338         if use_bias:
339             self.assertEqual(model_dict['bias'], b)
340         bytes_io = io.BytesIO()
341         torch.save(model_dict, bytes_io)
342         for weights_only in [True, False]:
343             bytes_io.seek(0)
344             loaded_dict = torch.load(bytes_io, weights_only=weights_only)
345             for key in loaded_dict:
346                 self.assertEqual(model_dict[key], loaded_dict[key])
347             loaded_qconv_module = type(qconv_module)(
348                 in_channels, out_channels, kernel_size, stride, padding, dilation,
349                 groups, use_bias, padding_mode=padding_mode)
350             loaded_qconv_module.load_state_dict(loaded_dict)
351 
352             self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
353             self.assertTrue(module_name == loaded_qconv_module._get_name())
354             self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
355             self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))
356 
357             self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
358             if use_bias:
359                 self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias())
360             self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
361             self.assertEqual(qconv_module.zero_point,
362                              loaded_qconv_module.zero_point)
363             Y_loaded = loaded_qconv_module(*example_input_q)
364             np.testing.assert_array_almost_equal(
365                 Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
366 
367         # Test serialization
368         b = io.BytesIO()
369         torch.save(qconv_module, b)
370         b.seek(0)
371         # weights_only=False as this is legacy code that saves the model
372         loaded_conv = torch.load(b, weights_only=False)
373 
374         self.assertEqual(loaded_conv.bias(), qconv_module.bias())
375         self.assertEqual(loaded_conv.scale, qconv_module.scale)
376         self.assertEqual(loaded_conv.zero_point,
377                          qconv_module.zero_point)
378 
379         # Test copy and deepcopy
380         copied_conv = copy.copy(qconv_module)
381         self.assertEqual(copied_conv.bias(), qconv_module.bias())
382         self.assertEqual(copied_conv.scale, qconv_module.scale)
383         self.assertEqual(copied_conv.zero_point,
384                          qconv_module.zero_point)
385         Y_copied = copied_conv(*example_input_q)
386         np.testing.assert_array_almost_equal(
387             Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
388 
389         deepcopied_conv = copy.deepcopy(qconv_module)
390         self.assertEqual(deepcopied_conv.bias(), qconv_module.bias())
391         self.assertEqual(deepcopied_conv.scale, qconv_module.scale)
392         self.assertEqual(deepcopied_conv.zero_point,
393                          qconv_module.zero_point)
394         Y_deepcopied = deepcopied_conv(*example_input_q)
395         np.testing.assert_array_almost_equal(
396             Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
397 
398         # JIT testing
399         self.checkScriptable(
400             qconv_module, [example_input_q],
401             check_save_load=True)
402 
403         class _FusedModule_two_input_args(torch.ao.nn.intrinsic._FusedModule):
404             # Help Module for ConvAdd2d since torch.ao.nn.intrinsic._FusedModule only support one input arg
405             def forward(self, x1, x2):
406                 input = self[0](x1, x2)
407                 return input
408 
409         # Test from_float
410         fused_conv_module = _FusedModule_two_input_args(conv_module) \
411             if post_op in ["add", "add_relu"] else torch.ao.nn.intrinsic._FusedModule(conv_module)
412 
413         fused_conv_module.qconfig = torch.ao.quantization.default_qconfig
414         torch.ao.quantization.prepare(fused_conv_module, inplace=True)
415         example_input[0] = example_input[0].float()
416         fused_conv_module(*example_input)
417         converted_qconv_module = fused_conv_module
418         reference_mapping = get_default_static_quant_module_mappings()
419         reference_mapping[type(conv_module)] = type(qconv_module)
420         torch.ao.quantization.convert(converted_qconv_module, mapping=reference_mapping, inplace=True)
421 
422         # Smoke test to make sure the module actually runs
423         if use_bias:
424             self.assertEqual(conv_module[0].bias if (post_op in ["relu", "add", "add_relu"]) else conv_module.bias,
425                              converted_qconv_module[0].bias())
426         # Smoke test extra_repr
427         self.assertTrue(module_name == converted_qconv_module[0]._get_name())
428 
429     @override_qengines
430     def test_conv1d_api(self):
431         options = itertools.product(
432             ["zeros", "reflect"],  # pad_mode
433             [True, False],  # use_bias
434             [True, False],  # use_channelwise
435         )
436         for pad_mode, use_bias, use_channelwise in options:
437             if torch.backends.quantized.engine == "qnnpack":
438                 use_channelwise = False
439             batch_size = 2
440             in_channels_per_group = 2
441             length = 8
442             out_channels_per_group = 2
443             groups = 3
444             kernel = 3
445             stride = 2
446             pad = 1
447             dilation = 1
448             # Tests the correctness of the conv2d module.
449             in_channels = in_channels_per_group * groups
450             out_channels = out_channels_per_group * groups
451             input_feature_map_size = (length,)
452             kernel_size = (kernel, )
453             stride = (stride, )
454             pad = (pad, )
455             dilation = (dilation, )
456             X_scale = 1.3
457             X_zero_point = 2
458             W_scale = [0.5]
459             W_zero_point = [0] if qengine_is_onednn() else [3]
460             Y_scale = 5.0
461             Y_zero_point = 4
462             if torch.backends.quantized.engine == 'qnnpack':
463                 use_channelwise = False
464             qconv_cls = nnq.Conv1d
465             module_name = "QuantizedConv1d"
466             qconv_module = qconv_cls(
467                 in_channels, out_channels, kernel, stride, pad,
468                 dilation, groups, use_bias, padding_mode=pad_mode
469             )
470 
471             conv_module = nn.Conv1d(
472                 in_channels, out_channels, kernel, stride, pad,
473                 dilation, groups, use_bias, padding_mode=pad_mode)
474             conv_module = conv_module.float()
475 
476             self._test_conv_api_impl(
477                 module_name, qconv_module, conv_module, batch_size,
478                 in_channels_per_group, input_feature_map_size,
479                 out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
480                 dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
481                 Y_zero_point, use_bias, "none", use_channelwise)
482 
483     @override_qengines
484     def test_conv1d_relu_api(self):
485         options = itertools.product(
486             ["zeros", "reflect"],  # pad_mode
487             [True, False],  # use_bias
488             [True, False],  # use_channelwise
489         )
490         batch_size = 2
491         in_channels_per_group = 2
492         length = 8
493         out_channels_per_group = 2
494         groups = 3
495         kernel = 3
496         stride = 2
497         pad = 1
498         dilation = 1
499         # Tests the correctness of the conv2d module.
500         in_channels = in_channels_per_group * groups
501         out_channels = out_channels_per_group * groups
502         input_feature_map_size = (length,)
503         kernel_size = (kernel, )
504         stride = (stride, )
505         pad = (pad, )
506         dilation = (dilation, )
507         X_scale = 1.3
508         X_zero_point = 2
509         W_scale = [0.5]
510         W_zero_point = [0] if qengine_is_onednn() else [3]
511         Y_scale = 5.0
512         Y_zero_point = 4
513         qconv_cls = nniq.ConvReLU1d
514         module_name = "QuantizedConvReLU1d"
515         for pad_mode, use_bias, use_channelwise in options:
516             if torch.backends.quantized.engine == 'qnnpack':
517                 use_channelwise = False
518             qconv_module = qconv_cls(
519                 in_channels, out_channels, kernel, stride, pad,
520                 dilation, groups, use_bias, padding_mode=pad_mode
521             )
522 
523             conv_module = nn.Conv1d(
524                 in_channels, out_channels, kernel, stride, pad,
525                 dilation, groups, use_bias, padding_mode=pad_mode)
526             relu_module = nn.ReLU()
527             conv_module = nni.ConvReLU1d(conv_module, relu_module)
528             conv_module = conv_module.float()
529 
530             self._test_conv_api_impl(
531                 module_name, qconv_module, conv_module, batch_size,
532                 in_channels_per_group, input_feature_map_size,
533                 out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
534                 dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
535                 Y_zero_point, use_bias, "relu", use_channelwise)
536 
537     @override_qengines
538     def test_conv2d_api(self):
539         options = itertools.product(
540             ["zeros", "reflect"],  # pad_mode
541             [True, False],  # use_bias
542             [True, False],  # use_channelwise
543         )
544         for pad_mode, use_bias, use_channelwise in options:
545             if torch.backends.quantized.engine == "qnnpack":
546                 use_channelwise = False
547             batch_size = 2
548             in_channels_per_group = 2
549             H = 8
550             W = 8
551             out_channels_per_group = 2
552             groups = 3
553             kernel_h = 3
554             kernel_w = 3
555             stride_h = 2
556             stride_w = 2
557             pad_h = 1
558             pad_w = 1
559             dilation = 1
560             # Tests the correctness of the conv2d module.
561             in_channels = in_channels_per_group * groups
562             out_channels = out_channels_per_group * groups
563             input_feature_map_size = (H, W)
564             kernel_size = (kernel_h, kernel_w)
565             stride = (stride_h, stride_w)
566             padding = (pad_h, pad_w)
567             dilation = (dilation, dilation)
568             X_scale = 1.3
569             X_zero_point = 2
570             W_scale = [0.5]
571             W_zero_point = [0] if qengine_is_onednn() else [3]
572             Y_scale = 5.0
573             Y_zero_point = 4
574             qconv_cls = nnq.Conv2d
575             module_name = "QuantizedConv2d"
576             qconv_module = qconv_cls(
577                 in_channels, out_channels, kernel_size, stride, padding,
578                 dilation, groups, use_bias, padding_mode=pad_mode
579             )
580 
581             conv_module = nn.Conv2d(
582                 in_channels, out_channels, kernel_size, stride, padding,
583                 dilation, groups, use_bias, padding_mode=pad_mode)
584             conv_module = conv_module.float()
585 
586             self._test_conv_api_impl(
587                 module_name, qconv_module, conv_module, batch_size,
588                 in_channels_per_group, input_feature_map_size,
589                 out_channels_per_group, groups, kernel_size, stride, padding,
590                 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
591                 Y_scale, Y_zero_point, use_bias, "none", use_channelwise)
592 
593     @override_qengines
594     def test_conv2d_relu_api(self):
595         options = itertools.product(
596             ["zeros", "reflect"],  # pad_mode
597             [True, False],  # use_bias
598             [True, False],  # use_channelwise
599         )
600         batch_size = 2
601         in_channels_per_group = 2
602         H = 8
603         W = 8
604         out_channels_per_group = 2
605         groups = 3
606         kernel_h = 3
607         kernel_w = 3
608         stride_h = 2
609         stride_w = 2
610         pad_h = 1
611         pad_w = 1
612         dilation = 1
613         # Tests the correctness of the conv2d module.
614         in_channels = in_channels_per_group * groups
615         out_channels = out_channels_per_group * groups
616         input_feature_map_size = (H, W)
617         kernel_size = (kernel_h, kernel_w)
618         stride = (stride_h, stride_w)
619         padding = (pad_h, pad_w)
620         dilation = (dilation, dilation)
621         X_scale = 1.3
622         X_zero_point = 2
623         W_scale = [0.5]
624         W_zero_point = [0] if qengine_is_onednn() else [3]
625         Y_scale = 5.0
626         Y_zero_point = 4
627         qconv_cls = nniq.ConvReLU2d
628         module_name = "QuantizedConvReLU2d"
629         for pad_mode, use_bias, use_channelwise in options:
630             if torch.backends.quantized.engine == "qnnpack":
631                 use_channelwise = False
632             qconv_module = qconv_cls(
633                 in_channels, out_channels, kernel_size, stride, padding,
634                 dilation, groups, use_bias, padding_mode=pad_mode
635             )
636 
637             conv_module = nn.Conv2d(
638                 in_channels, out_channels, kernel_size, stride, padding,
639                 dilation, groups, use_bias, padding_mode=pad_mode)
640             relu_module = nn.ReLU()
641             conv_module = nni.ConvReLU2d(conv_module, relu_module)
642             conv_module = conv_module.float()
643 
644             self._test_conv_api_impl(
645                 module_name, qconv_module, conv_module, batch_size,
646                 in_channels_per_group, input_feature_map_size,
647                 out_channels_per_group, groups, kernel_size, stride, padding,
648                 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
649                 Y_scale, Y_zero_point, use_bias, "relu", use_channelwise)
650 
651     @skipIfNoFBGEMM
652     def test_conv3d_api(self):
653         options = itertools.product(
654             [True, False],  # use_bias
655             [True, False],  # use_channelwise
656         )
657         batch_size = 2
658         in_channels_per_group = 2
659         H = 8
660         W = 8
661         D = 8
662         out_channels_per_group = 2
663         groups = 3
664         kernel_h = 3
665         kernel_w = 3
666         kernel_d = 3
667         stride_h = 2
668         stride_w = 2
669         stride_d = 2
670         pad_mode = "zeros"  # 3d doesn't support reflect padding
671         pad_h = 1
672         pad_w = 1
673         pad_d = 1
674         dilation = 1
675         # Tests the correctness of the conv3d module.
676         in_channels = in_channels_per_group * groups
677         out_channels = out_channels_per_group * groups
678         input_feature_map_size = (D, H, W)
679         kernel_size = (kernel_d, kernel_h, kernel_w)
680         stride = (stride_d, stride_h, stride_w)
681         padding = (pad_d, pad_h, pad_w)
682         dilation = (dilation, dilation, dilation)
683         X_scale = 1.3
684         X_zero_point = 2
685         W_scale = [0.5]
686         W_zero_point = [0] if qengine_is_onednn() else [3]
687         Y_scale = 5.0
688         Y_zero_point = 4
689         qconv_cls = nnq.Conv3d
690         module_name = "QuantizedConv3d"
691         for use_bias, use_channelwise in options:
692             if torch.backends.quantized.engine == "qnnpack":
693                 use_channelwise = False
694             with override_quantized_engine('fbgemm'):
695                 qconv_module = qconv_cls(
696                     in_channels, out_channels, kernel_size, stride, padding,
697                     dilation, groups, use_bias, padding_mode=pad_mode
698                 )
699 
700                 conv_module = nn.Conv3d(
701                     in_channels, out_channels, kernel_size, stride, padding,
702                     dilation, groups, use_bias, padding_mode=pad_mode)
703                 conv_module = conv_module.float()
704 
705                 self._test_conv_api_impl(
706                     module_name, qconv_module, conv_module, batch_size,
707                     in_channels_per_group, input_feature_map_size,
708                     out_channels_per_group, groups, kernel_size, stride, padding,
709                     pad_mode, dilation, X_scale, X_zero_point, W_scale,
710                     W_zero_point, Y_scale, Y_zero_point, use_bias, "none",
711                     use_channelwise)
712 
713     @skipIfNoFBGEMM
714     def test_conv3d_relu_api(self):
715         options = itertools.product(
716             [True, False],  # use_bias
717             [True, False],  # use_channelwise
718         )
719         batch_size = 2
720         in_channels_per_group = 2
721         H = 8
722         W = 8
723         D = 8
724         out_channels_per_group = 2
725         groups = 3
726         kernel_h = 3
727         kernel_w = 3
728         kernel_d = 3
729         stride_h = 2
730         stride_w = 2
731         stride_d = 2
732         pad_mode = "zeros"  # 3d doesn't support reflect padding
733         pad_h = 1
734         pad_w = 1
735         pad_d = 1
736         dilation = 1
737         # Tests the correctness of the conv3d module.
738         in_channels = in_channels_per_group * groups
739         out_channels = out_channels_per_group * groups
740         input_feature_map_size = (D, H, W)
741         kernel_size = (kernel_d, kernel_h, kernel_w)
742         stride = (stride_d, stride_h, stride_w)
743         padding = (pad_d, pad_h, pad_w)
744         dilation = (dilation, dilation, dilation)
745         X_scale = 1.3
746         X_zero_point = 2
747         W_scale = [0.5]
748         W_zero_point = [0] if qengine_is_onednn() else [3]
749         Y_scale = 5.0
750         Y_zero_point = 4
751         qconv_cls = nniq.ConvReLU3d
752         module_name = "QuantizedConvReLU3d"
753         for use_bias, use_channelwise in options:
754             if torch.backends.quantized.engine == "qnnpack":
755                 use_channelwise = False
756             with override_quantized_engine('fbgemm'):
757                 qconv_module = qconv_cls(
758                     in_channels, out_channels, kernel_size, stride, padding,
759                     dilation, groups, use_bias, padding_mode=pad_mode
760                 )
761 
762                 conv_module = nn.Conv3d(
763                     in_channels, out_channels, kernel_size, stride, padding,
764                     dilation, groups, use_bias, padding_mode=pad_mode)
765                 relu_module = nn.ReLU()
766                 conv_module = nni.ConvReLU3d(conv_module, relu_module)
767                 conv_module = conv_module.float()
768 
769                 self._test_conv_api_impl(
770                     module_name, qconv_module, conv_module, batch_size,
771                     in_channels_per_group, input_feature_map_size,
772                     out_channels_per_group, groups, kernel_size, stride, padding,
773                     pad_mode, dilation, X_scale, X_zero_point, W_scale,
774                     W_zero_point, Y_scale, Y_zero_point, use_bias, "relu",
775                     use_channelwise)
776 
777     @skipIfNoONEDNN
778     def test_conv2d_add(self):
779         """test API functionality for nn.intrinsic.quantized.ConvAdd2d"""
780         with override_quantized_engine('onednn'):
781             options = itertools.product(
782                 ["zeros", "reflect"],  # pad_mode
783                 [True, False],  # use_bias
784                 [True, False],  # use_channelwise
785             )
786             batch_size = 2
787             in_channels_per_group = 2
788             H = 8
789             W = 8
790             out_channels_per_group = 2
791             groups = 3
792             kernel_h = 3
793             kernel_w = 3
794             stride_h = 2
795             stride_w = 2
796             pad_h = 1
797             pad_w = 1
798             dilation = 1
799             # Tests the correctness of the conv2d module.
800             in_channels = in_channels_per_group * groups
801             out_channels = out_channels_per_group * groups
802             input_feature_map_size = (H, W)
803             kernel_size = (kernel_h, kernel_w)
804             stride = (stride_h, stride_w)
805             padding = (pad_h, pad_w)
806             dilation = (dilation, dilation)
807             X_scale = 1.3
808             X_zero_point = 2
809             X2_scale = 1.2
810             X2_zero_point = 1
811             W_scale = [0.5]
812             W_zero_point = [0] if qengine_is_onednn() else [3]
813             Y_scale = 5.0
814             Y_zero_point = 4
815             qconv_cls = nniq.ConvAdd2d
816             module_name = "QuantizedConvAdd2d"
817             for pad_mode, use_bias, use_channelwise in options:
818                 qconv_module = qconv_cls(
819                     in_channels, out_channels, kernel_size, stride, padding,
820                     dilation, groups, use_bias, padding_mode=pad_mode
821                 )
822 
823                 conv_module = nn.Conv2d(
824                     in_channels, out_channels, kernel_size, stride, padding,
825                     dilation, groups, use_bias, padding_mode=pad_mode)
826                 conv_module = torch.ao.nn.intrinsic.ConvAdd2d(conv_module, torch.add)
827                 conv_module = conv_module.float()
828 
829                 self._test_conv_api_impl(
830                     module_name, qconv_module, conv_module, batch_size,
831                     in_channels_per_group, input_feature_map_size,
832                     out_channels_per_group, groups, kernel_size, stride, padding,
833                     pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
834                     Y_scale, Y_zero_point, use_bias, "add", use_channelwise, X2_scale, X2_zero_point)
835 
836     @skipIfNoONEDNN
837     def test_conv2d_add_relu(self):
838         """test API functionality for nn.intrinsic.quantized.ConvAdd2d"""
839         with override_quantized_engine('onednn'):
840             options = itertools.product(
841                 ["zeros", "reflect"],  # pad_mode
842                 [True, False],  # use_bias
843                 [True, False],  # use_channelwise
844             )
845             batch_size = 2
846             in_channels_per_group = 2
847             H = 8
848             W = 8
849             out_channels_per_group = 2
850             groups = 3
851             kernel_h = 3
852             kernel_w = 3
853             stride_h = 2
854             stride_w = 2
855             pad_h = 1
856             pad_w = 1
857             dilation = 1
858             # Tests the correctness of the conv2d module.
859             in_channels = in_channels_per_group * groups
860             out_channels = out_channels_per_group * groups
861             input_feature_map_size = (H, W)
862             kernel_size = (kernel_h, kernel_w)
863             stride = (stride_h, stride_w)
864             padding = (pad_h, pad_w)
865             dilation = (dilation, dilation)
866             X_scale = 1.3
867             X_zero_point = 2
868             X2_scale = 1.2
869             X2_zero_point = 1
870             W_scale = [0.5]
871             W_zero_point = [0] if qengine_is_onednn() else [3]
872             Y_scale = 5.0
873             Y_zero_point = 4
874             qconv_cls = nniq.ConvAddReLU2d
875             module_name = "QuantizedConvAddReLU2d"
876             for pad_mode, use_bias, use_channelwise in options:
877                 qconv_module = qconv_cls(
878                     in_channels, out_channels, kernel_size, stride, padding,
879                     dilation, groups, use_bias, padding_mode=pad_mode
880                 )
881 
882                 conv_module = nn.Conv2d(
883                     in_channels, out_channels, kernel_size, stride, padding,
884                     dilation, groups, use_bias, padding_mode=pad_mode)
885                 conv_module = torch.ao.nn.intrinsic.ConvAddReLU2d(conv_module, torch.add, nn.ReLU())
886                 conv_module = conv_module.float()
887 
888                 self._test_conv_api_impl(
889                     module_name, qconv_module, conv_module, batch_size,
890                     in_channels_per_group, input_feature_map_size,
891                     out_channels_per_group, groups, kernel_size, stride, padding,
892                     pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
893                     Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, X2_scale, X2_zero_point)
894 
895     def test_pool_api(self):
896         """Tests the correctness of the pool module.
897         The correctness is defined against the functional implementation.
898         """
899         N, C, H, W = 10, 10, 10, 3
900         kwargs = {
901             'kernel_size': 2,
902             'stride': None,
903             'padding': 0,
904             'dilation': 1
905         }
906 
907         scale, zero_point = 1.0 / 255, 128
908 
909         X = torch.randn(N, C, H, W, dtype=torch.float32)
910         qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
911                                        dtype=torch.quint8)
912         qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs)
913 
914         pool_under_test = torch.ao.nn.quantized.MaxPool2d(**kwargs)
915         qX_hat = pool_under_test(qX)
916         self.assertEqual(qX_expect, qX_hat)
917 
918         # JIT Testing
919         self.checkScriptable(pool_under_test, [[X]])
920 
921     def test_dropout(self):
922         """Tests the correctness of the dropout module.
923         The correctness is defined against the functional implementation.
924         """
925         x = torch.randn((2, 4, 6, 8), dtype=torch.float)
926         float_mod = torch.nn.Dropout(p=0.5)
927         float_mod.training = False
928 
929         y_ref = float_mod(x)
930         quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
931 
932         quant_mod = nnq.Dropout(p=0.5)
933         qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
934         qy = quant_mod(qx)
935 
936         self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
937                          msg="Dropout module API failed")
938 
939     def _test_dropout_serialization(self, get_model, data1, data2):
940         m1 = get_model()
941         m1.qconfig = torch.ao.quantization.default_qconfig
942         mp1 = torch.ao.quantization.prepare(m1)
943         mp1(data1)
944         mq1 = torch.ao.quantization.convert(mp1)
945         ref1 = mq1(data2)
946 
947         m2 = get_model()
948         m2.qconfig = torch.ao.quantization.default_qconfig
949         mp2 = torch.ao.quantization.prepare(m2)
950         mq2 = torch.ao.quantization.convert(mp2)
951 
952         mq2.load_state_dict(mq1.state_dict())
953         ref2 = mq2(data2)
954 
955         self.assertTrue(torch.allclose(ref1, ref2))
956 
957     def test_dropout_serialization(self):
958         data1 = torch.randn(2, 4, 6, 8)
959         data2 = torch.randn(2, 4, 6, 8)
960 
961         def _get_model():
962             return nn.Sequential(
963                 torch.ao.quantization.QuantStub(),
964                 nn.Dropout(p=0.5),
965                 torch.ao.quantization.DeQuantStub()
966             ).eval()
967 
968         self._test_dropout_serialization(_get_model, data1, data2)
969 
970 
971 
972     def test_batch_norm2d(self):
973         """Tests the correctness of the batchnorm2d module.
974         The correctness is defined against the functional implementation.
975         """
976         x = torch.randn((2, 4, 6, 8), dtype=torch.float)
977         float_mod = torch.nn.BatchNorm2d(4)
978         float_mod.training = False
979 
980         y_ref = float_mod(x)
981         quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
982 
983         quant_mod = nnq.BatchNorm2d(4)
984         qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
985         qy = quant_mod(qx)
986 
987         self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
988                          msg="BatchNorm2d module API failed")
989 
990     def test_batch_norm3d(self):
991         """Tests the correctness of the batchnorm3d module.
992         The correctness is defined against the functional implementation.
993         """
994         x = torch.randn((2, 4, 6, 8, 10), dtype=torch.float)
995         float_mod = torch.nn.BatchNorm3d(4)
996         float_mod.training = False
997 
998         y_ref = float_mod(x)
999         quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
1000 
1001         quant_mod = nnq.BatchNorm3d(4)
1002         qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
1003         qy = quant_mod(qx)
1004 
1005         self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
1006                          msg="BatchNorm3d module API failed")
1007 
1008     def _test_batch_norm_serialization(self, get_model, data1, data2):
1009         m1 = get_model()
1010         m1.qconfig = torch.ao.quantization.default_qconfig
1011         mp1 = torch.ao.quantization.prepare(m1)
1012         mp1(data1)
1013         mq1 = torch.ao.quantization.convert(mp1)
1014         ref1 = mq1(data2)
1015 
1016         m2 = get_model()
1017         m2.qconfig = torch.ao.quantization.default_qconfig
1018         mp2 = torch.ao.quantization.prepare(m2)
1019         mq2 = torch.ao.quantization.convert(mp2)
1020 
1021         mq2.load_state_dict(mq1.state_dict())
1022         ref2 = mq2(data2)
1023 
1024         self.assertTrue(torch.allclose(ref1, ref2))
1025 
1026     def test_batch_norm2d_serialization(self):
1027         data1 = torch.randn(2, 4, 6, 8)
1028         data2 = torch.randn(2, 4, 6, 8)
1029 
1030         def _get_model():
1031             return nn.Sequential(
1032                 torch.ao.quantization.QuantStub(),
1033                 nn.BatchNorm2d(4),
1034                 torch.ao.quantization.DeQuantStub()
1035             ).eval()
1036 
1037         self._test_batch_norm_serialization(_get_model, data1, data2)
1038 
1039     def test_batch_norm3d_serialization(self):
1040         data1 = torch.randn(2, 4, 6, 8, 1)
1041         data2 = torch.randn(2, 4, 6, 8, 1)
1042 
1043         def _get_model():
1044             return nn.Sequential(
1045                 torch.ao.quantization.QuantStub(),
1046                 nn.BatchNorm3d(4),
1047                 torch.ao.quantization.DeQuantStub()
1048             ).eval()
1049 
1050         self._test_batch_norm_serialization(_get_model, data1, data2)
1051 
1052     def test_layer_norm(self):
1053         """Tests the correctness of the layernorm module.
1054         The correctness is defined against the functional implementation.
1055         """
1056         x_scale = 10.0 / 256
1057         x_zero_point = 0
1058         y_scale = 5.0 / 256
1059         y_zero_point = 127
1060 
1061         dims = (1, 4, 8)
1062 
1063         X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1064         qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1065         dqX = qX.dequantize()
1066 
1067         float_mod = torch.nn.LayerNorm(dqX.size()[1:]).float()
1068         float_mod.weight = torch.nn.Parameter(torch.rand(*dims[1:]))
1069         float_mod.bias = torch.nn.Parameter(torch.rand(*dims[1:]))
1070 
1071         dqY_ref = float_mod(dqX)
1072         qY_ref = torch.quantize_per_tensor(
1073             dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1074 
1075         quant_mod = nnq.LayerNorm(
1076             qX.size()[1:], float_mod.weight, float_mod.bias, y_scale, y_zero_point)
1077         qY = quant_mod(qX)
1078 
1079         self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1080                          msg=f"LayerNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1081 
1082     def test_group_norm(self):
1083         """Tests the correctness of the groupnorm module.
1084         The correctness is defined against the functional implementation.
1085         """
1086         x_scale = 10.0 / 256
1087         x_zero_point = 0
1088         y_scale = 5.0 / 256
1089         y_zero_point = 127
1090 
1091         dims = (1, 4, 8)
1092 
1093         X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1094         qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1095         dqX = qX.dequantize()
1096 
1097         float_mod = torch.nn.GroupNorm(2, 4).float()
1098         float_mod.weight = torch.nn.Parameter(torch.rand(dims[1]))
1099         float_mod.bias = torch.nn.Parameter(torch.rand(dims[1]))
1100 
1101         dqY_ref = float_mod(dqX)
1102         qY_ref = torch.quantize_per_tensor(
1103             dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1104 
1105         quant_mod = nnq.GroupNorm(
1106             2, 2, float_mod.weight, float_mod.bias, y_scale, y_zero_point)
1107         qY = quant_mod(qX)
1108 
1109         self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1110                          msg=f"GroupNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1111 
1112     def test_instance_norm(self):
1113         """Tests the correctness of the instancenorm{n}d modules.
1114         The correctness is defined against the functional implementation.
1115         """
1116         x_scale = 10.0 / 256
1117         x_zero_point = 0
1118         y_scale = 5.0 / 256
1119         y_zero_point = 127
1120 
1121         dims_to_modules = [
1122             ((1, 4, 8), torch.nn.InstanceNorm1d, nnq.InstanceNorm1d),
1123             ((1, 4, 8, 1), torch.nn.InstanceNorm2d, nnq.InstanceNorm2d),
1124             ((1, 4, 8, 1, 1), torch.nn.InstanceNorm3d, nnq.InstanceNorm3d),
1125         ]
1126 
1127         for dim_to_modules in dims_to_modules:
1128             dims, float_cls, q_cls = dim_to_modules
1129 
1130             X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1131             qX = torch.quantize_per_tensor(
1132                 X, x_scale, x_zero_point, dtype=torch.quint8)
1133             dqX = qX.dequantize()
1134 
1135             float_mod = float_cls(dims[1]).float()
1136             float_mod.weight = torch.nn.Parameter(torch.rand(dims[1]))
1137             float_mod.bias = torch.nn.Parameter(torch.rand(dims[1]))
1138 
1139             dqY_ref = float_mod(dqX)
1140             qY_ref = torch.quantize_per_tensor(
1141                 dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1142 
1143             quant_mod = q_cls(
1144                 dims[1], float_mod.weight, float_mod.bias, y_scale,
1145                 y_zero_point)
1146             qY = quant_mod(qX)
1147 
1148             self.assertEqual(
1149                 qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1150                 msg=f"InstanceNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1151 
1152     def _test_activation_module_impl(self, name, float_module_class, quantized_module_class, extra_kwargs):
1153         """Tests the correctness of the ELU module.
1154         The correctness is defined against the functional implementation.
1155         """
1156         x_scale = 10.0 / 256
1157         x_zero_point = 0
1158         y_scale = 5.0 / 256
1159         y_zero_point = 127
1160         alpha = 1.5
1161 
1162         dims = (1, 4, 8)
1163 
1164         X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1165         qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1166         dqX = qX.dequantize()
1167 
1168         float_mod = float_module_class(**extra_kwargs).float()
1169 
1170         dqY_ref = float_mod(dqX)
1171         qY_ref = torch.quantize_per_tensor(
1172             dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1173 
1174         quant_mod = quantized_module_class(y_scale, y_zero_point, **extra_kwargs)
1175         qY = quant_mod(qX)
1176         self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1177                          msg=f"{name} module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1178 
1179     def _test_leaky_relu_serialization(self):
1180         scale_original = 10.0 / 256
1181         zero_point_original = 1.0
1182 
1183         quant_mod_original = nnq.LeakyReLU(scale_original, zero_point_original)
1184         state_dict = quant_mod_original.state_dict()
1185 
1186         scale_new = 5.0 / 256
1187         zero_point_new = 2.0
1188         quant_mod_new = nnq.LeakyReLU(scale_new, zero_point_new)
1189         quant_mod_new.load_state_dict(state_dict)
1190 
1191         self.assertEqual(quant_mod_original.scale, quant_mod_new.scale)
1192         self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point)
1193 
1194     def test_elu(self):
1195         """Tests the correctness of the ELU module.
1196         The correctness is defined against the functional implementation.
1197         """
1198         self._test_activation_module_impl("ELU", nn.ELU, nnq.ELU, {"alpha": 1.5})
1199 
1200     def test_leaky_relu(self):
1201         self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2})
1202         self._test_leaky_relu_serialization()
1203 
1204     def test_sigmoid(self):
1205         self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {})
1206 
1207     def _test_hard_swish_serialization(self):
1208         scale_original = 10.0 / 256
1209         zero_point_original = 1.0
1210 
1211         quant_mod_original = nnq.Hardswish(scale_original, zero_point_original)
1212         state_dict = quant_mod_original.state_dict()
1213 
1214         scale_new = 5.0 / 256
1215         zero_point_new = 2.0
1216         quant_mod_new = nnq.Hardswish(scale_new, zero_point_new)
1217         quant_mod_new.load_state_dict(state_dict)
1218 
1219         self.assertEqual(quant_mod_original.scale, quant_mod_new.scale)
1220         self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point)
1221 
1222     def test_hard_swish(self):
1223         self._test_activation_module_impl("Hardswish", nn.Hardswish, nnq.Hardswish, {})
1224         self._test_hard_swish_serialization()
1225 
1226     @given(
1227         num_embeddings=st.integers(10, 50),
1228         embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
1229         set_qconfig=st.booleans(),
1230     )
1231     @skipIfNoFBGEMM
1232     def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig):
1233         num_lengths = np.random.randint(1, 6)
1234         lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
1235         num_indices = np.sum(lengths)
1236         indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
1237         weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))
1238 
1239         obs = default_float_qparams_observer()
1240         obs(weights)
1241         qparams = obs.calculate_qparams()
1242 
1243         dtypes = [torch.quint4x2, torch.quint8]
1244         embedding_funcs = [torch.ops.quantized.embedding_4bit, torch.ops.quantized.embedding_byte]
1245 
1246         for dtype, embedding_func in zip(dtypes, embedding_funcs):
1247             # Quantize the weights
1248             qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype)
1249             qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
1250             qemb.set_weight(qweight)
1251             qemb(indices)
1252 
1253             # Ensure the module has the correct weights
1254             self.assertEqual(qweight, qemb.weight())
1255             w_packed = qemb._packed_params._packed_weight
1256             module_out = qemb(indices)
1257 
1258             # Call the bit qembedding operator directly
1259             ref = embedding_func(w_packed, indices, pruned_weights=False)
1260             self.assertEqual(module_out, ref)
1261             self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False,
1262                                              is_emb_bag=False, dtype=dtype)
1263 
1264     @given(
1265         num_embeddings=st.integers(10, 50),
1266         embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
1267         num_offsets=st.integers(1, 20),
1268         set_qconfig=st.booleans(),
1269     )
1270     @skipIfNoFBGEMM
1271     def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
1272         r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
1273         """
1274 
1275         num_lengths = np.random.randint(1, 6)
1276         lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
1277         num_indices = np.sum(lengths)
1278         indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
1279 
1280         offsets = lengths_to_offsets(lengths)
1281         # include the last offset
1282         offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
1283         weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))
1284 
1285         for qdtype in [torch.quint8, torch.quint4x2]:
1286             obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
1287             obs(weights)
1288             # Get the scale and zero point for the weight tensor
1289             qparams = obs.calculate_qparams()
1290             # Quantize the weights to 8bits
1291             qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
1292             qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
1293                                     include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype)
1294             qemb(indices, offsets)
1295 
1296             # Ensure the module has the correct weights
1297             self.assertEqual(qweight, qemb.weight())
1298 
1299             w_packed = qemb._packed_params._packed_weight
1300             module_out = qemb(indices, offsets)
1301 
1302             # Call the qembedding_bag operator directly
1303             if qdtype == torch.quint8:
1304                 ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
1305                                                              per_sample_weights=None,
1306                                                              include_last_offset=True)
1307             else:
1308                 ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0,
1309                                                              per_sample_weights=None,
1310                                                              include_last_offset=True)
1311 
1312             self.assertEqual(module_out, ref)
1313             self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices,
1314                                              offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
1315 
1316     def test_prelu(self):
1317         for num_parameters in range(1, 10):
1318             x = torch.randn(4, num_parameters, 4)
1319             qx = torch.quantize_per_tensor_dynamic(x, dtype=torch.quint8, reduce_range=False)
1320 
1321 
1322             f_prelu = torch.nn.PReLU(num_parameters=num_parameters)
1323             f_prelu.weight = torch.nn.Parameter(torch.randn(num_parameters).abs())
1324             f_prelu.qconfig = torch.ao.quantization.QConfig(
1325                 activation=torch.ao.quantization.default_observer,
1326                 weight=torch.ao.quantization.default_observer,)
1327             f_prelu.activation_post_process = f_prelu.qconfig.activation()
1328             f_prelu.activation_post_process(f_prelu(x))
1329             q_prelu = nnq.PReLU.from_float(f_prelu)
1330             w_obs = f_prelu.qconfig.weight()
1331             w_obs(f_prelu.weight)
1332             w_scale, w_zp = w_obs.calculate_qparams()
1333             q_prelu_weight = torch.quantize_per_tensor(
1334                 f_prelu.weight,
1335                 dtype=torch.quint8,
1336                 scale=w_scale,
1337                 zero_point=w_zp
1338             ).dequantize()
1339 
1340             # check that the weight makes sense
1341             self.assertEqual(q_prelu.weight.dequantize(), q_prelu_weight)
1342             f_prelu.weight = torch.nn.Parameter(q_prelu.weight.dequantize())
1343             qy = q_prelu(qx)
1344             qy_ref = torch.quantize_per_tensor(
1345                 f_prelu(qx.dequantize()), q_prelu.scale, q_prelu.zero_point, dtype=torch.quint8
1346             )
1347             # check that the output makes sense
1348             self.assertEqual(qy, qy_ref, atol=.1, rtol=.1)
1349 
1350     def test_channel_shuffle(self):
1351         """Tests the correctness of the ChannelShuffle module.
1352         """
1353         x_scale = 10.0 / 256
1354         x_zero_point = 1
1355         y_scale = x_scale
1356         y_zero_point = x_zero_point
1357 
1358         dims = (1, 4, 4, 8)
1359         groups = 2
1360 
1361         X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1362         qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1363         dqX = qX.dequantize()
1364 
1365         float_mod = torch.nn.ChannelShuffle(groups).float()
1366         dqY_ref = float_mod(dqX)
1367         qY_ref = torch.quantize_per_tensor(
1368             dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1369 
1370         quant_mod = torch.nn.ChannelShuffle(groups)
1371         qY = quant_mod(qX)
1372 
1373         self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1374                          msg=f"ChannelShuffle module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1375 
1376     @skipIfNoONEDNN
1377     def test_linear_leaky_relu(self):
1378         """test API functionality for nn.intrinsic.quantized.linear_leaky_relu"""
1379         with override_quantized_engine('onednn'):
1380             options = itertools.product(
1381                 [1, 5],  # batch size
1382                 [16, 32],  # in_features
1383                 [4, 8],  # out_features
1384                 [True, False],  # use_bias
1385                 [True, False],  # per_channel
1386                 [0.01, 0.05])  # negative slope
1387             for (batch_size, in_features, out_features, use_bias,
1388                  per_channel, neg_slope) in options:
1389                 self._test_linear_api_impl(
1390                     nniq.LinearLeakyReLU, 'QuantizedLinearLeakyReLU',
1391                     torch.ops.quantized.linear_leaky_relu,
1392                     batch_size, in_features, out_features, use_bias,
1393                     per_channel, negative_slope=neg_slope)
1394 
1395     @skipIfNoONEDNN
1396     def test_linear_tanh(self):
1397         """test API functionality for nn.intrinsic.quantized.linear_tanh"""
1398         with override_quantized_engine('onednn'):
1399             options = itertools.product(
1400                 [1, 5],  # batch size
1401                 [16, 32],  # in_features
1402                 [4, 8],  # out_features
1403                 [True, False],  # use_bias
1404                 [True, False])  # negative slope
1405             for (batch_size, in_features, out_features, use_bias,
1406                  per_channel) in options:
1407                 self._test_linear_api_impl(
1408                     nniq.LinearTanh, 'QuantizedLinearTanh',
1409                     torch.ops.quantized.linear_tanh,
1410                     batch_size, in_features, out_features, use_bias,
1411                     per_channel)
1412 
1413 class TestDynamicQuantizedModule(QuantizationTestCase):
1414     def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias):
1415         in_channels = 3
1416         out_channels = 10
1417         kernel_size = 2
1418         stride = 1
1419         padding = 0
1420         dilation = 1
1421         groups = 1
1422         padding_mode = 'zeros'
1423 
1424         if qengine_is_qnnpack():
1425             reduce_range = False
1426         else:
1427             reduce_range = True
1428 
1429         X_fp32 = torch.randn(*([in_channels] * dim))
1430         s, z = _calculate_dynamic_qparams(X_fp32, dtype, reduce_range)
1431         X_q = torch.quantize_per_tensor(X_fp32, s, z, dtype)
1432         X_dq = torch.dequantize(X_q)
1433 
1434         quantized_module = q_mod(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
1435                                  dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
1436         dynamic_module = dq_mod(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
1437                                 dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
1438 
1439         quantized_module.scale, quantized_module.zero_point = s, z
1440         dynamic_module.set_weight_bias(*quantized_module._weight_bias())
1441 
1442         Y_q_ref = quantized_module(X_q)
1443         Y_ref = torch.dequantize(Y_q_ref)
1444 
1445         Y = dynamic_module(X_dq, reduce_range)
1446 
1447         self.assertEqual(Y, Y_ref)
1448 
1449         # Test serialization of quantized Conv Module using state_dict
1450         W_q, b = dynamic_module._weight_bias()
1451         model_dict = dynamic_module.state_dict()
1452         self.assertEqual(model_dict['weight'], W_q)
1453         self.assertEqual(model_dict['bias'], b)
1454         bytes_io = io.BytesIO()
1455         torch.save(model_dict, bytes_io)
1456         for weights_only in [True, False]:
1457             bytes_io.seek(0)
1458             loaded_dict = torch.load(bytes_io, weights_only=weights_only)
1459             for key in loaded_dict:
1460                 self.assertEqual(model_dict[key], loaded_dict[key])
1461             loaded_qconv_module = type(dynamic_module)(
1462                 in_channels, out_channels, kernel_size, stride=stride, padding=padding,
1463                 dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
1464             loaded_qconv_module.load_state_dict(loaded_dict)
1465 
1466             self.assertTrue(dir(loaded_qconv_module) == dir(dynamic_module))
1467             self.assertTrue(dynamic_module._get_name() == loaded_qconv_module._get_name())
1468             self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
1469             self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))
1470 
1471             self.assertEqual(dynamic_module.weight(), loaded_qconv_module.weight())
1472             if bias:
1473                 self.assertEqual(dynamic_module.bias(), loaded_qconv_module.bias())
1474             self.assertEqual(dynamic_module.scale, loaded_qconv_module.scale)
1475             self.assertEqual(dynamic_module.zero_point,
1476                              loaded_qconv_module.zero_point)
1477             Y_loaded = loaded_qconv_module(X_fp32, reduce_range)
1478             np.testing.assert_array_almost_equal(
1479                 Y.numpy(), Y_loaded.numpy(), decimal=0)
1480 
1481         # Test serialization
1482         b = io.BytesIO()
1483         torch.save(dynamic_module, b)
1484         b.seek(0)
1485         # weights_only=False as this is legacy code that saves the model
1486         loaded_conv = torch.load(b, weights_only=False)
1487 
1488         self.assertEqual(loaded_conv.bias(), dynamic_module.bias())
1489         self.assertEqual(loaded_conv.scale, dynamic_module.scale)
1490         self.assertEqual(loaded_conv.zero_point,
1491                          dynamic_module.zero_point)
1492 
1493         # Test copy and deepcopy
1494         copied_conv = copy.copy(dynamic_module)
1495         self.assertEqual(copied_conv.bias(), dynamic_module.bias())
1496         self.assertEqual(copied_conv.scale, dynamic_module.scale)
1497         self.assertEqual(copied_conv.zero_point,
1498                          dynamic_module.zero_point)
1499         Y_copied = copied_conv(X_fp32, reduce_range)
1500         np.testing.assert_array_almost_equal(
1501             Y.numpy(), Y_copied.numpy(), decimal=0)
1502 
1503         deepcopied_conv = copy.deepcopy(dynamic_module)
1504         self.assertEqual(deepcopied_conv.bias(), dynamic_module.bias())
1505         self.assertEqual(deepcopied_conv.scale, dynamic_module.scale)
1506         self.assertEqual(deepcopied_conv.zero_point,
1507                          dynamic_module.zero_point)
1508         Y_deepcopied = copied_conv(X_fp32, reduce_range)
1509         np.testing.assert_array_almost_equal(
1510             Y.numpy(), Y_deepcopied.numpy(), decimal=0)
1511 
1512         # need to fix this
1513         # JIT testing
1514         self.checkScriptable(
1515             dynamic_module, [[X_dq]],
1516             check_save_load=True)
1517 
1518         # Test from_float
1519         conv_module = dynamic_module._FLOAT_MODULE(in_channels, out_channels, kernel_size)
1520         conv_module.qconfig = torch.ao.quantization.default_dynamic_qconfig  # type: ignore[assignment]
1521         prepare_dynamic(conv_module)
1522         conv_module(X_dq)
1523         quantized_conv_module = dq_mod.from_float(conv_module)
1524 
1525         # Smoke test to make sure the module actually runs
1526         quantized_conv_module(X_dq)
1527 
1528         # Smoke test extra_repr
1529         self.assertEqual(dynamic_module._get_name(), quantized_conv_module._get_name())
1530 
1531     @override_qengines
1532     def test_dynamic_conv1d(self):
1533         q_mod = torch.ao.nn.quantized.Conv1d
1534         dq_mod = torch.ao.nn.quantized.dynamic.Conv1d
1535         dim = 3
1536         dtype = torch.quint8
1537 
1538         for bias in [True, False]:
1539             self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1540 
1541     @override_qengines
1542     def test_dynamic_conv2d(self):
1543         q_mod = torch.ao.nn.quantized.Conv2d
1544         dq_mod = torch.ao.nn.quantized.dynamic.Conv2d
1545         dim = 4
1546         dtype = torch.quint8
1547 
1548         for bias in [True, False]:
1549             self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1550 
1551     @override_qengines
1552     def test_dynamic_conv3d(self):
1553         q_mod = torch.ao.nn.quantized.Conv3d
1554         dq_mod = torch.ao.nn.quantized.dynamic.Conv3d
1555         dim = 5
1556         dtype = torch.quint8
1557 
1558         if qengine_is_qnnpack():
1559             return  # qnnpack doesn't support unpacking conv3d
1560         for bias in [True, False]:
1561             self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1562 
1563     @override_qengines
1564     def test_dynamic_convtranspose1d(self):
1565         q_mod = torch.ao.nn.quantized.ConvTranspose1d
1566         dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose1d
1567         dim = 3
1568         dtype = torch.quint8
1569 
1570         for bias in [True, False]:
1571             self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1572 
1573     @override_qengines
1574     def test_dynamic_convtranspose2d(self):
1575         q_mod = torch.ao.nn.quantized.ConvTranspose2d
1576         dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose2d
1577         dim = 4
1578         dtype = torch.quint8
1579 
1580         for bias in [True, False]:
1581             self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1582 
1583     @override_qengines
1584     def test_dynamic_convtranspose3d(self):
1585         q_mod = torch.ao.nn.quantized.ConvTranspose3d
1586         dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose3d
1587         dim = 5
1588         dtype = torch.quint8
1589 
1590         if qengine_is_qnnpack():
1591             return  # qnnpack doesn't support unpacking conv3d
1592         for bias in [True, False]:
1593             self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1594 
1595     @given(
1596         batch_size=st.integers(1, 5),
1597         in_features=st.integers(16, 32),
1598         out_features=st.integers(4, 8),
1599         use_bias=st.booleans(),
1600         use_default_observer=st.booleans(),
1601     )
1602     @override_qengines
1603     def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer):
1604         """test API functionality for nn.quantized.dynamic.Linear"""
1605         W = torch.rand(out_features, in_features).float()
1606         qscheme = torch.per_tensor_symmetric if qengine_is_onednn() else torch.per_tensor_affine
1607         W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8, qscheme=qscheme)
1608         W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
1609         X = torch.rand(batch_size, in_features).float()
1610         B = torch.rand(out_features).float() if use_bias else None
1611         qlinear = nnqd.Linear(in_features, out_features)
1612         # Run module with default-initialized parameters.
1613         # This tests that the constructor is correct.
1614         qlinear.set_weight_bias(W_q, B)
1615         qlinear(X)
1616 
1617         # Simple round-trip test to ensure weight()/set_weight() API
1618         self.assertEqual(qlinear.weight(), W_q)
1619         W_pack = qlinear._packed_params._packed_params
1620         Z_dq = qlinear(X)
1621 
1622         # Check if the module implementation matches calling the
1623         # ops directly
1624         Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack, reduce_range=True)
1625         self.assertEqual(Z_ref, Z_dq)
1626 
1627         # Test serialization of dynamic quantized Linear Module using state_dict
1628         model_dict = qlinear.state_dict()
1629         b = io.BytesIO()
1630         torch.save(model_dict, b)
1631         for weights_only in [True, False]:
1632             b.seek(0)
1633             loaded_dict = torch.load(b, weights_only=weights_only)
1634             for key in model_dict:
1635                 if isinstance(model_dict[key], torch._C.ScriptObject):
1636                     assert isinstance(loaded_dict[key], torch._C.ScriptObject)
1637                     w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
1638                     w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
1639                     self.assertEqual(w_model, w_loaded)
1640                     self.assertEqual(b_model, b_loaded)
1641                 else:
1642                     self.assertEqual(model_dict[key], loaded_dict[key])
1643             loaded_qlinear = nnqd.Linear(in_features, out_features)
1644             loaded_qlinear.load_state_dict(loaded_dict)
1645 
1646             linear_unpack = torch.ops.quantized.linear_unpack
1647             self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
1648                              linear_unpack(loaded_qlinear._packed_params._packed_params))
1649             if use_bias:
1650                 self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
1651             self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
1652             self.assertTrue(hasattr(qlinear, '_packed_params'))
1653             self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
1654             self.assertTrue(hasattr(qlinear, '_weight_bias'))
1655             self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
1656 
1657             self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
1658             self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
1659             Z_dq2 = qlinear(X)
1660             self.assertEqual(Z_dq, Z_dq2)
1661 
1662         b = io.BytesIO()
1663         torch.save(qlinear, b)
1664         b.seek(0)
1665         # weights_only=False as this is legacy code that saves the model
1666         loaded = torch.load(b, weights_only=False)
1667         self.assertEqual(qlinear.weight(), loaded.weight())
1668         self.assertEqual(qlinear.zero_point, loaded.zero_point)
1669 
1670         # Test JIT
1671         self.checkScriptable(qlinear, [[X]], check_save_load=True)
1672 
1673         modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
1674         for mut in modules_under_test:
1675             # Test from_float
1676             float_linear = mut(in_features, out_features).float()
1677             if use_default_observer:
1678                 float_linear.qconfig = torch.ao.quantization.default_dynamic_qconfig
1679             prepare_dynamic(float_linear)
1680             float_linear(X.float())
1681             quantized_float_linear = nnqd.Linear.from_float(float_linear)
1682 
1683             # Smoke test to make sure the module actually runs
1684             quantized_float_linear(X)
1685 
1686         # Smoke test extra_repr
1687         self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
1688 
1689     @given(
1690         dtype=st.sampled_from([torch.qint8, torch.float16]),
1691         bidirectional=st.booleans(),
1692     )
1693     @override_qengines
1694     def test_lstm_api(self, dtype, bidirectional):
1695         r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
1696         """
1697         # Check that module matches the numerics of the op and ensure that module can be
1698         # instantiated for all engines and dtypes
1699         seq_len = 4
1700         batch = 2
1701         input_size = 3
1702         hidden_size = 7
1703         num_layers = 2
1704         bias = True
1705         weight_keys = []
1706         bias_keys = []
1707         num_directions = 2 if bidirectional else 1
1708         for layer in range(num_layers):
1709             for direction in range(num_directions):
1710                 suffix = '_reverse' if direction == 1 else ''
1711                 key_name1 = f'weight_ih_l{layer}{suffix}'
1712                 key_name2 = f'weight_hh_l{layer}{suffix}'
1713                 weight_keys.append(key_name1)
1714                 weight_keys.append(key_name2)
1715                 key_name1 = f'bias_ih_l{layer}{suffix}'
1716                 key_name2 = f'bias_hh_l{layer}{suffix}'
1717                 bias_keys.append(key_name1)
1718                 bias_keys.append(key_name2)
1719 
1720         if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
1721             # fp16 dynamic quant is not supported for qnnpack or onednn
1722             x = torch.randn(seq_len, batch, input_size)
1723             h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1724             c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1725             cell_dq = torch.ao.nn.quantized.dynamic.LSTM(input_size=input_size,
1726                                                          hidden_size=hidden_size,
1727                                                          num_layers=num_layers,
1728                                                          bias=bias,
1729                                                          batch_first=False,
1730                                                          dropout=0.0,
1731                                                          bidirectional=bidirectional,
1732                                                          dtype=dtype)
1733             ref_dq = torch.ao.nn.quantized.dynamic.LSTM(input_size=input_size,
1734                                                         hidden_size=hidden_size,
1735                                                         num_layers=num_layers,
1736                                                         bias=bias,
1737                                                         batch_first=False,
1738                                                         dropout=0.0,
1739                                                         bidirectional=bidirectional,
1740                                                         dtype=dtype)
1741 
1742             _all_params = ([m.param for m in cell_dq._all_weight_values])
1743             result = torch.quantized_lstm(x, (h, c),
1744                                           _all_params,
1745                                           cell_dq.bias,
1746                                           cell_dq.num_layers,
1747                                           float(cell_dq.dropout),
1748                                           False,
1749                                           bidirectional,
1750                                           False,
1751                                           dtype=dtype,
1752                                           use_dynamic=True)
1753 
1754 
1755             y, (h, c) = cell_dq(x, (h, c))
1756             self.assertEqual(result[0], y)
1757             self.assertEqual(result[1], h)
1758             self.assertEqual(result[2], c)
1759             x = torch.randn(10, 20, 3)
1760             self.check_eager_serialization(cell_dq, ref_dq, [x])
1761             self.check_weight_bias_api(cell_dq, weight_keys, bias_keys)
1762 
1763     @override_qengines
1764     def test_gru_api(self):
1765         r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
1766         """
1767         # Check that module matches the numerics of the op and ensure that module can be
1768         # instantiated for all engines and dtypes
1769 
1770         for dtype in [torch.qint8, torch.float16]:
1771             if dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn"):
1772                 # fp16 dynamic quant is not supported for qnnpack or onednn
1773                 continue
1774                 # Test default instantiation
1775             seq_len = 4
1776             batch = 2
1777             input_size = 3
1778             hidden_size = 7
1779             num_layers = 2
1780             bias = True
1781             bidirectional = False
1782 
1783             x = torch.rand(seq_len, batch, input_size)
1784             h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)
1785 
1786 
1787             cell_dq = torch.ao.nn.quantized.dynamic.GRU(input_size=input_size,
1788                                                         hidden_size=hidden_size,
1789                                                         num_layers=num_layers,
1790                                                         bias=bias,
1791                                                         batch_first=False,
1792                                                         dropout=0.0,
1793                                                         bidirectional=bidirectional,
1794                                                         dtype=dtype)
1795 
1796             _all_params = ([m.param for m in cell_dq._all_weight_values])
1797             result = torch.quantized_gru(x,
1798                                          h,
1799                                          _all_params,
1800                                          cell_dq.bias,
1801                                          cell_dq.num_layers,
1802                                          float(cell_dq.dropout),
1803                                          False,
1804                                          bidirectional,
1805                                          False)
1806 
1807 
1808             y, h = cell_dq(x, h)
1809             self.assertEqual(result[0], y, msg="GRU module API failed")
1810             self.assertEqual(result[1], h, msg="GRU module API failed")
1811 
1812     @given(
1813         dtype=st.sampled_from([torch.qint8, torch.float16]),
1814     )
1815     @override_qengines
1816     def test_cell_api(self, dtype):
1817         r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
1818         """
1819         # Check that module matches the numerics of the op and ensure that module can be
1820         # instantiated for all engines and dtypes
1821         batch = 7
1822         input_size = 3
1823         hidden_size = 7
1824         bias = True
1825 
1826         x = torch.rand(batch, input_size)
1827         h = torch.rand(batch, hidden_size)
1828         cell_dict = {'LSTMCell': torch.ao.nn.quantized.dynamic.LSTMCell,
1829                      'GRUCell': torch.ao.nn.quantized.dynamic.GRUCell,
1830                      'RNNTanh': torch.ao.nn.quantized.dynamic.RNNCell,
1831                      'RNNReLU': torch.ao.nn.quantized.dynamic.RNNCell
1832                      }
1833         state = {'LSTMCell': (h, h),
1834                  'GRUCell': h,
1835                  'RNNTanh': h,
1836                  'RNNReLU': h}
1837 
1838         qfn_dict = {'LSTMCell': torch.ops.quantized.quantized_lstm_cell_dynamic,
1839                     'GRUCell': torch.ops.quantized.quantized_gru_cell_dynamic,
1840                     'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic,
1841                     'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic}
1842 
1843         for rnn_type in cell_dict.keys():
1844             if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
1845                 # fp16 dynamic quant is not supported for qnnpack or onednn
1846                 kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype}
1847                 if rnn_type == 'RNNReLU':
1848                     kwargs['nonlinearity'] = "relu"
1849                 elif rnn_type == 'RNNTanh':
1850                     kwargs['nonlinearity'] = "tanh"
1851 
1852                 cell_dq = cell_dict[rnn_type](**kwargs)
1853                 result = qfn_dict[rnn_type](x, state[rnn_type],
1854                                             cell_dq._packed_weight_ih, cell_dq._packed_weight_hh,
1855                                             cell_dq.bias_ih, cell_dq.bias_hh)
1856                 result_module = cell_dq(x, state[rnn_type])
1857                 self.assertEqual(result[0], result_module[0], msg="RNNCell module API failed")
1858                 self.assertEqual(result[1], result_module[1], msg="RNNCell module API failed")
1859                 weight_keys = ['weight_ih', 'weight_hh']
1860                 bias_keys = ['bias_ih', 'bias_hh']
1861                 self.check_eager_serialization(cell_dq, cell_dict[rnn_type](**kwargs), [x])
1862                 self.check_weight_bias_api(cell_dq, weight_keys, bias_keys)
1863 
1864 class TestReferenceQuantizedModule(QuantizationTestCase):
1865     def _quant_dequant_weight(self, weight, weight_qparams):
1866         qscheme = weight_qparams["qscheme"]
1867         scale = weight_qparams["scale"]
1868         zero_point = weight_qparams["zero_point"]
1869         dtype = weight_qparams["dtype"]
1870         if qscheme == torch.per_tensor_affine:
1871             weight = torch.quantize_per_tensor(weight, scale, zero_point, dtype)
1872         else:
1873             # per channel affine
1874             axis = weight_qparams["axis"]
1875             weight = torch.quantize_per_channel(weight, scale, zero_point, axis, dtype)
1876         weight = weight.dequantize()
1877         return weight
1878 
1879     # TODO: add tests for conv and linear
1880     def test_rnn_cell(self):
1881         """ Checks the rnn cell reference quantized modules has correct numerics
1882         This includes LSTMCell, GRUCell, RNNCell
1883         """
1884         batch = 7
1885         input_size = 3
1886         hidden_size = 7
1887         bias = True
1888 
1889         x = torch.rand(batch, input_size)
1890         h = torch.rand(batch, hidden_size)
1891         cell_dict = {'LSTMCell': torch.nn.LSTMCell,
1892                      'GRUCell': torch.nn.GRUCell,
1893                      'RNNTanh': torch.nn.RNNCell,
1894                      'RNNReLU': torch.nn.RNNCell
1895                      }
1896         state = {'LSTMCell': (h, h),
1897                  'GRUCell': h,
1898                  'RNNTanh': h,
1899                  'RNNReLU': h}
1900 
1901         qfn_dict = {'LSTMCell': nnqr.LSTMCell,
1902                     'GRUCell': nnqr.GRUCell,
1903                     'RNNTanh': nnqr.RNNCell,
1904                     'RNNReLU': nnqr.RNNCell}
1905 
1906         for rnn_type in cell_dict.keys():
1907             kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias}
1908             if rnn_type == 'RNNReLU':
1909                 kwargs['nonlinearity'] = "relu"
1910             elif rnn_type == 'RNNTanh':
1911                 kwargs['nonlinearity'] = "tanh"
1912 
1913             fp_cell = cell_dict[rnn_type](**kwargs)
1914             # initialize ref rnn cell module
1915             weight_qparams = {
1916                 'qscheme': torch.per_tensor_affine,
1917                 'dtype': torch.quint8,
1918                 'scale': 2.0,
1919                 'zero_point': 5
1920             }
1921             weight_qparams_dict = {
1922                 "weight_ih": weight_qparams,
1923                 "weight_hh": weight_qparams,
1924                 "is_decomposed": False,
1925             }
1926             ref_kwargs = kwargs.copy()
1927             ref_kwargs["weight_qparams_dict"] = weight_qparams_dict
1928             ref_cell = qfn_dict[rnn_type](**ref_kwargs)
1929             # reassign the weights from fp32 rnn cell modulea
1930             ref_cell.weight_ih = fp_cell.weight_ih
1931             ref_cell.weight_hh = fp_cell.weight_hh
1932             ref_cell.bias_ih = fp_cell.bias_ih
1933             ref_cell.bias_hh = fp_cell.bias_hh
1934 
1935             ref_res = ref_cell(x, state[rnn_type])
1936 
1937             # change the weight of fp_res, we first want to run a quantie and
1938             # dequantize on the weight
1939             fp_cell.weight_ih = torch.nn.Parameter(self._quant_dequant_weight(fp_cell.weight_ih, weight_qparams_dict["weight_ih"]))
1940             fp_cell.weight_hh = torch.nn.Parameter(self._quant_dequant_weight(fp_cell.weight_hh, weight_qparams_dict["weight_hh"]))
1941             fp_res = fp_cell(x, state[rnn_type])
1942             self.assertEqual(ref_res[0], fp_res[0], msg="RNNCell module API failed")
1943             self.assertEqual(ref_res[1], fp_res[1], msg="RNNCell module API failed")
1944 
1945     def test_rnn(self):
1946         """ Checks the rnn reference quantized modules has correct numerics
1947         This includes LSTM
1948         """
1949         seq_len = 4
1950         batch = 2
1951         input_size = 3
1952         hidden_size = 7
1953         num_layers = 2
1954         bias = True
1955         for bidirectional in [True, False]:
1956             x = torch.randn(seq_len, batch, input_size)
1957             h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1958             c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1959             fp32_rnn = torch.nn.LSTM(
1960                 input_size=input_size,
1961                 hidden_size=hidden_size,
1962                 num_layers=num_layers,
1963                 bias=bias,
1964                 batch_first=False,
1965                 dropout=0.0,
1966                 bidirectional=bidirectional)
1967             # initialize ref rnn module
1968             weight_qparams = {
1969                 "qscheme": torch.per_tensor_affine,
1970                 "dtype": torch.qint8,
1971                 "scale": 2.0,
1972                 "zero_point": 5
1973             }
1974             weight_qparams_dict = {key: weight_qparams for key in fp32_rnn._flat_weights_names if key.startswith("weight")}
1975             weight_qparams_dict["is_decomposed"] = False
1976             ref_rnn = nnqr.LSTM(
1977                 input_size=input_size,
1978                 hidden_size=hidden_size,
1979                 num_layers=num_layers,
1980                 bias=bias,
1981                 batch_first=False,
1982                 dropout=0.0,
1983                 bidirectional=bidirectional,
1984                 weight_qparams_dict=weight_qparams_dict)
1985             for wn in fp32_rnn._flat_weights_names:
1986                 setattr(ref_rnn, wn, copy.deepcopy(getattr(fp32_rnn, wn)))
1987 
1988             ref_rnn._flat_weights = copy.deepcopy(fp32_rnn._flat_weights)
1989 
1990             # quantize and dequantize the weights for fp32_rnn module
1991             flat_weights = []
1992             for wn in fp32_rnn._flat_weights_names:
1993                 if wn.startswith("weight"):
1994                     weight = self._quant_dequant_weight(getattr(fp32_rnn, wn), weight_qparams)
1995                 else:
1996                     weight = getattr(fp32_rnn, wn)
1997                 flat_weights.append(weight)
1998             fp32_rnn._flat_weights = flat_weights
1999 
2000             fp32_res = fp32_rnn(x, (h, c))
2001             ref_res = ref_rnn(x, (h, c))
2002             self.assertEqual(fp32_res, ref_res)
2003 
2004     def test_sparse(self):
2005         """ Embedding and EmbeddingBag
2006         """
2007         num_embeddings = 10
2008         embedding_dim = 3
2009         # embedding input
2010         ex = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
2011 
2012         # embedding bag input
2013         ebx = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
2014         offsets = torch.tensor([0, 4], dtype=torch.long)
2015 
2016         fp_to_ref = {
2017             nn.Embedding: (nnqr.Embedding, (ex,)),
2018             nn.EmbeddingBag: (nnqr.EmbeddingBag, (ebx, offsets)),
2019         }
2020 
2021         per_tensor_weight_qparams = {
2022             "qscheme": torch.per_tensor_affine,
2023             "dtype": torch.quint8,
2024             "scale": 2.0,
2025             "zero_point": 5,
2026             "is_decomposed": False,
2027         }
2028 
2029         per_channel_weight_qparams = {
2030             "qscheme": torch.per_channel_affine,
2031             "dtype": torch.quint8,
2032             "scale": torch.randn(10),
2033             "zero_point": torch.randint(0, 255, (10,)),
2034             "axis": 0,
2035             "is_decomposed": False,
2036         }
2037 
2038         per_channel_weight_qparams_quint4x2 = {
2039             "qscheme": torch.per_channel_affine_float_qparams,
2040             "dtype": torch.quint4x2,
2041             "scale": torch.randn(10),
2042             "zero_point": torch.randint(0, 255, (10,)),
2043             "axis": 0,
2044             "is_decomposed": False,
2045         }
2046 
2047         weight_qparams_options = [
2048             per_tensor_weight_qparams,
2049             per_channel_weight_qparams,
2050             per_channel_weight_qparams_quint4x2,
2051         ]
2052         for fp_cls, weight_qparams in itertools.product([nn.Embedding, nn.EmbeddingBag], weight_qparams_options):
2053             # TODO: torch.quint4x2 not supported in quantize_per_channel, need to add support
2054             if weight_qparams["dtype"] == torch.quint4x2:
2055                 continue
2056             ref_cls, args = fp_to_ref[fp_cls]
2057 
2058             fp32_embedding = fp_cls(num_embeddings, embedding_dim)
2059 
2060             ref_embedding = ref_cls(num_embeddings, embedding_dim, weight_qparams=weight_qparams)
2061             ref_embedding.weight = fp32_embedding.weight
2062 
2063             # quantize and dequantize the weight for fp32 module
2064             fp32_embedding.weight = torch.nn.Parameter(self._quant_dequant_weight(fp32_embedding.weight, weight_qparams))
2065 
2066             fp32_res = fp32_embedding(*args)
2067             ref_res = ref_embedding(*args)
2068             self.assertEqual(fp32_res, ref_res)
2069 
2070     def test_linear_decomposed_weight_custom_qmin_qmax(self):
2071         """Verify that reference Linear respects custom qmin/qmax for weight
2072         """
2073         linear_fp32 = torch.nn.Linear(2, 2)
2074         qconfig = torch.ao.quantization.default_symmetric_qnnpack_qconfig
2075         w_obs = qconfig.weight()
2076         self.assertTrue(w_obs.quant_min == -127)
2077         self.assertTrue(w_obs.quant_max == 127)
2078         w_obs(linear_fp32.weight)
2079         weight_qparams = torch.ao.quantization.utils.get_qparam_dict(w_obs)
2080         weight_qparams["is_decomposed"] = True
2081         linear_ref = nnqr.Linear.from_float(linear_fp32, weight_qparams)
2082         linear_ref_traced = torch.fx.symbolic_trace(linear_ref)
2083 
2084         # verify that the qmin/qmax arguments for weight q/dq are correctly
2085         # taken from the observer
2086         found = 0
2087         for n in linear_ref_traced.graph.nodes:
2088             if n.op != 'call_function':
2089                 continue
2090             if n.target in (
2091                 torch.ops.quantized_decomposed.quantize_per_tensor,
2092                 torch.ops.quantized_decomposed.dequantize_per_tensor,
2093             ):
2094                 _0, _1, _2, qmin, qmax, _5 = n.args
2095                 self.assertTrue(qmin == -127)
2096                 self.assertTrue(qmax == 127)
2097                 found += 1
2098         self.assertTrue(found == 2)
2099