1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14
15 /**
16 * For an input tensor, use the scale and zero_point arguments to quantize it.
17 */
18 namespace torch {
19 namespace executor {
20 namespace native {
21
22 using Tensor = exec_aten::Tensor;
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25
26 namespace {
27
28 /**
29 * Asserts that the parameters are valid.
30 */
check_quantize_per_tensor_args(const Tensor & input,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)31 void check_quantize_per_tensor_args(
32 const Tensor& input,
33 int64_t quant_min,
34 int64_t quant_max,
35 ScalarType dtype,
36 Tensor& out) {
37 // Ensure self and out has the same shape
38 ET_CHECK_MSG(
39 torch::executor::isFloatingType(input.scalar_type()),
40 "input.scalar_type() %" PRId8 " is not floating type",
41 static_cast<int8_t>(input.scalar_type()));
42
43 int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0;
44 ScalarType out_dtype = out.scalar_type();
45 ET_CHECK_MSG(
46 out_dtype == dtype,
47 "out.scalar_type() %" PRId8 " is not matching dtype argument %" PRId8,
48 static_cast<int8_t>(out_dtype),
49 static_cast<int8_t>(dtype));
50 if (out_dtype == ScalarType::Byte) {
51 quant_min_lower_bound =
52 static_cast<int32_t>(std::numeric_limits<uint8_t>::min());
53 quant_max_upper_bound =
54 static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
55 } else if (dtype == ScalarType::Char) {
56 quant_min_lower_bound =
57 static_cast<int32_t>(std::numeric_limits<int8_t>::min());
58 quant_max_upper_bound =
59 static_cast<int32_t>(std::numeric_limits<int8_t>::max());
60 } else if (dtype == ScalarType::Bits16 || dtype == ScalarType::UInt16) {
61 quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
62 quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
63 } else if (dtype == ScalarType::Short) {
64 quant_min_lower_bound = std::numeric_limits<int16_t>::min();
65 quant_max_upper_bound = std::numeric_limits<int16_t>::max();
66 } else if (dtype == ScalarType::Int) {
67 quant_min_lower_bound = std::numeric_limits<int32_t>::min();
68 quant_max_upper_bound = std::numeric_limits<int32_t>::max();
69 } else {
70 ET_CHECK_MSG(
71 false, "Unsupported dtype: %" PRId8, static_cast<int8_t>(out_dtype));
72 }
73 ET_CHECK_MSG(
74 quant_min >= quant_min_lower_bound,
75 "quant_min out of bound for dtype, expected quant_min_lower_bound: %" PRId32
76 " actual quant_min: %" PRId64,
77 quant_min_lower_bound,
78 quant_min);
79
80 ET_CHECK_MSG(
81 quant_max <= quant_max_upper_bound,
82 "quant_max out of bound for dtype, expected quant_max_upper_bound: %" PRId32
83 " actual quant_max: %" PRId64,
84 quant_max_upper_bound,
85 quant_max);
86 }
87
88 } // namespace
89
90 template <typename T, typename K>
quantize_val(double scale,int64_t zero_point,K value,int64_t quant_min,int64_t quant_max)91 T quantize_val(
92 double scale,
93 int64_t zero_point,
94 K value,
95 int64_t quant_min,
96 int64_t quant_max) {
97 int64_t qvalue;
98 float inv_scale = 1.0f / static_cast<float>(scale);
99 qvalue = static_cast<int64_t>(
100 static_cast<int32_t>(zero_point) +
101 std::nearbyint(static_cast<float>(inv_scale * value)));
102
103 qvalue = std::max<int64_t>(qvalue, quant_min);
104 qvalue = std::min<int64_t>(qvalue, quant_max);
105 return static_cast<T>(qvalue);
106 }
107
quantize_per_tensor_out(const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)108 Tensor& quantize_per_tensor_out(
109 const Tensor& input,
110 double scale,
111 int64_t zero_point,
112 int64_t quant_min,
113 int64_t quant_max,
114 ScalarType dtype,
115 Tensor& out) {
116 torch::executor::Error err = resize_tensor(out, input.sizes());
117 ET_CHECK_MSG(
118 err == torch::executor::Error::Ok,
119 "Failed to resize out Tensor in quantize_per_tensor_out");
120
121 check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
122
123 // calculate the quantized input
124 #define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
125 case ScalarType::out_dtype: { \
126 /* Hoist these function calls out of our inner loop because they might not \
127 * get inlined without LTO, particularly in ATen mode. */ \
128 auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
129 const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
130 const auto input_numel = input.numel(); \
131 for (size_t i = 0; i < input_numel; i++) { \
132 IN_CTYPE value = input_data_ptr[i]; \
133 out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
134 scale, zero_point, value, quant_min, quant_max); \
135 } \
136 } break;
137 #define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype) \
138 case ScalarType::in_dtype: \
139 switch (out.scalar_type()) { \
140 ET_FORALL_INT_TYPES_WITH(IN_CTYPE, QUANTIZE_IMPL); \
141 QUANTIZE_IMPL(IN_CTYPE, uint16_t, Bits16) \
142 QUANTIZE_IMPL(IN_CTYPE, uint16_t, UInt16) \
143 default: \
144 ET_CHECK_MSG( \
145 false, \
146 "Unhandled output dtype %" PRId8, \
147 static_cast<int8_t>(out.scalar_type())); \
148 } \
149 break;
150
151 switch (input.scalar_type()) {
152 ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
153 default:
154 ET_CHECK_MSG(
155 false,
156 "Unhandled input dtype %" PRId8,
157 static_cast<int8_t>(input.scalar_type()));
158 }
159 #undef CALCULATE_FLOAT_TYPE
160 #undef QUANTIZE_IMPL
161 return out;
162 }
163
quantize_per_tensor_tensor_args_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)164 Tensor& quantize_per_tensor_tensor_args_out(
165 KernelRuntimeContext& context,
166 const Tensor& input,
167 const Tensor& scale,
168 const Tensor& zero_point,
169 int64_t quant_min,
170 int64_t quant_max,
171 ScalarType dtype,
172 Tensor& out) {
173 // Temporary change to allow not fatal failure for now to unblock some
174 // expected failure tests that are dying instead of failure. Will revisit
175 // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal
176 // failures.
177 if (scale.scalar_type() != ScalarType::Double) {
178 context.fail(torch::executor::Error::InvalidArgument);
179 return out;
180 }
181 ET_CHECK_MSG(
182 scale.scalar_type() == ScalarType::Double,
183 "Expected scale to be Double tensor received: %" PRId8,
184 static_cast<int8_t>(scale.scalar_type()));
185 ET_CHECK_MSG(
186 zero_point.scalar_type() == ScalarType::Long,
187 "Expected zero_point to be Long tensor received: %" PRId8,
188 static_cast<int8_t>(zero_point.scalar_type()));
189 ET_CHECK_MSG(
190 scale.numel() == 1,
191 "Exepcted scale to only have one element received: %zd",
192 ssize_t(scale.numel()));
193 ET_CHECK_MSG(
194 zero_point.numel() == 1,
195 "Exepcted zero_point to only have one element received: %zd",
196 ssize_t(zero_point.numel()));
197
198 quantize_per_tensor_out(
199 input,
200 scale.const_data_ptr<double>()[0],
201 zero_point.const_data_ptr<int64_t>()[0],
202 quant_min,
203 quant_max,
204 dtype,
205 out);
206 return out;
207 }
208
quantize_per_tensor_tensor_args_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)209 Tensor& quantize_per_tensor_tensor_args_out(
210 const Tensor& input,
211 const Tensor& scale,
212 const Tensor& zero_point,
213 int64_t quant_min,
214 int64_t quant_max,
215 ScalarType dtype,
216 Tensor& out) {
217 auto context = executorch::runtime::KernelRuntimeContext();
218 auto& res = quantize_per_tensor_tensor_args_out(
219 context, input, scale, zero_point, quant_min, quant_max, dtype, out);
220 ET_CHECK(context.failure_state() == Error::Ok);
221 return res;
222 }
223
quantize_per_tensor_out(KernelRuntimeContext & context,const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)224 Tensor& quantize_per_tensor_out(
225 KernelRuntimeContext& context,
226 const Tensor& input,
227 double scale,
228 int64_t zero_point,
229 int64_t quant_min,
230 int64_t quant_max,
231 ScalarType dtype,
232 Tensor& out) {
233 // TODO(larryliu): Add a context arg to the real op function and remove this
234 // wrapper
235 (void)context;
236 return quantize_per_tensor_out(
237 input, scale, zero_point, quant_min, quant_max, dtype, out);
238 }
239
quantize_per_channel_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)240 Tensor& quantize_per_channel_out(
241 const Tensor& input,
242 const Tensor& scale,
243 const Tensor& zero_point,
244 int64_t axis,
245 int64_t quant_min,
246 int64_t quant_max,
247 ScalarType dtype,
248 Tensor& out) {
249 // normalize axis
250 ET_CHECK_MSG(
251 tensor_has_dim(input, axis),
252 "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd",
253 ssize_t(axis),
254 ssize_t(input.dim()));
255
256 if (axis < 0) {
257 axis += nonzero_dim(input);
258 }
259
260 ET_CHECK_MSG(
261 scale.scalar_type() == ScalarType::Double,
262 "scale.scalar_type() %" PRId8 " is not double type",
263 static_cast<int8_t>(scale.scalar_type()));
264
265 ET_CHECK_MSG(
266 scale.numel() == input.size(axis),
267 "scale.numel() %zd != input.size(axis) %zd",
268 scale.numel(),
269 input.size(axis));
270
271 ET_CHECK_MSG(
272 zero_point.scalar_type() == ScalarType::Long,
273 "zero_point.scalar_type() %" PRId8 " is not integer type",
274 static_cast<int8_t>(zero_point.scalar_type()));
275
276 ET_CHECK_MSG(
277 zero_point.numel() == input.size(axis),
278 "zero_point.numel() %zd != input.size(axis) %zd",
279 zero_point.numel(),
280 input.size(axis));
281
282 check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
283
284 // a list contains all dimensions except axis
285 int64_t dims[kTensorDimensionLimit];
286 for (int64_t i = 0; i < input.dim() - 1; i++) {
287 if (i < axis) {
288 dims[i] = i;
289 } else {
290 dims[i] = i - 1;
291 }
292 }
293 const double* scale_data = scale.const_data_ptr<double>();
294 const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();
295
296 exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
297 exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
298
299 // Actual quantization logic
300 // input, out are the input and output tensors
301 // channel_ix is the index along the axis dimension. 0 <= channel_ix <
302 // input.size(axis).
303 // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
304 // will be 0, 1, 2, ... C-1
305 // in_ix is the flat index of the element you are quantizing.
306 // in other words you are quantizing in_data[in_ix]
307 #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
308 case ScalarType::out_dtype: \
309 for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
310 double _scale = scale_data[channel_ix]; \
311 int64_t _zero_point = zero_point_data[channel_ix]; \
312 auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
313 const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
314 apply_over_dim_list( \
315 [input_data_ptr, \
316 out_data_ptr, \
317 _scale, \
318 _zero_point, \
319 quant_min, \
320 quant_max](size_t in_ix) { \
321 out_data_ptr[in_ix] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
322 _scale, \
323 _zero_point, \
324 input_data_ptr[in_ix], \
325 quant_min, \
326 quant_max); \
327 }, \
328 input, \
329 optional_dim_list, \
330 channel_ix); \
331 } \
332 break;
333 #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
334 case ScalarType::in_dtype: \
335 switch (out.scalar_type()) { \
336 ET_FORALL_INT_TYPES_WITH(CTYPE_IN, QUANTIZE_IMPL); \
337 QUANTIZE_IMPL(CTYPE_IN, uint16_t, Bits16) \
338 QUANTIZE_IMPL(CTYPE_IN, uint16_t, UInt16) \
339 default: \
340 ET_CHECK_MSG( \
341 false, \
342 "Unhandled output dtype %" PRId8, \
343 static_cast<int8_t>(out.scalar_type())); \
344 } \
345 break;
346
347 switch (input.scalar_type()) {
348 ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
349 default:
350 ET_CHECK_MSG(
351 false,
352 "Unhandled input dtype %" PRId8,
353 static_cast<int8_t>(input.scalar_type()));
354 }
355 #undef CALCULATE_FLOAT_TYPE
356 #undef QUANTIZE_IMPL
357
358 return out;
359 }
360
quantize_per_channel_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)361 Tensor& quantize_per_channel_out(
362 KernelRuntimeContext& context,
363 const Tensor& input,
364 const Tensor& scale,
365 const Tensor& zero_point,
366 int64_t axis,
367 int64_t quant_min,
368 int64_t quant_max,
369 ScalarType dtype,
370 Tensor& out) {
371 (void)context;
372 torch::executor::Error err = resize_tensor(out, input.sizes());
373 ET_CHECK_MSG(
374 err == torch::executor::Error::Ok,
375 "Failed to resize out Tensor in quantize_per_channel_out");
376
377 return quantize_per_channel_out(
378 input, scale, zero_point, axis, quant_min, quant_max, dtype, out);
379 }
380
quantize_per_token_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)381 Tensor& quantize_per_token_out(
382 const Tensor& input,
383 const Tensor& scale,
384 const Tensor& zero_point,
385 int64_t quant_min,
386 int64_t quant_max,
387 ScalarType dtype,
388 Tensor& out) {
389 size_t num_tokens = 1;
390 for (size_t i = 0; i < input.dim() - 1; i++) {
391 num_tokens *= input.size(i);
392 }
393 // This unfortunate change is needed because we compile op_quantize for aten
394 // mode as well
395 #ifdef USE_ATEN_LIB
396 std::vector<int64_t> sizes(2);
397 sizes[0] = num_tokens;
398 sizes[1] = input.size(input.dim() - 1);
399 Tensor reshaped_input = at::from_blob(
400 input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type()));
401 #else
402 std::array<exec_aten::DimOrderType, 2> input_dim_order{0, 1};
403 std::array<exec_aten::SizesType, 2> input_sizes;
404 input_sizes[0] = num_tokens;
405 input_sizes[1] = input.size(input.dim() - 1);
406 std::array<exec_aten::StridesType, 2> input_strides;
407 dim_order_to_stride_nocheck(
408 input_sizes.data(), input_dim_order.data(), 2, input_strides.data());
409 void* input_data = input.mutable_data_ptr();
410 TensorImpl reshaped_input_impl = TensorImpl(
411 input.scalar_type(),
412 2,
413 input_sizes.data(),
414 input_data,
415 input_dim_order.data(),
416 input_strides.data(),
417 TensorShapeDynamism::STATIC);
418 Tensor reshaped_input(&reshaped_input_impl);
419 torch::executor::Error err = resize_tensor(out, input.sizes());
420 ET_CHECK_MSG(
421 err == torch::executor::Error::Ok,
422 "Failed to resize out Tensor in quantize_per_channel_out");
423 #endif
424
425 return quantize_per_channel_out(
426 reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out);
427 }
428
quantize_per_token_out(RuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)429 Tensor& quantize_per_token_out(
430 RuntimeContext& context,
431 const Tensor& input,
432 const Tensor& scale,
433 const Tensor& zero_point,
434 int64_t quant_min,
435 int64_t quant_max,
436 ScalarType dtype,
437 Tensor& out) {
438 (void)context;
439 return quantize_per_token_out(
440 input, scale, zero_point, quant_min, quant_max, dtype, out);
441 }
442 } // namespace native
443 } // namespace executor
444 } // namespace torch
445