xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_representation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2import copy
3from typing import Any, Dict, Tuple
4
5import torch
6from torch._export import capture_pre_autograd_graph
7from torch._higher_order_ops.out_dtype import out_dtype  # noqa: F401
8from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
9from torch.ao.quantization.quantizer import Quantizer
10from torch.ao.quantization.quantizer.xnnpack_quantizer import (
11    get_symmetric_quantization_config,
12    XNNPACKQuantizer,
13)
14from torch.testing._internal.common_quantization import (
15    NodeSpec as ns,
16    QuantizationTestCase,
17    skipIfNoQNNPACK,
18    TestHelperModules,
19)
20
21
22@skipIfNoQNNPACK
23class TestPT2ERepresentation(QuantizationTestCase):
24    def _test_representation(
25        self,
26        model: torch.nn.Module,
27        example_inputs: Tuple[Any, ...],
28        quantizer: Quantizer,
29        ref_node_occurrence: Dict[ns, int],
30        non_ref_node_occurrence: Dict[ns, int],
31        fixed_output_tol: float = None,
32        output_scale_idx: int = 2,
33    ) -> torch.nn.Module:
34        # resetting dynamo cache
35        torch._dynamo.reset()
36        model = capture_pre_autograd_graph(
37            model,
38            example_inputs,
39        )
40        model_copy = copy.deepcopy(model)
41
42        model = prepare_pt2e(model, quantizer)
43        # Calibrate
44        model(*example_inputs)
45        model = convert_pt2e(model, use_reference_representation=True)
46        self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
47        # make sure it runs
48        pt2e_quant_output = model(*example_inputs)
49
50        # TODO: torchdynamo times out when we do this, we can enable numerical checking
51        # after that is fixed
52        model_copy = prepare_pt2e(model_copy, quantizer)
53        # Calibrate
54        model_copy(*example_inputs)
55        model_copy = convert_pt2e(model_copy, use_reference_representation=False)
56        self.checkGraphModuleNodes(
57            model_copy, expected_node_occurrence=non_ref_node_occurrence
58        )
59        pt2e_quant_output_copy = model_copy(*example_inputs)
60
61        output_tol = None
62        if fixed_output_tol is not None:
63            output_tol = fixed_output_tol
64        else:
65            idx = 0
66            for n in model_copy.graph.nodes:
67                if (
68                    n.target
69                    == torch.ops.quantized_decomposed.quantize_per_tensor.default
70                ):
71                    idx += 1
72                    if idx == output_scale_idx:
73                        output_tol = n.args[1]
74            assert output_tol is not None
75
76        # make sure the result is off by one at most in the quantized integer representation
77        self.assertTrue(
78            torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output))
79            <= (2 * output_tol + 1e-5)
80        )
81
82    def test_static_linear(self):
83        class M(torch.nn.Module):
84            def __init__(self) -> None:
85                super().__init__()
86                self.linear = torch.nn.Linear(5, 5)
87
88            def forward(self, x):
89                return self.linear(x)
90
91        quantizer = XNNPACKQuantizer()
92        operator_config = get_symmetric_quantization_config(is_per_channel=False)
93        quantizer.set_global(operator_config)
94        example_inputs = (torch.randn(2, 5),)
95
96        self._test_representation(
97            M().eval(),
98            example_inputs,
99            quantizer,
100            ref_node_occurrence={},
101            non_ref_node_occurrence={},
102        )
103
104    def test_dynamic_linear(self):
105        class M(torch.nn.Module):
106            def __init__(self) -> None:
107                super().__init__()
108                self.linear = torch.nn.Linear(5, 5)
109
110            def forward(self, x):
111                return self.linear(x)
112
113        quantizer = XNNPACKQuantizer()
114        operator_config = get_symmetric_quantization_config(
115            is_per_channel=False, is_dynamic=True
116        )
117        quantizer.set_global(operator_config)
118        example_inputs = (torch.randn(2, 5),)
119
120        self._test_representation(
121            M().eval(),
122            example_inputs,
123            quantizer,
124            ref_node_occurrence={},
125            non_ref_node_occurrence={},
126            fixed_output_tol=1e-4,
127        )
128
129    def test_conv2d(self):
130        class M(torch.nn.Module):
131            def __init__(self) -> None:
132                super().__init__()
133                self.conv2d = torch.nn.Conv2d(3, 3, 3)
134
135            def forward(self, x):
136                return self.conv2d(x)
137
138        quantizer = XNNPACKQuantizer()
139        operator_config = get_symmetric_quantization_config(is_per_channel=False)
140        quantizer.set_global(operator_config)
141        example_inputs = (torch.randn(1, 3, 3, 3),)
142
143        self._test_representation(
144            M().eval(),
145            example_inputs,
146            quantizer,
147            ref_node_occurrence={},
148            non_ref_node_occurrence={},
149        )
150
151    def test_add(self):
152        class M(torch.nn.Module):
153            def __init__(self) -> None:
154                super().__init__()
155
156            def forward(self, x, y):
157                return x + y
158
159        quantizer = XNNPACKQuantizer()
160        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
161        quantizer.set_global(quantization_config)
162        m_eager = M().eval()
163
164        example_inputs = (
165            torch.randn(1, 3, 3, 3),
166            torch.randn(1, 3, 3, 3),
167        )
168
169        self._test_representation(
170            M().eval(),
171            example_inputs,
172            quantizer,
173            ref_node_occurrence={},
174            non_ref_node_occurrence={},
175        )
176
177    def test_add_relu(self):
178        class M(torch.nn.Module):
179            def __init__(self) -> None:
180                super().__init__()
181
182            def forward(self, x, y):
183                out = x + y
184                out = torch.nn.functional.relu(out)
185                return out
186
187        quantizer = XNNPACKQuantizer()
188        operator_config = get_symmetric_quantization_config(is_per_channel=True)
189        quantizer.set_global(operator_config)
190
191        example_inputs = (
192            torch.randn(1, 3, 3, 3),
193            torch.randn(1, 3, 3, 3),
194        )
195        ref_node_occurrence = {
196            ns.call_function(out_dtype): 2,
197        }
198
199        self._test_representation(
200            M().eval(),
201            example_inputs,
202            quantizer,
203            ref_node_occurrence=ref_node_occurrence,
204            non_ref_node_occurrence={},
205        )
206
207    def test_maxpool2d(self):
208        quantizer = XNNPACKQuantizer()
209        operator_config = get_symmetric_quantization_config(is_per_channel=True)
210        quantizer.set_global(operator_config)
211        m_eager = TestHelperModules.ConvMaxPool2d().eval()
212
213        example_inputs = (torch.randn(1, 2, 2, 2),)
214
215        self._test_representation(
216            m_eager,
217            example_inputs,
218            quantizer,
219            ref_node_occurrence={},
220            non_ref_node_occurrence={},
221        )
222
223    def test_qdq_per_channel(self):
224        """Test representation for quantize_per_channel and dequantize_per_channel op"""
225
226        class M(torch.nn.Module):
227            def __init__(self) -> None:
228                super().__init__()
229                self.linear = torch.nn.Linear(5, 5)
230
231            def forward(self, x):
232                return self.linear(x)
233
234        quantizer = XNNPACKQuantizer()
235        # use per channel quantization for weight
236        operator_config = get_symmetric_quantization_config(is_per_channel=True)
237        quantizer.set_global(operator_config)
238        m_eager = M().eval()
239
240        inputs = [
241            (torch.randn(1, 5),),
242            (torch.randn(1, 3, 5),),
243            (torch.randn(1, 3, 3, 5),),
244            (torch.randn(1, 3, 3, 3, 5),),
245        ]
246        for example_inputs in inputs:
247            ref_node_occurrence = {
248                ns.call_function(
249                    torch.ops.quantized_decomposed.quantize_per_channel.default
250                ): 0,
251                ns.call_function(
252                    torch.ops.quantized_decomposed.dequantize_per_channel.default
253                ): 0,
254            }
255            non_ref_node_occurrence = {
256                # quantize_per_channel is folded
257                ns.call_function(
258                    torch.ops.quantized_decomposed.quantize_per_channel.default
259                ): 0,
260                ns.call_function(
261                    torch.ops.quantized_decomposed.dequantize_per_channel.default
262                ): 1,
263            }
264
265            self._test_representation(
266                M().eval(),
267                example_inputs,
268                quantizer,
269                ref_node_occurrence,
270                non_ref_node_occurrence,
271                output_scale_idx=2,
272            )
273
274    def test_qdq(self):
275        """Test representation for quantize and dequantize op"""
276
277        class M(torch.nn.Module):
278            def __init__(self) -> None:
279                super().__init__()
280
281            def forward(self, x, y):
282                return x + y
283
284        quantizer = XNNPACKQuantizer()
285        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
286        quantizer.set_global(quantization_config)
287        m_eager = M().eval()
288
289        example_inputs = (
290            torch.randn(1, 3, 3, 3),
291            torch.randn(1, 3, 3, 3),
292        )
293        ref_node_occurrence = {
294            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0,
295            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0,
296        }
297        non_ref_node_occurrence = {
298            ns.call_function(
299                torch.ops.quantized_decomposed.quantize_per_tensor.default
300            ): 3,
301            ns.call_function(
302                torch.ops.quantized_decomposed.dequantize_per_tensor.default
303            ): 3,
304        }
305        self._test_representation(
306            M().eval(),
307            example_inputs,
308            quantizer,
309            ref_node_occurrence,
310            non_ref_node_occurrence,
311        )
312