1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum, unique 7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Optional, Sequence, Set 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum 11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( 13*523fa7a6SAndroid Build Coastguard Worker RecomposePixelUnshuffle, 14*523fa7a6SAndroid Build Coastguard Worker) 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.transforms.decompose_sdpa import ( 18*523fa7a6SAndroid Build Coastguard Worker DecomposeScaledDotProductAttention, 19*523fa7a6SAndroid Build Coastguard Worker) 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerfrom torch._ops import OpOverload 22*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import Quantizer 23*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerfrom .annotators import OP_ANNOTATOR 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerfrom .qconfig import ( 28*523fa7a6SAndroid Build Coastguard Worker get_16a16w_qnn_ptq_config, 29*523fa7a6SAndroid Build Coastguard Worker get_16a4w_qnn_ptq_config, 30*523fa7a6SAndroid Build Coastguard Worker get_16a4w_qnn_qat_config, 31*523fa7a6SAndroid Build Coastguard Worker get_16a8w_qnn_ptq_config, 32*523fa7a6SAndroid Build Coastguard Worker get_8a8w_qnn_ptq_config, 33*523fa7a6SAndroid Build Coastguard Worker get_8a8w_qnn_qat_config, 34*523fa7a6SAndroid Build Coastguard Worker get_ptq_per_channel_quant_config, 35*523fa7a6SAndroid Build Coastguard Worker get_qat_per_channel_quant_config, 36*523fa7a6SAndroid Build Coastguard Worker QuantizationConfig, 37*523fa7a6SAndroid Build Coastguard Worker) 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker# To bypass the meta internal test error 40*523fa7a6SAndroid Build Coastguard Workerget_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker__all__ = [ 43*523fa7a6SAndroid Build Coastguard Worker "QnnQuantizer", 44*523fa7a6SAndroid Build Coastguard Worker "QuantDtype", 45*523fa7a6SAndroid Build Coastguard Worker "get_16a4w_qnn_ptq_config", 46*523fa7a6SAndroid Build Coastguard Worker "get_16a8w_qnn_ptq_config", 47*523fa7a6SAndroid Build Coastguard Worker "get_16a16w_qnn_ptq_config", 48*523fa7a6SAndroid Build Coastguard Worker "get_8a8w_qnn_ptq_config", 49*523fa7a6SAndroid Build Coastguard Worker "get_8a8w_qnn_qat_config", 50*523fa7a6SAndroid Build Coastguard Worker "get_16a4w_qnn_qat_config", 51*523fa7a6SAndroid Build Coastguard Worker] 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker@unique 55*523fa7a6SAndroid Build Coastguard Workerclass QuantDtype(IntEnum): 56*523fa7a6SAndroid Build Coastguard Worker """ 57*523fa7a6SAndroid Build Coastguard Worker bits of activation and bits of weight 58*523fa7a6SAndroid Build Coastguard Worker """ 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker use_16a16w = 0 61*523fa7a6SAndroid Build Coastguard Worker use_16a8w = 1 62*523fa7a6SAndroid Build Coastguard Worker use_16a4w = 2 63*523fa7a6SAndroid Build Coastguard Worker use_8a8w = 3 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Workerquant_config_dict = { 67*523fa7a6SAndroid Build Coastguard Worker # PTQ 68*523fa7a6SAndroid Build Coastguard Worker (QuantDtype.use_16a16w, False): ( 69*523fa7a6SAndroid Build Coastguard Worker get_16a16w_qnn_ptq_config, 70*523fa7a6SAndroid Build Coastguard Worker get_ptq_per_channel_quant_config(torch.uint16, torch.int16), 71*523fa7a6SAndroid Build Coastguard Worker ), 72*523fa7a6SAndroid Build Coastguard Worker (QuantDtype.use_16a8w, False): ( 73*523fa7a6SAndroid Build Coastguard Worker get_16a8w_qnn_ptq_config, 74*523fa7a6SAndroid Build Coastguard Worker get_ptq_per_channel_quant_config(torch.uint16, torch.int8), 75*523fa7a6SAndroid Build Coastguard Worker ), 76*523fa7a6SAndroid Build Coastguard Worker (QuantDtype.use_16a4w, False): ( 77*523fa7a6SAndroid Build Coastguard Worker get_16a4w_qnn_ptq_config, 78*523fa7a6SAndroid Build Coastguard Worker get_ptq_per_channel_quant_config(torch.uint16, "int4"), 79*523fa7a6SAndroid Build Coastguard Worker ), 80*523fa7a6SAndroid Build Coastguard Worker (QuantDtype.use_8a8w, False): ( 81*523fa7a6SAndroid Build Coastguard Worker get_8a8w_qnn_ptq_config, 82*523fa7a6SAndroid Build Coastguard Worker get_ptq_per_channel_quant_config(), 83*523fa7a6SAndroid Build Coastguard Worker ), 84*523fa7a6SAndroid Build Coastguard Worker # QAT, 85*523fa7a6SAndroid Build Coastguard Worker (QuantDtype.use_16a4w, True): ( 86*523fa7a6SAndroid Build Coastguard Worker get_16a4w_qnn_qat_config, 87*523fa7a6SAndroid Build Coastguard Worker get_qat_per_channel_quant_config(torch.uint16, "int4"), 88*523fa7a6SAndroid Build Coastguard Worker ), 89*523fa7a6SAndroid Build Coastguard Worker (QuantDtype.use_8a8w, True): ( 90*523fa7a6SAndroid Build Coastguard Worker get_8a8w_qnn_qat_config, 91*523fa7a6SAndroid Build Coastguard Worker get_qat_per_channel_quant_config(), 92*523fa7a6SAndroid Build Coastguard Worker ), 93*523fa7a6SAndroid Build Coastguard Worker} 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Workerclass QnnQuantizer(Quantizer): 97*523fa7a6SAndroid Build Coastguard Worker SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 100*523fa7a6SAndroid Build Coastguard Worker super().__init__() 101*523fa7a6SAndroid Build Coastguard Worker self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker self.is_qat = False 104*523fa7a6SAndroid Build Coastguard Worker self.quant_dtype = QuantDtype.use_8a8w 105*523fa7a6SAndroid Build Coastguard Worker self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() 106*523fa7a6SAndroid Build Coastguard Worker self.per_channel_quant_config = get_ptq_per_channel_quant_config() 107*523fa7a6SAndroid Build Coastguard Worker self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker self.custom_quant_annotations: Sequence[Callable] = [] 110*523fa7a6SAndroid Build Coastguard Worker self.discard_nodes: Set[str] = set() 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker def _annotate(self, gm: GraphModule) -> None: 113*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 114*523fa7a6SAndroid Build Coastguard Worker if node.name in self.discard_nodes: 115*523fa7a6SAndroid Build Coastguard Worker continue 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker quant_config = self._get_quant_config(node.target) 118*523fa7a6SAndroid Build Coastguard Worker if quant_config: 119*523fa7a6SAndroid Build Coastguard Worker OP_ANNOTATOR[node.target](node, quant_config) 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker def _annotate_custom_annotation(self, gm: GraphModule) -> None: 122*523fa7a6SAndroid Build Coastguard Worker for annotation_func in self.custom_quant_annotations: 123*523fa7a6SAndroid Build Coastguard Worker annotation_func(gm) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]: 126*523fa7a6SAndroid Build Coastguard Worker """ 127*523fa7a6SAndroid Build Coastguard Worker Priority: 128*523fa7a6SAndroid Build Coastguard Worker 1. is one of use_per_channel_weight_quant_ops 129*523fa7a6SAndroid Build Coastguard Worker 2. quant config 130*523fa7a6SAndroid Build Coastguard Worker """ 131*523fa7a6SAndroid Build Coastguard Worker if isinstance(op, str): 132*523fa7a6SAndroid Build Coastguard Worker return 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker if op in self.use_per_channel_weight_quant_ops: 135*523fa7a6SAndroid Build Coastguard Worker return self.per_channel_quant_config 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker if op in self.quant_ops: 138*523fa7a6SAndroid Build Coastguard Worker return self.quant_config 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker print(f"No quant config is implemented for op, {op}") 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): 143*523fa7a6SAndroid Build Coastguard Worker if enable: 144*523fa7a6SAndroid Build Coastguard Worker self.use_per_channel_weight_quant_ops.update(ops) 145*523fa7a6SAndroid Build Coastguard Worker else: 146*523fa7a6SAndroid Build Coastguard Worker self.use_per_channel_weight_quant_ops.difference_update(ops) 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker def add_custom_quant_annotations( 149*523fa7a6SAndroid Build Coastguard Worker self, custom_quant_annotations: Sequence[Callable] 150*523fa7a6SAndroid Build Coastguard Worker ) -> None: 151*523fa7a6SAndroid Build Coastguard Worker self.custom_quant_annotations = custom_quant_annotations 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker def add_discard_nodes(self, nodes: Sequence[str]) -> None: 154*523fa7a6SAndroid Build Coastguard Worker self.discard_nodes = set(nodes) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: 157*523fa7a6SAndroid Build Coastguard Worker for op in ops: 158*523fa7a6SAndroid Build Coastguard Worker self.quant_ops.remove(op) 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker def annotate(self, model: GraphModule) -> GraphModule: 161*523fa7a6SAndroid Build Coastguard Worker self._annotate(model) 162*523fa7a6SAndroid Build Coastguard Worker self._annotate_custom_annotation(model) 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker return model 165*523fa7a6SAndroid Build Coastguard Worker 166*523fa7a6SAndroid Build Coastguard Worker def get_supported_ops(self) -> Set[OpOverload]: 167*523fa7a6SAndroid Build Coastguard Worker return self.SUPPORTED_OPS 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker def set_quant_config( 170*523fa7a6SAndroid Build Coastguard Worker self, quant_dtype: QuantDtype, is_qat=False, act_observer=None 171*523fa7a6SAndroid Build Coastguard Worker ) -> None: 172*523fa7a6SAndroid Build Coastguard Worker self.quant_dtype = quant_dtype 173*523fa7a6SAndroid Build Coastguard Worker self.is_qat = is_qat 174*523fa7a6SAndroid Build Coastguard Worker if (quant_dtype, is_qat) not in quant_config_dict: 175*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 176*523fa7a6SAndroid Build Coastguard Worker f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" 177*523fa7a6SAndroid Build Coastguard Worker ) 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ 180*523fa7a6SAndroid Build Coastguard Worker (quant_dtype, is_qat) 181*523fa7a6SAndroid Build Coastguard Worker ] 182*523fa7a6SAndroid Build Coastguard Worker self.quant_config = ( 183*523fa7a6SAndroid Build Coastguard Worker quant_config_fuc(act_observer) if act_observer else quant_config_fuc() 184*523fa7a6SAndroid Build Coastguard Worker ) 185*523fa7a6SAndroid Build Coastguard Worker 186*523fa7a6SAndroid Build Coastguard Worker def set_per_channel_conv_quant(self, enable: bool) -> None: 187*523fa7a6SAndroid Build Coastguard Worker conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} 188*523fa7a6SAndroid Build Coastguard Worker self._update_per_channel_weight_quant_ops(conv_ops, enable) 189*523fa7a6SAndroid Build Coastguard Worker 190*523fa7a6SAndroid Build Coastguard Worker def set_per_channel_linear_quant(self, enable: bool) -> None: 191*523fa7a6SAndroid Build Coastguard Worker linear_ops = { 192*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.linear.default, 193*523fa7a6SAndroid Build Coastguard Worker } 194*523fa7a6SAndroid Build Coastguard Worker self._update_per_channel_weight_quant_ops(linear_ops, enable) 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker def transform_for_annotation(self, model: GraphModule) -> GraphModule: 197*523fa7a6SAndroid Build Coastguard Worker model = ReduceDynamicRange()(model).graph_module 198*523fa7a6SAndroid Build Coastguard Worker model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module 199*523fa7a6SAndroid Build Coastguard Worker model = DecomposeScaledDotProductAttention()(model).graph_module 200*523fa7a6SAndroid Build Coastguard Worker model = DecomposeSilu()(model).graph_module 201*523fa7a6SAndroid Build Coastguard Worker model = DecomposeEinsum()(model).graph_module 202*523fa7a6SAndroid Build Coastguard Worker model = ReplaceInfBuffer()(model).graph_module 203*523fa7a6SAndroid Build Coastguard Worker return model 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Worker def validate(self, model: GraphModule) -> None: 206*523fa7a6SAndroid Build Coastguard Worker pass 207