xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# flake8: noqa: E266, C417, B950
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
4*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker##### Quantization Primitives ######
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerdef dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
11*da0073e9SAndroid Build Coastguard Worker    # assumes symmetric quantization
12*da0073e9SAndroid Build Coastguard Worker    # assumes axis == 0
13*da0073e9SAndroid Build Coastguard Worker    # assumes dense memory format
14*da0073e9SAndroid Build Coastguard Worker    # TODO(future): relax ^ as needed
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker    # default setup for affine quantization of activations
17*da0073e9SAndroid Build Coastguard Worker    eps = torch.finfo(torch.float32).eps
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    # get min and max
20*da0073e9SAndroid Build Coastguard Worker    min_val, max_val = torch.aminmax(x, dim=1)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    # calculate scales and zero_points based on min and max
23*da0073e9SAndroid Build Coastguard Worker    # reference: https://fburl.com/code/srbiybme
24*da0073e9SAndroid Build Coastguard Worker    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
25*da0073e9SAndroid Build Coastguard Worker    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
26*da0073e9SAndroid Build Coastguard Worker    device = min_val_neg.device
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    # reference: https://fburl.com/code/4wll53rk
29*da0073e9SAndroid Build Coastguard Worker    max_val_pos = torch.max(-min_val_neg, max_val_pos)
30*da0073e9SAndroid Build Coastguard Worker    scales = max_val_pos / (float(quant_max - quant_min) / 2)
31*da0073e9SAndroid Build Coastguard Worker    # ensure scales is the same dtype as the original tensor
32*da0073e9SAndroid Build Coastguard Worker    scales = torch.clamp(scales, min=eps).to(x.dtype)
33*da0073e9SAndroid Build Coastguard Worker    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    # quantize based on qmin/qmax/scales/zp
36*da0073e9SAndroid Build Coastguard Worker    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
37*da0073e9SAndroid Build Coastguard Worker    x_div = x / scales.unsqueeze(-1)
38*da0073e9SAndroid Build Coastguard Worker    x_round = torch.round(x_div)
39*da0073e9SAndroid Build Coastguard Worker    x_zp = x_round + zero_points.unsqueeze(-1)
40*da0073e9SAndroid Build Coastguard Worker    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    return quant, scales, zero_points
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker##### Weight-only int8 per-channel quantized code ######
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerdef replace_linear_weight_only_int8_per_channel(module):
49*da0073e9SAndroid Build Coastguard Worker    for name, child in module.named_children():
50*da0073e9SAndroid Build Coastguard Worker        if isinstance(child, nn.Linear):
51*da0073e9SAndroid Build Coastguard Worker            setattr(
52*da0073e9SAndroid Build Coastguard Worker                module,
53*da0073e9SAndroid Build Coastguard Worker                name,
54*da0073e9SAndroid Build Coastguard Worker                WeightOnlyInt8Linear(child.in_features, child.out_features),
55*da0073e9SAndroid Build Coastguard Worker            )
56*da0073e9SAndroid Build Coastguard Worker        else:
57*da0073e9SAndroid Build Coastguard Worker            replace_linear_weight_only_int8_per_channel(child)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerclass WeightOnlyInt8QuantHandler:
61*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mod):
62*da0073e9SAndroid Build Coastguard Worker        self.mod = mod
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    @torch.no_grad()
65*da0073e9SAndroid Build Coastguard Worker    def create_quantized_state_dict(self):
66*da0073e9SAndroid Build Coastguard Worker        cur_state_dict = self.mod.state_dict()
67*da0073e9SAndroid Build Coastguard Worker        for fqn, mod in self.mod.named_modules():
68*da0073e9SAndroid Build Coastguard Worker            if isinstance(mod, torch.nn.Linear):
69*da0073e9SAndroid Build Coastguard Worker                int8_weight, scales, _ = dynamically_quantize_per_channel(
70*da0073e9SAndroid Build Coastguard Worker                    mod.weight.float(), -128, 127, torch.int8
71*da0073e9SAndroid Build Coastguard Worker                )
72*da0073e9SAndroid Build Coastguard Worker                cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu")
73*da0073e9SAndroid Build Coastguard Worker                cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu")
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        return cur_state_dict
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def convert_for_runtime(self):
78*da0073e9SAndroid Build Coastguard Worker        replace_linear_weight_only_int8_per_channel(self.mod)
79*da0073e9SAndroid Build Coastguard Worker        return self.mod
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerclass WeightOnlyInt8Linear(torch.nn.Module):
83*da0073e9SAndroid Build Coastguard Worker    __constants__ = ["in_features", "out_features"]
84*da0073e9SAndroid Build Coastguard Worker    in_features: int
85*da0073e9SAndroid Build Coastguard Worker    out_features: int
86*da0073e9SAndroid Build Coastguard Worker    weight: torch.Tensor
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def __init__(
89*da0073e9SAndroid Build Coastguard Worker        self,
90*da0073e9SAndroid Build Coastguard Worker        in_features: int,
91*da0073e9SAndroid Build Coastguard Worker        out_features: int,
92*da0073e9SAndroid Build Coastguard Worker        bias: bool = True,
93*da0073e9SAndroid Build Coastguard Worker        device=None,
94*da0073e9SAndroid Build Coastguard Worker        dtype=None,
95*da0073e9SAndroid Build Coastguard Worker    ) -> None:
96*da0073e9SAndroid Build Coastguard Worker        factory_kwargs = {"device": device, "dtype": dtype}
97*da0073e9SAndroid Build Coastguard Worker        super().__init__()
98*da0073e9SAndroid Build Coastguard Worker        self.in_features = in_features
99*da0073e9SAndroid Build Coastguard Worker        self.out_features = out_features
100*da0073e9SAndroid Build Coastguard Worker        self.register_buffer(
101*da0073e9SAndroid Build Coastguard Worker            "weight", torch.empty((out_features, in_features), dtype=torch.int8)
102*da0073e9SAndroid Build Coastguard Worker        )
103*da0073e9SAndroid Build Coastguard Worker        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    def forward(self, input: torch.Tensor) -> torch.Tensor:
106*da0073e9SAndroid Build Coastguard Worker        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
107