1*da0073e9SAndroid Build Coastguard Worker# flake8: noqa: E266, C417, B950 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 4*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker##### Quantization Primitives ###### 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerdef dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 11*da0073e9SAndroid Build Coastguard Worker # assumes symmetric quantization 12*da0073e9SAndroid Build Coastguard Worker # assumes axis == 0 13*da0073e9SAndroid Build Coastguard Worker # assumes dense memory format 14*da0073e9SAndroid Build Coastguard Worker # TODO(future): relax ^ as needed 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker # default setup for affine quantization of activations 17*da0073e9SAndroid Build Coastguard Worker eps = torch.finfo(torch.float32).eps 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker # get min and max 20*da0073e9SAndroid Build Coastguard Worker min_val, max_val = torch.aminmax(x, dim=1) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker # calculate scales and zero_points based on min and max 23*da0073e9SAndroid Build Coastguard Worker # reference: https://fburl.com/code/srbiybme 24*da0073e9SAndroid Build Coastguard Worker min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 25*da0073e9SAndroid Build Coastguard Worker max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 26*da0073e9SAndroid Build Coastguard Worker device = min_val_neg.device 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker # reference: https://fburl.com/code/4wll53rk 29*da0073e9SAndroid Build Coastguard Worker max_val_pos = torch.max(-min_val_neg, max_val_pos) 30*da0073e9SAndroid Build Coastguard Worker scales = max_val_pos / (float(quant_max - quant_min) / 2) 31*da0073e9SAndroid Build Coastguard Worker # ensure scales is the same dtype as the original tensor 32*da0073e9SAndroid Build Coastguard Worker scales = torch.clamp(scales, min=eps).to(x.dtype) 33*da0073e9SAndroid Build Coastguard Worker zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker # quantize based on qmin/qmax/scales/zp 36*da0073e9SAndroid Build Coastguard Worker # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 37*da0073e9SAndroid Build Coastguard Worker x_div = x / scales.unsqueeze(-1) 38*da0073e9SAndroid Build Coastguard Worker x_round = torch.round(x_div) 39*da0073e9SAndroid Build Coastguard Worker x_zp = x_round + zero_points.unsqueeze(-1) 40*da0073e9SAndroid Build Coastguard Worker quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker return quant, scales, zero_points 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker##### Weight-only int8 per-channel quantized code ###### 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerdef replace_linear_weight_only_int8_per_channel(module): 49*da0073e9SAndroid Build Coastguard Worker for name, child in module.named_children(): 50*da0073e9SAndroid Build Coastguard Worker if isinstance(child, nn.Linear): 51*da0073e9SAndroid Build Coastguard Worker setattr( 52*da0073e9SAndroid Build Coastguard Worker module, 53*da0073e9SAndroid Build Coastguard Worker name, 54*da0073e9SAndroid Build Coastguard Worker WeightOnlyInt8Linear(child.in_features, child.out_features), 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker else: 57*da0073e9SAndroid Build Coastguard Worker replace_linear_weight_only_int8_per_channel(child) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerclass WeightOnlyInt8QuantHandler: 61*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod): 62*da0073e9SAndroid Build Coastguard Worker self.mod = mod 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 65*da0073e9SAndroid Build Coastguard Worker def create_quantized_state_dict(self): 66*da0073e9SAndroid Build Coastguard Worker cur_state_dict = self.mod.state_dict() 67*da0073e9SAndroid Build Coastguard Worker for fqn, mod in self.mod.named_modules(): 68*da0073e9SAndroid Build Coastguard Worker if isinstance(mod, torch.nn.Linear): 69*da0073e9SAndroid Build Coastguard Worker int8_weight, scales, _ = dynamically_quantize_per_channel( 70*da0073e9SAndroid Build Coastguard Worker mod.weight.float(), -128, 127, torch.int8 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu") 73*da0073e9SAndroid Build Coastguard Worker cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu") 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker return cur_state_dict 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def convert_for_runtime(self): 78*da0073e9SAndroid Build Coastguard Worker replace_linear_weight_only_int8_per_channel(self.mod) 79*da0073e9SAndroid Build Coastguard Worker return self.mod 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Workerclass WeightOnlyInt8Linear(torch.nn.Module): 83*da0073e9SAndroid Build Coastguard Worker __constants__ = ["in_features", "out_features"] 84*da0073e9SAndroid Build Coastguard Worker in_features: int 85*da0073e9SAndroid Build Coastguard Worker out_features: int 86*da0073e9SAndroid Build Coastguard Worker weight: torch.Tensor 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def __init__( 89*da0073e9SAndroid Build Coastguard Worker self, 90*da0073e9SAndroid Build Coastguard Worker in_features: int, 91*da0073e9SAndroid Build Coastguard Worker out_features: int, 92*da0073e9SAndroid Build Coastguard Worker bias: bool = True, 93*da0073e9SAndroid Build Coastguard Worker device=None, 94*da0073e9SAndroid Build Coastguard Worker dtype=None, 95*da0073e9SAndroid Build Coastguard Worker ) -> None: 96*da0073e9SAndroid Build Coastguard Worker factory_kwargs = {"device": device, "dtype": dtype} 97*da0073e9SAndroid Build Coastguard Worker super().__init__() 98*da0073e9SAndroid Build Coastguard Worker self.in_features = in_features 99*da0073e9SAndroid Build Coastguard Worker self.out_features = out_features 100*da0073e9SAndroid Build Coastguard Worker self.register_buffer( 101*da0073e9SAndroid Build Coastguard Worker "weight", torch.empty((out_features, in_features), dtype=torch.int8) 102*da0073e9SAndroid Build Coastguard Worker ) 103*da0073e9SAndroid Build Coastguard Worker self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor) -> torch.Tensor: 106*da0073e9SAndroid Build Coastguard Worker return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 107