1# mypy: allow-untyped-defs 2import numpy as np 3 4import torch 5from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule 6from torch.ao.quantization.experimental.observer import APoTObserver 7from torch.ao.quantization.experimental.quantizer import quantize_APoT 8 9 10class LinearAPoT(WeightedQuantizedModule): 11 r""" 12 A quantized linear module with quantized tensor as inputs and outputs 13 to support APoT quantization. 14 We adopt the same interface as `torch.nn.Linear`, see 15 https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. 16 17 Similar to :class:`~torch.nn.Linear`, attributes will be randomly 18 initialized at module creation time and will be overwritten later 19 20 Attributes: 21 alpha: `alpha` qparam of output Quantized Tensor, type: Tensor 22 gamma: `gamma` qparam of output Quantized Tensor, type: Tensor 23 quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor 24 level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor 25 weight: APoT quantized tensor from weight2quantize 26 weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b) 27 """ 28 29 def __init__(self, weight2quantize: torch.Tensor, b: int, k: int): 30 assert weight2quantize.dim() == 2 31 assert b % k == 0 32 33 super().__init__() 34 35 self.b = b 36 self.k = k 37 self.n = self.b // self.k 38 39 observer = APoTObserver(b=self.b, k=self.k) 40 41 observer(weight2quantize) 42 43 ( 44 self.alpha, 45 self.gamma, 46 self.quantization_levels, 47 self.level_indices, 48 ) = observer.calculate_qparams(signed=False) 49 50 quantized_weight = quantize_APoT( 51 weight2quantize, 52 self.alpha, 53 self.gamma, 54 self.quantization_levels, 55 self.level_indices, 56 ) 57 self.weight = quantized_weight.data 58 self.weight_transposed = torch.transpose(self.weight, 0, 1) 59 60 def decompose_APoT(self, x): 61 r""" 62 Decompose binary representation of APoT values into list of k-sized blocks 63 Args: 64 x (Tensor): binary representation of APoT quantized tensor 65 """ 66 # remove "0b" prefix from binary representation 67 x = x[2:] 68 69 # initialize list of blocks 70 blocks = [] 71 72 while x: 73 blocks.append(x[0 : self.k]) 74 x = x[self.k :] 75 76 return blocks 77 78 def bitshift_mul(self, weight_val, r): 79 r""" 80 Compute multiplication of weight_val * r using bitshifting 81 method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf 82 Args: 83 weight_val: list of binary digits representing APoT quantized weight value 84 r: int representing uniformly quantized activation value 85 """ 86 product = 0 87 88 idx = len(weight_val) - 1 89 place = 0 90 91 while idx >= 0: 92 block = weight_val[idx] 93 94 # reverse digits in block 95 block = block[::-1] 96 97 curr_block_result = 0 98 99 for ele in block: 100 if int(ele): 101 curr_block_result += r << place 102 place += 1 103 104 idx -= 1 105 product += curr_block_result 106 107 return product 108 109 def matmul(self, decomposed_weight, activation): 110 r""" 111 Perform matrix multiplication between decomposed_weight and 112 activation by calling bitshift_mul function for each value 113 Args: 114 decomposed_weight (Tensor): APoT quantized weight decomposed into binary 115 activation (Tensor): uniformly quantized activation 116 """ 117 rows1 = activation.size(dim=0) 118 cols1 = activation.size(dim=1) 119 120 rows2 = decomposed_weight.shape[0] 121 cols2 = decomposed_weight.shape[1] 122 123 result = torch.zeros(rows1, cols2) 124 125 # compute matrix multiplication with bitshifts 126 for i in range(rows1): 127 for j in range(cols2): 128 for k in range(rows2): 129 weight_val = decomposed_weight[k][j] 130 r = int(activation[i][k]) 131 132 product = self.bitshift_mul(weight_val, r) 133 134 result[i][j] += product 135 136 return result 137 138 def forward(self, activation: torch.Tensor) -> torch.FloatTensor: 139 r""" 140 Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8) 141 with bitshifting instead of matrix multiplication. 142 Result has dtype torch.float32 143 Args: 144 activation (Tensor): uniformly quantized activation tensor 145 """ 146 assert activation.dim() == 2 147 148 weight_rows = self.weight_transposed.size()[0] 149 weight_cols = self.weight_transposed.size()[1] 150 151 decomposed_weight: np.ndarray = np.empty( 152 shape=(weight_rows, weight_cols), dtype=object 153 ) 154 for row in range(weight_rows): 155 for col in range(weight_cols): 156 decomposed_weight[row][col] = self.decompose_APoT( 157 bin(self.weight_transposed[row][col]) 158 ) 159 160 result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor) 161 162 return result 163 164 @classmethod 165 def from_reference( # type: ignore[override] 166 cls, 167 ref_qlinear, 168 alpha: torch.Tensor, 169 gamma: torch.Tensor, 170 quantization_levels: torch.Tensor, 171 level_indices: torch.Tensor, 172 ): 173 raise NotImplementedError 174