xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum, unique
7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Optional, Sequence, Set
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum
11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
13*523fa7a6SAndroid Build Coastguard Worker    RecomposePixelUnshuffle,
14*523fa7a6SAndroid Build Coastguard Worker)
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.transforms.decompose_sdpa import (
18*523fa7a6SAndroid Build Coastguard Worker    DecomposeScaledDotProductAttention,
19*523fa7a6SAndroid Build Coastguard Worker)
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerfrom torch._ops import OpOverload
22*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import Quantizer
23*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerfrom .annotators import OP_ANNOTATOR
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerfrom .qconfig import (
28*523fa7a6SAndroid Build Coastguard Worker    get_16a16w_qnn_ptq_config,
29*523fa7a6SAndroid Build Coastguard Worker    get_16a4w_qnn_ptq_config,
30*523fa7a6SAndroid Build Coastguard Worker    get_16a4w_qnn_qat_config,
31*523fa7a6SAndroid Build Coastguard Worker    get_16a8w_qnn_ptq_config,
32*523fa7a6SAndroid Build Coastguard Worker    get_8a8w_qnn_ptq_config,
33*523fa7a6SAndroid Build Coastguard Worker    get_8a8w_qnn_qat_config,
34*523fa7a6SAndroid Build Coastguard Worker    get_ptq_per_channel_quant_config,
35*523fa7a6SAndroid Build Coastguard Worker    get_qat_per_channel_quant_config,
36*523fa7a6SAndroid Build Coastguard Worker    QuantizationConfig,
37*523fa7a6SAndroid Build Coastguard Worker)
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker# To bypass the meta internal test error
40*523fa7a6SAndroid Build Coastguard Workerget_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker__all__ = [
43*523fa7a6SAndroid Build Coastguard Worker    "QnnQuantizer",
44*523fa7a6SAndroid Build Coastguard Worker    "QuantDtype",
45*523fa7a6SAndroid Build Coastguard Worker    "get_16a4w_qnn_ptq_config",
46*523fa7a6SAndroid Build Coastguard Worker    "get_16a8w_qnn_ptq_config",
47*523fa7a6SAndroid Build Coastguard Worker    "get_16a16w_qnn_ptq_config",
48*523fa7a6SAndroid Build Coastguard Worker    "get_8a8w_qnn_ptq_config",
49*523fa7a6SAndroid Build Coastguard Worker    "get_8a8w_qnn_qat_config",
50*523fa7a6SAndroid Build Coastguard Worker    "get_16a4w_qnn_qat_config",
51*523fa7a6SAndroid Build Coastguard Worker]
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker@unique
55*523fa7a6SAndroid Build Coastguard Workerclass QuantDtype(IntEnum):
56*523fa7a6SAndroid Build Coastguard Worker    """
57*523fa7a6SAndroid Build Coastguard Worker    bits of activation and bits of weight
58*523fa7a6SAndroid Build Coastguard Worker    """
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker    use_16a16w = 0
61*523fa7a6SAndroid Build Coastguard Worker    use_16a8w = 1
62*523fa7a6SAndroid Build Coastguard Worker    use_16a4w = 2
63*523fa7a6SAndroid Build Coastguard Worker    use_8a8w = 3
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Workerquant_config_dict = {
67*523fa7a6SAndroid Build Coastguard Worker    # PTQ
68*523fa7a6SAndroid Build Coastguard Worker    (QuantDtype.use_16a16w, False): (
69*523fa7a6SAndroid Build Coastguard Worker        get_16a16w_qnn_ptq_config,
70*523fa7a6SAndroid Build Coastguard Worker        get_ptq_per_channel_quant_config(torch.uint16, torch.int16),
71*523fa7a6SAndroid Build Coastguard Worker    ),
72*523fa7a6SAndroid Build Coastguard Worker    (QuantDtype.use_16a8w, False): (
73*523fa7a6SAndroid Build Coastguard Worker        get_16a8w_qnn_ptq_config,
74*523fa7a6SAndroid Build Coastguard Worker        get_ptq_per_channel_quant_config(torch.uint16, torch.int8),
75*523fa7a6SAndroid Build Coastguard Worker    ),
76*523fa7a6SAndroid Build Coastguard Worker    (QuantDtype.use_16a4w, False): (
77*523fa7a6SAndroid Build Coastguard Worker        get_16a4w_qnn_ptq_config,
78*523fa7a6SAndroid Build Coastguard Worker        get_ptq_per_channel_quant_config(torch.uint16, "int4"),
79*523fa7a6SAndroid Build Coastguard Worker    ),
80*523fa7a6SAndroid Build Coastguard Worker    (QuantDtype.use_8a8w, False): (
81*523fa7a6SAndroid Build Coastguard Worker        get_8a8w_qnn_ptq_config,
82*523fa7a6SAndroid Build Coastguard Worker        get_ptq_per_channel_quant_config(),
83*523fa7a6SAndroid Build Coastguard Worker    ),
84*523fa7a6SAndroid Build Coastguard Worker    # QAT,
85*523fa7a6SAndroid Build Coastguard Worker    (QuantDtype.use_16a4w, True): (
86*523fa7a6SAndroid Build Coastguard Worker        get_16a4w_qnn_qat_config,
87*523fa7a6SAndroid Build Coastguard Worker        get_qat_per_channel_quant_config(torch.uint16, "int4"),
88*523fa7a6SAndroid Build Coastguard Worker    ),
89*523fa7a6SAndroid Build Coastguard Worker    (QuantDtype.use_8a8w, True): (
90*523fa7a6SAndroid Build Coastguard Worker        get_8a8w_qnn_qat_config,
91*523fa7a6SAndroid Build Coastguard Worker        get_qat_per_channel_quant_config(),
92*523fa7a6SAndroid Build Coastguard Worker    ),
93*523fa7a6SAndroid Build Coastguard Worker}
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Workerclass QnnQuantizer(Quantizer):
97*523fa7a6SAndroid Build Coastguard Worker    SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
100*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
101*523fa7a6SAndroid Build Coastguard Worker        self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker        self.is_qat = False
104*523fa7a6SAndroid Build Coastguard Worker        self.quant_dtype = QuantDtype.use_8a8w
105*523fa7a6SAndroid Build Coastguard Worker        self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
106*523fa7a6SAndroid Build Coastguard Worker        self.per_channel_quant_config = get_ptq_per_channel_quant_config()
107*523fa7a6SAndroid Build Coastguard Worker        self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker        self.custom_quant_annotations: Sequence[Callable] = []
110*523fa7a6SAndroid Build Coastguard Worker        self.discard_nodes: Set[str] = set()
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker    def _annotate(self, gm: GraphModule) -> None:
113*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
114*523fa7a6SAndroid Build Coastguard Worker            if node.name in self.discard_nodes:
115*523fa7a6SAndroid Build Coastguard Worker                continue
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker            quant_config = self._get_quant_config(node.target)
118*523fa7a6SAndroid Build Coastguard Worker            if quant_config:
119*523fa7a6SAndroid Build Coastguard Worker                OP_ANNOTATOR[node.target](node, quant_config)
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker    def _annotate_custom_annotation(self, gm: GraphModule) -> None:
122*523fa7a6SAndroid Build Coastguard Worker        for annotation_func in self.custom_quant_annotations:
123*523fa7a6SAndroid Build Coastguard Worker            annotation_func(gm)
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
126*523fa7a6SAndroid Build Coastguard Worker        """
127*523fa7a6SAndroid Build Coastguard Worker        Priority:
128*523fa7a6SAndroid Build Coastguard Worker            1. is one of use_per_channel_weight_quant_ops
129*523fa7a6SAndroid Build Coastguard Worker            2. quant config
130*523fa7a6SAndroid Build Coastguard Worker        """
131*523fa7a6SAndroid Build Coastguard Worker        if isinstance(op, str):
132*523fa7a6SAndroid Build Coastguard Worker            return
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker        if op in self.use_per_channel_weight_quant_ops:
135*523fa7a6SAndroid Build Coastguard Worker            return self.per_channel_quant_config
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker        if op in self.quant_ops:
138*523fa7a6SAndroid Build Coastguard Worker            return self.quant_config
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker        print(f"No quant config is implemented for op, {op}")
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker    def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
143*523fa7a6SAndroid Build Coastguard Worker        if enable:
144*523fa7a6SAndroid Build Coastguard Worker            self.use_per_channel_weight_quant_ops.update(ops)
145*523fa7a6SAndroid Build Coastguard Worker        else:
146*523fa7a6SAndroid Build Coastguard Worker            self.use_per_channel_weight_quant_ops.difference_update(ops)
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker    def add_custom_quant_annotations(
149*523fa7a6SAndroid Build Coastguard Worker        self, custom_quant_annotations: Sequence[Callable]
150*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
151*523fa7a6SAndroid Build Coastguard Worker        self.custom_quant_annotations = custom_quant_annotations
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker    def add_discard_nodes(self, nodes: Sequence[str]) -> None:
154*523fa7a6SAndroid Build Coastguard Worker        self.discard_nodes = set(nodes)
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    def add_discard_ops(self, ops: Sequence[OpOverload]) -> None:
157*523fa7a6SAndroid Build Coastguard Worker        for op in ops:
158*523fa7a6SAndroid Build Coastguard Worker            self.quant_ops.remove(op)
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker    def annotate(self, model: GraphModule) -> GraphModule:
161*523fa7a6SAndroid Build Coastguard Worker        self._annotate(model)
162*523fa7a6SAndroid Build Coastguard Worker        self._annotate_custom_annotation(model)
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker        return model
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker    def get_supported_ops(self) -> Set[OpOverload]:
167*523fa7a6SAndroid Build Coastguard Worker        return self.SUPPORTED_OPS
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker    def set_quant_config(
170*523fa7a6SAndroid Build Coastguard Worker        self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
171*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
172*523fa7a6SAndroid Build Coastguard Worker        self.quant_dtype = quant_dtype
173*523fa7a6SAndroid Build Coastguard Worker        self.is_qat = is_qat
174*523fa7a6SAndroid Build Coastguard Worker        if (quant_dtype, is_qat) not in quant_config_dict:
175*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
176*523fa7a6SAndroid Build Coastguard Worker                f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
177*523fa7a6SAndroid Build Coastguard Worker            )
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker        quant_config_fuc, self.per_channel_quant_config = quant_config_dict[
180*523fa7a6SAndroid Build Coastguard Worker            (quant_dtype, is_qat)
181*523fa7a6SAndroid Build Coastguard Worker        ]
182*523fa7a6SAndroid Build Coastguard Worker        self.quant_config = (
183*523fa7a6SAndroid Build Coastguard Worker            quant_config_fuc(act_observer) if act_observer else quant_config_fuc()
184*523fa7a6SAndroid Build Coastguard Worker        )
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Worker    def set_per_channel_conv_quant(self, enable: bool) -> None:
187*523fa7a6SAndroid Build Coastguard Worker        conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
188*523fa7a6SAndroid Build Coastguard Worker        self._update_per_channel_weight_quant_ops(conv_ops, enable)
189*523fa7a6SAndroid Build Coastguard Worker
190*523fa7a6SAndroid Build Coastguard Worker    def set_per_channel_linear_quant(self, enable: bool) -> None:
191*523fa7a6SAndroid Build Coastguard Worker        linear_ops = {
192*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.linear.default,
193*523fa7a6SAndroid Build Coastguard Worker        }
194*523fa7a6SAndroid Build Coastguard Worker        self._update_per_channel_weight_quant_ops(linear_ops, enable)
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    def transform_for_annotation(self, model: GraphModule) -> GraphModule:
197*523fa7a6SAndroid Build Coastguard Worker        model = ReduceDynamicRange()(model).graph_module
198*523fa7a6SAndroid Build Coastguard Worker        model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module
199*523fa7a6SAndroid Build Coastguard Worker        model = DecomposeScaledDotProductAttention()(model).graph_module
200*523fa7a6SAndroid Build Coastguard Worker        model = DecomposeSilu()(model).graph_module
201*523fa7a6SAndroid Build Coastguard Worker        model = DecomposeEinsum()(model).graph_module
202*523fa7a6SAndroid Build Coastguard Worker        model = ReplaceInfBuffer()(model).graph_module
203*523fa7a6SAndroid Build Coastguard Worker        return model
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker    def validate(self, model: GraphModule) -> None:
206*523fa7a6SAndroid Build Coastguard Worker        pass
207