xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_quantize.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14 
15 /**
16  * For an input tensor, use the scale and zero_point arguments to quantize it.
17  */
18 namespace torch {
19 namespace executor {
20 namespace native {
21 
22 using Tensor = exec_aten::Tensor;
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25 
26 namespace {
27 
28 /**
29  * Asserts that the parameters are valid.
30  */
check_quantize_per_tensor_args(const Tensor & input,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)31 void check_quantize_per_tensor_args(
32     const Tensor& input,
33     int64_t quant_min,
34     int64_t quant_max,
35     ScalarType dtype,
36     Tensor& out) {
37   // Ensure self and out has the same shape
38   ET_CHECK_MSG(
39       torch::executor::isFloatingType(input.scalar_type()),
40       "input.scalar_type() %" PRId8 " is not floating type",
41       static_cast<int8_t>(input.scalar_type()));
42 
43   int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0;
44   ScalarType out_dtype = out.scalar_type();
45   ET_CHECK_MSG(
46       out_dtype == dtype,
47       "out.scalar_type() %" PRId8 " is not matching dtype argument %" PRId8,
48       static_cast<int8_t>(out_dtype),
49       static_cast<int8_t>(dtype));
50   if (out_dtype == ScalarType::Byte) {
51     quant_min_lower_bound =
52         static_cast<int32_t>(std::numeric_limits<uint8_t>::min());
53     quant_max_upper_bound =
54         static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
55   } else if (dtype == ScalarType::Char) {
56     quant_min_lower_bound =
57         static_cast<int32_t>(std::numeric_limits<int8_t>::min());
58     quant_max_upper_bound =
59         static_cast<int32_t>(std::numeric_limits<int8_t>::max());
60   } else if (dtype == ScalarType::Bits16 || dtype == ScalarType::UInt16) {
61     quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
62     quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
63   } else if (dtype == ScalarType::Short) {
64     quant_min_lower_bound = std::numeric_limits<int16_t>::min();
65     quant_max_upper_bound = std::numeric_limits<int16_t>::max();
66   } else if (dtype == ScalarType::Int) {
67     quant_min_lower_bound = std::numeric_limits<int32_t>::min();
68     quant_max_upper_bound = std::numeric_limits<int32_t>::max();
69   } else {
70     ET_CHECK_MSG(
71         false, "Unsupported dtype: %" PRId8, static_cast<int8_t>(out_dtype));
72   }
73   ET_CHECK_MSG(
74       quant_min >= quant_min_lower_bound,
75       "quant_min out of bound for dtype, expected quant_min_lower_bound: %" PRId32
76       " actual quant_min: %" PRId64,
77       quant_min_lower_bound,
78       quant_min);
79 
80   ET_CHECK_MSG(
81       quant_max <= quant_max_upper_bound,
82       "quant_max out of bound for dtype, expected quant_max_upper_bound: %" PRId32
83       " actual quant_max: %" PRId64,
84       quant_max_upper_bound,
85       quant_max);
86 }
87 
88 } // namespace
89 
90 template <typename T, typename K>
quantize_val(double scale,int64_t zero_point,K value,int64_t quant_min,int64_t quant_max)91 T quantize_val(
92     double scale,
93     int64_t zero_point,
94     K value,
95     int64_t quant_min,
96     int64_t quant_max) {
97   int64_t qvalue;
98   float inv_scale = 1.0f / static_cast<float>(scale);
99   qvalue = static_cast<int64_t>(
100       static_cast<int32_t>(zero_point) +
101       std::nearbyint(static_cast<float>(inv_scale * value)));
102 
103   qvalue = std::max<int64_t>(qvalue, quant_min);
104   qvalue = std::min<int64_t>(qvalue, quant_max);
105   return static_cast<T>(qvalue);
106 }
107 
quantize_per_tensor_out(const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)108 Tensor& quantize_per_tensor_out(
109     const Tensor& input,
110     double scale,
111     int64_t zero_point,
112     int64_t quant_min,
113     int64_t quant_max,
114     ScalarType dtype,
115     Tensor& out) {
116   torch::executor::Error err = resize_tensor(out, input.sizes());
117   ET_CHECK_MSG(
118       err == torch::executor::Error::Ok,
119       "Failed to resize out Tensor in quantize_per_tensor_out");
120 
121   check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
122 
123   // calculate the quantized input
124 #define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype)                          \
125   case ScalarType::out_dtype: {                                                \
126     /* Hoist these function calls out of our inner loop because they might not \
127      * get inlined without LTO, particularly in ATen mode. */                  \
128     auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>();                    \
129     const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>();             \
130     const auto input_numel = input.numel();                                    \
131     for (size_t i = 0; i < input_numel; i++) {                                 \
132       IN_CTYPE value = input_data_ptr[i];                                      \
133       out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>(                     \
134           scale, zero_point, value, quant_min, quant_max);                     \
135     }                                                                          \
136   } break;
137 #define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype)         \
138   case ScalarType::in_dtype:                             \
139     switch (out.scalar_type()) {                         \
140       ET_FORALL_INT_TYPES_WITH(IN_CTYPE, QUANTIZE_IMPL); \
141       QUANTIZE_IMPL(IN_CTYPE, uint16_t, Bits16)          \
142       QUANTIZE_IMPL(IN_CTYPE, uint16_t, UInt16)          \
143       default:                                           \
144         ET_CHECK_MSG(                                    \
145             false,                                       \
146             "Unhandled output dtype %" PRId8,            \
147             static_cast<int8_t>(out.scalar_type()));     \
148     }                                                    \
149     break;
150 
151   switch (input.scalar_type()) {
152     ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
153     default:
154       ET_CHECK_MSG(
155           false,
156           "Unhandled input dtype %" PRId8,
157           static_cast<int8_t>(input.scalar_type()));
158   }
159 #undef CALCULATE_FLOAT_TYPE
160 #undef QUANTIZE_IMPL
161   return out;
162 }
163 
quantize_per_tensor_tensor_args_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)164 Tensor& quantize_per_tensor_tensor_args_out(
165     KernelRuntimeContext& context,
166     const Tensor& input,
167     const Tensor& scale,
168     const Tensor& zero_point,
169     int64_t quant_min,
170     int64_t quant_max,
171     ScalarType dtype,
172     Tensor& out) {
173   // Temporary change to allow not fatal failure for now to unblock some
174   // expected failure tests that are dying instead of failure. Will revisit
175   // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal
176   // failures.
177   if (scale.scalar_type() != ScalarType::Double) {
178     context.fail(torch::executor::Error::InvalidArgument);
179     return out;
180   }
181   ET_CHECK_MSG(
182       scale.scalar_type() == ScalarType::Double,
183       "Expected scale to be Double tensor received: %" PRId8,
184       static_cast<int8_t>(scale.scalar_type()));
185   ET_CHECK_MSG(
186       zero_point.scalar_type() == ScalarType::Long,
187       "Expected zero_point to be Long tensor received: %" PRId8,
188       static_cast<int8_t>(zero_point.scalar_type()));
189   ET_CHECK_MSG(
190       scale.numel() == 1,
191       "Exepcted scale to only have one element received: %zd",
192       ssize_t(scale.numel()));
193   ET_CHECK_MSG(
194       zero_point.numel() == 1,
195       "Exepcted zero_point to only have one element received: %zd",
196       ssize_t(zero_point.numel()));
197 
198   quantize_per_tensor_out(
199       input,
200       scale.const_data_ptr<double>()[0],
201       zero_point.const_data_ptr<int64_t>()[0],
202       quant_min,
203       quant_max,
204       dtype,
205       out);
206   return out;
207 }
208 
quantize_per_tensor_tensor_args_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)209 Tensor& quantize_per_tensor_tensor_args_out(
210     const Tensor& input,
211     const Tensor& scale,
212     const Tensor& zero_point,
213     int64_t quant_min,
214     int64_t quant_max,
215     ScalarType dtype,
216     Tensor& out) {
217   auto context = executorch::runtime::KernelRuntimeContext();
218   auto& res = quantize_per_tensor_tensor_args_out(
219       context, input, scale, zero_point, quant_min, quant_max, dtype, out);
220   ET_CHECK(context.failure_state() == Error::Ok);
221   return res;
222 }
223 
quantize_per_tensor_out(KernelRuntimeContext & context,const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)224 Tensor& quantize_per_tensor_out(
225     KernelRuntimeContext& context,
226     const Tensor& input,
227     double scale,
228     int64_t zero_point,
229     int64_t quant_min,
230     int64_t quant_max,
231     ScalarType dtype,
232     Tensor& out) {
233   // TODO(larryliu): Add a context arg to the real op function and remove this
234   // wrapper
235   (void)context;
236   return quantize_per_tensor_out(
237       input, scale, zero_point, quant_min, quant_max, dtype, out);
238 }
239 
quantize_per_channel_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)240 Tensor& quantize_per_channel_out(
241     const Tensor& input,
242     const Tensor& scale,
243     const Tensor& zero_point,
244     int64_t axis,
245     int64_t quant_min,
246     int64_t quant_max,
247     ScalarType dtype,
248     Tensor& out) {
249   // normalize axis
250   ET_CHECK_MSG(
251       tensor_has_dim(input, axis),
252       "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd",
253       ssize_t(axis),
254       ssize_t(input.dim()));
255 
256   if (axis < 0) {
257     axis += nonzero_dim(input);
258   }
259 
260   ET_CHECK_MSG(
261       scale.scalar_type() == ScalarType::Double,
262       "scale.scalar_type() %" PRId8 " is not double type",
263       static_cast<int8_t>(scale.scalar_type()));
264 
265   ET_CHECK_MSG(
266       scale.numel() == input.size(axis),
267       "scale.numel() %zd != input.size(axis) %zd",
268       scale.numel(),
269       input.size(axis));
270 
271   ET_CHECK_MSG(
272       zero_point.scalar_type() == ScalarType::Long,
273       "zero_point.scalar_type() %" PRId8 " is not integer type",
274       static_cast<int8_t>(zero_point.scalar_type()));
275 
276   ET_CHECK_MSG(
277       zero_point.numel() == input.size(axis),
278       "zero_point.numel() %zd != input.size(axis) %zd",
279       zero_point.numel(),
280       input.size(axis));
281 
282   check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
283 
284   // a list contains all dimensions except axis
285   int64_t dims[kTensorDimensionLimit];
286   for (int64_t i = 0; i < input.dim() - 1; i++) {
287     if (i < axis) {
288       dims[i] = i;
289     } else {
290       dims[i] = i - 1;
291     }
292   }
293   const double* scale_data = scale.const_data_ptr<double>();
294   const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();
295 
296   exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
297       exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
298 
299   // Actual quantization logic
300   // input, out are the input and output tensors
301   // channel_ix is the index along the axis dimension. 0 <= channel_ix <
302   // input.size(axis).
303   //   i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
304   //   will be 0, 1, 2, ... C-1
305   // in_ix is the flat index of the element you are quantizing.
306   //   in other words you are quantizing in_data[in_ix]
307 #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype)                          \
308   case ScalarType::out_dtype:                                                  \
309     for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
310       double _scale = scale_data[channel_ix];                                  \
311       int64_t _zero_point = zero_point_data[channel_ix];                       \
312       auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>();                  \
313       const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>();           \
314       apply_over_dim_list(                                                     \
315           [input_data_ptr,                                                     \
316            out_data_ptr,                                                       \
317            _scale,                                                             \
318            _zero_point,                                                        \
319            quant_min,                                                          \
320            quant_max](size_t in_ix) {                                          \
321             out_data_ptr[in_ix] = quantize_val<CTYPE_OUT, CTYPE_IN>(           \
322                 _scale,                                                        \
323                 _zero_point,                                                   \
324                 input_data_ptr[in_ix],                                         \
325                 quant_min,                                                     \
326                 quant_max);                                                    \
327           },                                                                   \
328           input,                                                               \
329           optional_dim_list,                                                   \
330           channel_ix);                                                         \
331     }                                                                          \
332     break;
333 #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype)         \
334   case ScalarType::in_dtype:                             \
335     switch (out.scalar_type()) {                         \
336       ET_FORALL_INT_TYPES_WITH(CTYPE_IN, QUANTIZE_IMPL); \
337       QUANTIZE_IMPL(CTYPE_IN, uint16_t, Bits16)          \
338       QUANTIZE_IMPL(CTYPE_IN, uint16_t, UInt16)          \
339       default:                                           \
340         ET_CHECK_MSG(                                    \
341             false,                                       \
342             "Unhandled output dtype %" PRId8,            \
343             static_cast<int8_t>(out.scalar_type()));     \
344     }                                                    \
345     break;
346 
347   switch (input.scalar_type()) {
348     ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
349     default:
350       ET_CHECK_MSG(
351           false,
352           "Unhandled input dtype %" PRId8,
353           static_cast<int8_t>(input.scalar_type()));
354   }
355 #undef CALCULATE_FLOAT_TYPE
356 #undef QUANTIZE_IMPL
357 
358   return out;
359 }
360 
quantize_per_channel_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)361 Tensor& quantize_per_channel_out(
362     KernelRuntimeContext& context,
363     const Tensor& input,
364     const Tensor& scale,
365     const Tensor& zero_point,
366     int64_t axis,
367     int64_t quant_min,
368     int64_t quant_max,
369     ScalarType dtype,
370     Tensor& out) {
371   (void)context;
372   torch::executor::Error err = resize_tensor(out, input.sizes());
373   ET_CHECK_MSG(
374       err == torch::executor::Error::Ok,
375       "Failed to resize out Tensor in quantize_per_channel_out");
376 
377   return quantize_per_channel_out(
378       input, scale, zero_point, axis, quant_min, quant_max, dtype, out);
379 }
380 
quantize_per_token_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)381 Tensor& quantize_per_token_out(
382     const Tensor& input,
383     const Tensor& scale,
384     const Tensor& zero_point,
385     int64_t quant_min,
386     int64_t quant_max,
387     ScalarType dtype,
388     Tensor& out) {
389   size_t num_tokens = 1;
390   for (size_t i = 0; i < input.dim() - 1; i++) {
391     num_tokens *= input.size(i);
392   }
393 // This unfortunate change is needed because we compile op_quantize for aten
394 // mode as well
395 #ifdef USE_ATEN_LIB
396   std::vector<int64_t> sizes(2);
397   sizes[0] = num_tokens;
398   sizes[1] = input.size(input.dim() - 1);
399   Tensor reshaped_input = at::from_blob(
400       input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type()));
401 #else
402   std::array<exec_aten::DimOrderType, 2> input_dim_order{0, 1};
403   std::array<exec_aten::SizesType, 2> input_sizes;
404   input_sizes[0] = num_tokens;
405   input_sizes[1] = input.size(input.dim() - 1);
406   std::array<exec_aten::StridesType, 2> input_strides;
407   dim_order_to_stride_nocheck(
408       input_sizes.data(), input_dim_order.data(), 2, input_strides.data());
409   void* input_data = input.mutable_data_ptr();
410   TensorImpl reshaped_input_impl = TensorImpl(
411       input.scalar_type(),
412       2,
413       input_sizes.data(),
414       input_data,
415       input_dim_order.data(),
416       input_strides.data(),
417       TensorShapeDynamism::STATIC);
418   Tensor reshaped_input(&reshaped_input_impl);
419   torch::executor::Error err = resize_tensor(out, input.sizes());
420   ET_CHECK_MSG(
421       err == torch::executor::Error::Ok,
422       "Failed to resize out Tensor in quantize_per_channel_out");
423 #endif
424 
425   return quantize_per_channel_out(
426       reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out);
427 }
428 
quantize_per_token_out(RuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,Tensor & out)429 Tensor& quantize_per_token_out(
430     RuntimeContext& context,
431     const Tensor& input,
432     const Tensor& scale,
433     const Tensor& zero_point,
434     int64_t quant_min,
435     int64_t quant_max,
436     ScalarType dtype,
437     Tensor& out) {
438   (void)context;
439   return quantize_per_token_out(
440       input, scale, zero_point, quant_min, quant_max, dtype, out);
441 }
442 } // namespace native
443 } // namespace executor
444 } // namespace torch
445