xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import numpy as np
3
4import torch
5from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
6from torch.ao.quantization.experimental.observer import APoTObserver
7from torch.ao.quantization.experimental.quantizer import quantize_APoT
8
9
10class LinearAPoT(WeightedQuantizedModule):
11    r"""
12    A quantized linear module with quantized tensor as inputs and outputs
13    to support APoT quantization.
14    We adopt the same interface as `torch.nn.Linear`, see
15    https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
16
17    Similar to :class:`~torch.nn.Linear`, attributes will be randomly
18    initialized at module creation time and will be overwritten later
19
20    Attributes:
21        alpha: `alpha` qparam of output Quantized Tensor, type: Tensor
22        gamma: `gamma` qparam of output Quantized Tensor, type: Tensor
23        quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor
24        level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor
25        weight: APoT quantized tensor from weight2quantize
26        weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b)
27    """
28
29    def __init__(self, weight2quantize: torch.Tensor, b: int, k: int):
30        assert weight2quantize.dim() == 2
31        assert b % k == 0
32
33        super().__init__()
34
35        self.b = b
36        self.k = k
37        self.n = self.b // self.k
38
39        observer = APoTObserver(b=self.b, k=self.k)
40
41        observer(weight2quantize)
42
43        (
44            self.alpha,
45            self.gamma,
46            self.quantization_levels,
47            self.level_indices,
48        ) = observer.calculate_qparams(signed=False)
49
50        quantized_weight = quantize_APoT(
51            weight2quantize,
52            self.alpha,
53            self.gamma,
54            self.quantization_levels,
55            self.level_indices,
56        )
57        self.weight = quantized_weight.data
58        self.weight_transposed = torch.transpose(self.weight, 0, 1)
59
60    def decompose_APoT(self, x):
61        r"""
62        Decompose binary representation of APoT values into list of k-sized blocks
63        Args:
64            x (Tensor): binary representation of APoT quantized tensor
65        """
66        # remove "0b" prefix from binary representation
67        x = x[2:]
68
69        # initialize list of blocks
70        blocks = []
71
72        while x:
73            blocks.append(x[0 : self.k])
74            x = x[self.k :]
75
76        return blocks
77
78    def bitshift_mul(self, weight_val, r):
79        r"""
80        Compute multiplication of weight_val * r using bitshifting
81        method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf
82        Args:
83            weight_val: list of binary digits representing APoT quantized weight value
84            r: int representing uniformly quantized activation value
85        """
86        product = 0
87
88        idx = len(weight_val) - 1
89        place = 0
90
91        while idx >= 0:
92            block = weight_val[idx]
93
94            # reverse digits in block
95            block = block[::-1]
96
97            curr_block_result = 0
98
99            for ele in block:
100                if int(ele):
101                    curr_block_result += r << place
102                place += 1
103
104            idx -= 1
105            product += curr_block_result
106
107        return product
108
109    def matmul(self, decomposed_weight, activation):
110        r"""
111        Perform matrix multiplication between decomposed_weight and
112        activation by calling bitshift_mul function for each value
113        Args:
114            decomposed_weight (Tensor): APoT quantized weight decomposed into binary
115            activation (Tensor): uniformly quantized activation
116        """
117        rows1 = activation.size(dim=0)
118        cols1 = activation.size(dim=1)
119
120        rows2 = decomposed_weight.shape[0]
121        cols2 = decomposed_weight.shape[1]
122
123        result = torch.zeros(rows1, cols2)
124
125        # compute matrix multiplication with bitshifts
126        for i in range(rows1):
127            for j in range(cols2):
128                for k in range(rows2):
129                    weight_val = decomposed_weight[k][j]
130                    r = int(activation[i][k])
131
132                    product = self.bitshift_mul(weight_val, r)
133
134                    result[i][j] += product
135
136        return result
137
138    def forward(self, activation: torch.Tensor) -> torch.FloatTensor:
139        r"""
140        Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8)
141        with bitshifting instead of matrix multiplication.
142        Result has dtype torch.float32
143        Args:
144            activation (Tensor): uniformly quantized activation tensor
145        """
146        assert activation.dim() == 2
147
148        weight_rows = self.weight_transposed.size()[0]
149        weight_cols = self.weight_transposed.size()[1]
150
151        decomposed_weight: np.ndarray = np.empty(
152            shape=(weight_rows, weight_cols), dtype=object
153        )
154        for row in range(weight_rows):
155            for col in range(weight_cols):
156                decomposed_weight[row][col] = self.decompose_APoT(
157                    bin(self.weight_transposed[row][col])
158                )
159
160        result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor)
161
162        return result
163
164    @classmethod
165    def from_reference(  # type: ignore[override]
166        cls,
167        ref_qlinear,
168        alpha: torch.Tensor,
169        gamma: torch.Tensor,
170        quantization_levels: torch.Tensor,
171        level_indices: torch.Tensor,
172    ):
173        raise NotImplementedError
174