xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 # Copyright (c) Qualcomm Innovation Center, Inc.
2 # All rights reserved
3 #
4 # This source code is licensed under the BSD-style license found in the
5 # LICENSE file in the root directory of this source tree.
6 from enum import IntEnum, unique
7 from typing import Callable, Optional, Sequence, Set
8 
9 import torch
10 from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum
11 from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
12 from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
13     RecomposePixelUnshuffle,
14 )
15 from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange
16 from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer
17 from executorch.backends.transforms.decompose_sdpa import (
18     DecomposeScaledDotProductAttention,
19 )
20 
21 from torch._ops import OpOverload
22 from torch.ao.quantization.quantizer import Quantizer
23 from torch.fx import GraphModule
24 
25 from .annotators import OP_ANNOTATOR
26 
27 from .qconfig import (
28     get_16a16w_qnn_ptq_config,
29     get_16a4w_qnn_ptq_config,
30     get_16a4w_qnn_qat_config,
31     get_16a8w_qnn_ptq_config,
32     get_8a8w_qnn_ptq_config,
33     get_8a8w_qnn_qat_config,
34     get_ptq_per_channel_quant_config,
35     get_qat_per_channel_quant_config,
36     QuantizationConfig,
37 )
38 
39 # To bypass the meta internal test error
40 get_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config
41 
42 __all__ = [
43     "QnnQuantizer",
44     "QuantDtype",
45     "get_16a4w_qnn_ptq_config",
46     "get_16a8w_qnn_ptq_config",
47     "get_16a16w_qnn_ptq_config",
48     "get_8a8w_qnn_ptq_config",
49     "get_8a8w_qnn_qat_config",
50     "get_16a4w_qnn_qat_config",
51 ]
52 
53 
54 @unique
55 class QuantDtype(IntEnum):
56     """
57     bits of activation and bits of weight
58     """
59 
60     use_16a16w = 0
61     use_16a8w = 1
62     use_16a4w = 2
63     use_8a8w = 3
64 
65 
66 quant_config_dict = {
67     # PTQ
68     (QuantDtype.use_16a16w, False): (
69         get_16a16w_qnn_ptq_config,
70         get_ptq_per_channel_quant_config(torch.uint16, torch.int16),
71     ),
72     (QuantDtype.use_16a8w, False): (
73         get_16a8w_qnn_ptq_config,
74         get_ptq_per_channel_quant_config(torch.uint16, torch.int8),
75     ),
76     (QuantDtype.use_16a4w, False): (
77         get_16a4w_qnn_ptq_config,
78         get_ptq_per_channel_quant_config(torch.uint16, "int4"),
79     ),
80     (QuantDtype.use_8a8w, False): (
81         get_8a8w_qnn_ptq_config,
82         get_ptq_per_channel_quant_config(),
83     ),
84     # QAT,
85     (QuantDtype.use_16a4w, True): (
86         get_16a4w_qnn_qat_config,
87         get_qat_per_channel_quant_config(torch.uint16, "int4"),
88     ),
89     (QuantDtype.use_8a8w, True): (
90         get_8a8w_qnn_qat_config,
91         get_qat_per_channel_quant_config(),
92     ),
93 }
94 
95 
96 class QnnQuantizer(Quantizer):
97     SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
98 
99     def __init__(self):
100         super().__init__()
101         self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
102 
103         self.is_qat = False
104         self.quant_dtype = QuantDtype.use_8a8w
105         self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
106         self.per_channel_quant_config = get_ptq_per_channel_quant_config()
107         self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
108 
109         self.custom_quant_annotations: Sequence[Callable] = []
110         self.discard_nodes: Set[str] = set()
111 
112     def _annotate(self, gm: GraphModule) -> None:
113         for node in gm.graph.nodes:
114             if node.name in self.discard_nodes:
115                 continue
116 
117             quant_config = self._get_quant_config(node.target)
118             if quant_config:
119                 OP_ANNOTATOR[node.target](node, quant_config)
120 
121     def _annotate_custom_annotation(self, gm: GraphModule) -> None:
122         for annotation_func in self.custom_quant_annotations:
123             annotation_func(gm)
124 
125     def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
126         """
127         Priority:
128             1. is one of use_per_channel_weight_quant_ops
129             2. quant config
130         """
131         if isinstance(op, str):
132             return
133 
134         if op in self.use_per_channel_weight_quant_ops:
135             return self.per_channel_quant_config
136 
137         if op in self.quant_ops:
138             return self.quant_config
139 
140         print(f"No quant config is implemented for op, {op}")
141 
142     def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
143         if enable:
144             self.use_per_channel_weight_quant_ops.update(ops)
145         else:
146             self.use_per_channel_weight_quant_ops.difference_update(ops)
147 
148     def add_custom_quant_annotations(
149         self, custom_quant_annotations: Sequence[Callable]
150     ) -> None:
151         self.custom_quant_annotations = custom_quant_annotations
152 
153     def add_discard_nodes(self, nodes: Sequence[str]) -> None:
154         self.discard_nodes = set(nodes)
155 
156     def add_discard_ops(self, ops: Sequence[OpOverload]) -> None:
157         for op in ops:
158             self.quant_ops.remove(op)
159 
160     def annotate(self, model: GraphModule) -> GraphModule:
161         self._annotate(model)
162         self._annotate_custom_annotation(model)
163 
164         return model
165 
166     def get_supported_ops(self) -> Set[OpOverload]:
167         return self.SUPPORTED_OPS
168 
169     def set_quant_config(
170         self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
171     ) -> None:
172         self.quant_dtype = quant_dtype
173         self.is_qat = is_qat
174         if (quant_dtype, is_qat) not in quant_config_dict:
175             raise RuntimeError(
176                 f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
177             )
178 
179         quant_config_fuc, self.per_channel_quant_config = quant_config_dict[
180             (quant_dtype, is_qat)
181         ]
182         self.quant_config = (
183             quant_config_fuc(act_observer) if act_observer else quant_config_fuc()
184         )
185 
186     def set_per_channel_conv_quant(self, enable: bool) -> None:
187         conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
188         self._update_per_channel_weight_quant_ops(conv_ops, enable)
189 
190     def set_per_channel_linear_quant(self, enable: bool) -> None:
191         linear_ops = {
192             torch.ops.aten.linear.default,
193         }
194         self._update_per_channel_weight_quant_ops(linear_ops, enable)
195 
196     def transform_for_annotation(self, model: GraphModule) -> GraphModule:
197         model = ReduceDynamicRange()(model).graph_module
198         model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module
199         model = DecomposeScaledDotProductAttention()(model).graph_module
200         model = DecomposeSilu()(model).graph_module
201         model = DecomposeEinsum()(model).graph_module
202         model = ReplaceInfBuffer()(model).graph_module
203         return model
204 
205     def validate(self, model: GraphModule) -> None:
206         pass
207