1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <algorithm>
11 #include <cinttypes>
12 #include <cmath>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17
18 using Tensor = exec_aten::Tensor;
19 using Scalar = exec_aten::Scalar;
20 using ScalarType = exec_aten::ScalarType;
21
22 namespace {
23
24 template <typename INPUT_T, typename OUTPUT_T>
quantize_val(double scale,int64_t zero_point,INPUT_T value,int64_t quant_min,int64_t quant_max)25 OUTPUT_T quantize_val(
26 double scale,
27 int64_t zero_point,
28 INPUT_T value,
29 int64_t quant_min,
30 int64_t quant_max) {
31 int64_t qvalue;
32 float inv_scale = 1.0f / static_cast<float>(scale);
33 qvalue = static_cast<int64_t>(zero_point + std::nearbyint(inv_scale * value));
34 qvalue = std::max<int64_t>(qvalue, quant_min);
35 qvalue = std::min<int64_t>(qvalue, quant_max);
36 return static_cast<OUTPUT_T>(qvalue);
37 }
38
39 template <typename INPUT_T, typename OUTPUT_T>
dequantize_val(double scale,int64_t zero_point,INPUT_T value)40 OUTPUT_T dequantize_val(double scale, int64_t zero_point, INPUT_T value) {
41 return (value - zero_point) * scale;
42 }
43
44 /**
45 * Perform element wise addition of the input tensors into out.
46 * Should be numerically equivalent to Dq -> fp add -> Q
47 */
48 template <class CTYPE>
add_tensors(const Tensor & a,float a_scale,int32_t a_zero_point,const Tensor & b,float b_scale,int32_t b_zero_point,Tensor & out,float out_scale,int32_t out_zero_point,int64_t out_quant_min,int64_t out_quant_max)49 void add_tensors(
50 const Tensor& a,
51 float a_scale,
52 int32_t a_zero_point,
53 const Tensor& b,
54 float b_scale,
55 int32_t b_zero_point,
56 Tensor& out,
57 float out_scale,
58 int32_t out_zero_point,
59 int64_t out_quant_min,
60 int64_t out_quant_max) {
61 const size_t n = a.numel();
62
63 const auto data_a = a.const_data_ptr<CTYPE>();
64 const auto data_b = b.const_data_ptr<CTYPE>();
65 auto data_out = out.mutable_data_ptr<CTYPE>();
66
67 for (size_t i = 0; i < n; ++i) {
68 // Dq -> fp add -> Q. Can be optimized further
69 const auto dqa =
70 dequantize_val<CTYPE, float>(a_scale, a_zero_point, data_a[i]);
71 const auto dqb =
72 dequantize_val<CTYPE, float>(b_scale, b_zero_point, data_b[i]);
73 const auto accumulate = dqa + dqb;
74
75 data_out[i] = quantize_val<float, CTYPE>(
76 out_scale, out_zero_point, accumulate, out_quant_min, out_quant_max);
77 }
78 }
79
80 } // namespace
81
82 /**
83 * Perform element wise addition of the input tensors into out. Should be
84 * numerically equivalent to Dq -> fp add -> Q
85 *
86 * PREREQ: a and b should be the same shape, quant_min and max should be in
87 * range [0,255]. a and b and out should be the same dtype.
88 */
quantized_add_out(const Tensor & a,double a_scale_d,int64_t a_zero_point_l,int64_t a_quant_min,int64_t a_quant_max,const Tensor & b,double b_scale_d,int64_t b_zero_point_l,int64_t b_quant_min,int64_t b_quant_max,double out_scale_d,int64_t out_zero_point_l,int64_t out_quant_min,int64_t out_quant_max,Tensor & out)89 Tensor& quantized_add_out(
90 const Tensor& a,
91 double a_scale_d,
92 int64_t a_zero_point_l,
93 int64_t a_quant_min,
94 int64_t a_quant_max,
95 const Tensor& b,
96 double b_scale_d,
97 int64_t b_zero_point_l,
98 int64_t b_quant_min,
99 int64_t b_quant_max,
100 double out_scale_d,
101 int64_t out_zero_point_l,
102 int64_t out_quant_min,
103 int64_t out_quant_max,
104 Tensor& out) {
105 ET_CHECK_SAME_SHAPE_AND_DTYPE3(a, b, out);
106 ET_CHECK_MSG(
107 a_quant_min >= 0 && a_quant_max <= 255 && a_quant_min <= a_quant_max,
108 "invalid quant_min: %" PRId64 " or quant_max: %" PRId64
109 " for input tensor a. Min should be <= max and both should be in bounds [0,255]",
110 a_quant_min,
111 a_quant_max);
112 ET_CHECK_MSG(
113 b_quant_min >= 0 && b_quant_max <= 255 && b_quant_min <= b_quant_max,
114 "invalid quant_min: %" PRId64 " or quant_max: %" PRId64
115 " for input tensor b. Min should be <= max and both should be in bounds [0,255]",
116 b_quant_min,
117 b_quant_max);
118 ET_CHECK_MSG(
119 out_quant_min >= 0 && out_quant_max <= 255 &&
120 out_quant_min <= out_quant_max,
121 "invalid quant_min: %" PRId64 " or quant_max: %" PRId64
122 " for output tensor. Min should be <= max and both should be in bounds [0,255]",
123 out_quant_min,
124 out_quant_max);
125
126 // downsize to maintain numerical consistency with fbgemm
127 float a_scale = static_cast<float>(a_scale_d);
128 float b_scale = static_cast<float>(b_scale_d);
129 float out_scale = static_cast<float>(out_scale_d);
130
131 int32_t a_zero_point = static_cast<int32_t>(a_zero_point_l);
132 int32_t b_zero_point = static_cast<int32_t>(b_zero_point_l);
133 int32_t out_zero_point = static_cast<int32_t>(out_zero_point_l);
134
135 #define ADD_TENSORS(ctype, dtype) \
136 case ScalarType::dtype: \
137 add_tensors<ctype>( \
138 a, \
139 a_scale, \
140 a_zero_point, \
141 b, \
142 b_scale, \
143 b_zero_point, \
144 out, \
145 out_scale, \
146 out_zero_point, \
147 out_quant_min, \
148 out_quant_max); \
149 break;
150
151 switch (a.scalar_type()) {
152 ET_FORALL_INT_TYPES(ADD_TENSORS)
153 default:
154 ET_CHECK_MSG(
155 false,
156 "Unhandled dtype %" PRId8,
157 static_cast<int8_t>(a.scalar_type()));
158 }
159
160 #undef ADD_TENSORS
161
162 return out;
163 }
164
quantized_add_out(KernelRuntimeContext & context,const Tensor & a,double a_scale,int64_t a_zero_point,int64_t a_quant_min,int64_t a_quant_max,const Tensor & b,double b_scale,int64_t b_zero_point,int64_t b_quant_min,int64_t b_quant_max,double out_scale,int64_t out_zero_point,int64_t out_quant_min,int64_t out_quant_max,Tensor & out)165 Tensor& quantized_add_out(
166 KernelRuntimeContext& context,
167 const Tensor& a,
168 double a_scale,
169 int64_t a_zero_point,
170 int64_t a_quant_min,
171 int64_t a_quant_max,
172 const Tensor& b,
173 double b_scale,
174 int64_t b_zero_point,
175 int64_t b_quant_min,
176 int64_t b_quant_max,
177 double out_scale,
178 int64_t out_zero_point,
179 int64_t out_quant_min,
180 int64_t out_quant_max,
181 Tensor& out) {
182 // TODO(larryliu): Add a context arg to the real op function and remove this
183 // wrapper
184 (void)context;
185 return quantized_add_out(
186 a,
187 a_scale,
188 a_zero_point,
189 a_quant_min,
190 a_quant_max,
191 b,
192 b_scale,
193 b_zero_point,
194 b_quant_min,
195 b_quant_max,
196 out_scale,
197 out_zero_point,
198 out_quant_min,
199 out_quant_max,
200 out);
201 }
202
203 } // namespace native
204 } // namespace executor
205 } // namespace torch
206