xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_add.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <algorithm>
11 #include <cinttypes>
12 #include <cmath>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 
18 using Tensor = exec_aten::Tensor;
19 using Scalar = exec_aten::Scalar;
20 using ScalarType = exec_aten::ScalarType;
21 
22 namespace {
23 
24 template <typename INPUT_T, typename OUTPUT_T>
quantize_val(double scale,int64_t zero_point,INPUT_T value,int64_t quant_min,int64_t quant_max)25 OUTPUT_T quantize_val(
26     double scale,
27     int64_t zero_point,
28     INPUT_T value,
29     int64_t quant_min,
30     int64_t quant_max) {
31   int64_t qvalue;
32   float inv_scale = 1.0f / static_cast<float>(scale);
33   qvalue = static_cast<int64_t>(zero_point + std::nearbyint(inv_scale * value));
34   qvalue = std::max<int64_t>(qvalue, quant_min);
35   qvalue = std::min<int64_t>(qvalue, quant_max);
36   return static_cast<OUTPUT_T>(qvalue);
37 }
38 
39 template <typename INPUT_T, typename OUTPUT_T>
dequantize_val(double scale,int64_t zero_point,INPUT_T value)40 OUTPUT_T dequantize_val(double scale, int64_t zero_point, INPUT_T value) {
41   return (value - zero_point) * scale;
42 }
43 
44 /**
45  * Perform element wise addition of the input tensors into out.
46  * Should be numerically equivalent to Dq -> fp add -> Q
47  */
48 template <class CTYPE>
add_tensors(const Tensor & a,float a_scale,int32_t a_zero_point,const Tensor & b,float b_scale,int32_t b_zero_point,Tensor & out,float out_scale,int32_t out_zero_point,int64_t out_quant_min,int64_t out_quant_max)49 void add_tensors(
50     const Tensor& a,
51     float a_scale,
52     int32_t a_zero_point,
53     const Tensor& b,
54     float b_scale,
55     int32_t b_zero_point,
56     Tensor& out,
57     float out_scale,
58     int32_t out_zero_point,
59     int64_t out_quant_min,
60     int64_t out_quant_max) {
61   const size_t n = a.numel();
62 
63   const auto data_a = a.const_data_ptr<CTYPE>();
64   const auto data_b = b.const_data_ptr<CTYPE>();
65   auto data_out = out.mutable_data_ptr<CTYPE>();
66 
67   for (size_t i = 0; i < n; ++i) {
68     // Dq -> fp add -> Q. Can be optimized further
69     const auto dqa =
70         dequantize_val<CTYPE, float>(a_scale, a_zero_point, data_a[i]);
71     const auto dqb =
72         dequantize_val<CTYPE, float>(b_scale, b_zero_point, data_b[i]);
73     const auto accumulate = dqa + dqb;
74 
75     data_out[i] = quantize_val<float, CTYPE>(
76         out_scale, out_zero_point, accumulate, out_quant_min, out_quant_max);
77   }
78 }
79 
80 } // namespace
81 
82 /**
83  * Perform element wise addition of the input tensors into out. Should be
84  * numerically equivalent to Dq -> fp add -> Q
85  *
86  * PREREQ: a and b should be the same shape, quant_min and max should be in
87  * range [0,255]. a and b and out should be the same dtype.
88  */
quantized_add_out(const Tensor & a,double a_scale_d,int64_t a_zero_point_l,int64_t a_quant_min,int64_t a_quant_max,const Tensor & b,double b_scale_d,int64_t b_zero_point_l,int64_t b_quant_min,int64_t b_quant_max,double out_scale_d,int64_t out_zero_point_l,int64_t out_quant_min,int64_t out_quant_max,Tensor & out)89 Tensor& quantized_add_out(
90     const Tensor& a,
91     double a_scale_d,
92     int64_t a_zero_point_l,
93     int64_t a_quant_min,
94     int64_t a_quant_max,
95     const Tensor& b,
96     double b_scale_d,
97     int64_t b_zero_point_l,
98     int64_t b_quant_min,
99     int64_t b_quant_max,
100     double out_scale_d,
101     int64_t out_zero_point_l,
102     int64_t out_quant_min,
103     int64_t out_quant_max,
104     Tensor& out) {
105   ET_CHECK_SAME_SHAPE_AND_DTYPE3(a, b, out);
106   ET_CHECK_MSG(
107       a_quant_min >= 0 && a_quant_max <= 255 && a_quant_min <= a_quant_max,
108       "invalid quant_min: %" PRId64 " or quant_max: %" PRId64
109       " for input tensor a. Min should be <= max and both should be in bounds [0,255]",
110       a_quant_min,
111       a_quant_max);
112   ET_CHECK_MSG(
113       b_quant_min >= 0 && b_quant_max <= 255 && b_quant_min <= b_quant_max,
114       "invalid quant_min: %" PRId64 " or quant_max: %" PRId64
115       " for input tensor b. Min should be <= max and both should be in bounds [0,255]",
116       b_quant_min,
117       b_quant_max);
118   ET_CHECK_MSG(
119       out_quant_min >= 0 && out_quant_max <= 255 &&
120           out_quant_min <= out_quant_max,
121       "invalid quant_min: %" PRId64 " or quant_max: %" PRId64
122       " for output tensor. Min should be <= max and both should be in bounds [0,255]",
123       out_quant_min,
124       out_quant_max);
125 
126   // downsize to maintain numerical consistency with fbgemm
127   float a_scale = static_cast<float>(a_scale_d);
128   float b_scale = static_cast<float>(b_scale_d);
129   float out_scale = static_cast<float>(out_scale_d);
130 
131   int32_t a_zero_point = static_cast<int32_t>(a_zero_point_l);
132   int32_t b_zero_point = static_cast<int32_t>(b_zero_point_l);
133   int32_t out_zero_point = static_cast<int32_t>(out_zero_point_l);
134 
135 #define ADD_TENSORS(ctype, dtype) \
136   case ScalarType::dtype:         \
137     add_tensors<ctype>(           \
138         a,                        \
139         a_scale,                  \
140         a_zero_point,             \
141         b,                        \
142         b_scale,                  \
143         b_zero_point,             \
144         out,                      \
145         out_scale,                \
146         out_zero_point,           \
147         out_quant_min,            \
148         out_quant_max);           \
149     break;
150 
151   switch (a.scalar_type()) {
152     ET_FORALL_INT_TYPES(ADD_TENSORS)
153     default:
154       ET_CHECK_MSG(
155           false,
156           "Unhandled dtype %" PRId8,
157           static_cast<int8_t>(a.scalar_type()));
158   }
159 
160 #undef ADD_TENSORS
161 
162   return out;
163 }
164 
quantized_add_out(KernelRuntimeContext & context,const Tensor & a,double a_scale,int64_t a_zero_point,int64_t a_quant_min,int64_t a_quant_max,const Tensor & b,double b_scale,int64_t b_zero_point,int64_t b_quant_min,int64_t b_quant_max,double out_scale,int64_t out_zero_point,int64_t out_quant_min,int64_t out_quant_max,Tensor & out)165 Tensor& quantized_add_out(
166     KernelRuntimeContext& context,
167     const Tensor& a,
168     double a_scale,
169     int64_t a_zero_point,
170     int64_t a_quant_min,
171     int64_t a_quant_max,
172     const Tensor& b,
173     double b_scale,
174     int64_t b_zero_point,
175     int64_t b_quant_min,
176     int64_t b_quant_max,
177     double out_scale,
178     int64_t out_zero_point,
179     int64_t out_quant_min,
180     int64_t out_quant_max,
181     Tensor& out) {
182   // TODO(larryliu): Add a context arg to the real op function and remove this
183   // wrapper
184   (void)context;
185   return quantized_add_out(
186       a,
187       a_scale,
188       a_zero_point,
189       a_quant_min,
190       a_quant_max,
191       b,
192       b_scale,
193       b_zero_point,
194       b_quant_min,
195       b_quant_max,
196       out_scale,
197       out_zero_point,
198       out_quant_min,
199       out_quant_max,
200       out);
201 }
202 
203 } // namespace native
204 } // namespace executor
205 } // namespace torch
206