xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/embeddingxb.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_includes.h>
11*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
12*523fa7a6SAndroid Build Coastguard Worker #include <cassert>
13*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes>
14*523fa7a6SAndroid Build Coastguard Worker #include <cmath>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker namespace torch {
17*523fa7a6SAndroid Build Coastguard Worker namespace executor {
18*523fa7a6SAndroid Build Coastguard Worker namespace native {
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker using Tensor = exec_aten::Tensor;
21*523fa7a6SAndroid Build Coastguard Worker using Scalar = exec_aten::Scalar;
22*523fa7a6SAndroid Build Coastguard Worker using ScalarType = exec_aten::ScalarType;
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker namespace {
25*523fa7a6SAndroid Build Coastguard Worker 
26*523fa7a6SAndroid Build Coastguard Worker static inline int32_t
weight_value(const unsigned char * w_data,int32_t index,int32_t weight_nbit)27*523fa7a6SAndroid Build Coastguard Worker weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) {
28*523fa7a6SAndroid Build Coastguard Worker   if (weight_nbit == 2) {
29*523fa7a6SAndroid Build Coastguard Worker     int32_t subbyte = index % 4;
30*523fa7a6SAndroid Build Coastguard Worker     index >>= 2;
31*523fa7a6SAndroid Build Coastguard Worker     switch (subbyte) {
32*523fa7a6SAndroid Build Coastguard Worker       case 0:
33*523fa7a6SAndroid Build Coastguard Worker         return (int32_t)(w_data[index] & 3) - 2;
34*523fa7a6SAndroid Build Coastguard Worker       case 1:
35*523fa7a6SAndroid Build Coastguard Worker         return (int32_t)((w_data[index] & 12) >> 2) - 2;
36*523fa7a6SAndroid Build Coastguard Worker       case 2:
37*523fa7a6SAndroid Build Coastguard Worker         return (int32_t)((w_data[index] & 48) >> 4) - 2;
38*523fa7a6SAndroid Build Coastguard Worker       case 3:
39*523fa7a6SAndroid Build Coastguard Worker         return (int32_t)((w_data[index] & 192) >> 6) - 2;
40*523fa7a6SAndroid Build Coastguard Worker     }
41*523fa7a6SAndroid Build Coastguard Worker   } else if (weight_nbit == 4) {
42*523fa7a6SAndroid Build Coastguard Worker     int32_t odd = index & 1;
43*523fa7a6SAndroid Build Coastguard Worker     index >>= 1;
44*523fa7a6SAndroid Build Coastguard Worker     if (odd) {
45*523fa7a6SAndroid Build Coastguard Worker       return (int32_t)(w_data[index] & 0x0F) - 8;
46*523fa7a6SAndroid Build Coastguard Worker     } else {
47*523fa7a6SAndroid Build Coastguard Worker       return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
48*523fa7a6SAndroid Build Coastguard Worker     }
49*523fa7a6SAndroid Build Coastguard Worker   }
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(false, "invalid weight_nbit");
52*523fa7a6SAndroid Build Coastguard Worker }
53*523fa7a6SAndroid Build Coastguard Worker 
get_embedding_dim(int32_t packed_dim,int32_t weight_nbit)54*523fa7a6SAndroid Build Coastguard Worker static inline int32_t get_embedding_dim(
55*523fa7a6SAndroid Build Coastguard Worker     int32_t packed_dim,
56*523fa7a6SAndroid Build Coastguard Worker     int32_t weight_nbit) {
57*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(8 % weight_nbit == 0, "invalid embedding dim");
58*523fa7a6SAndroid Build Coastguard Worker   int packed_values_per_byte = 8 / weight_nbit;
59*523fa7a6SAndroid Build Coastguard Worker   return packed_dim * packed_values_per_byte;
60*523fa7a6SAndroid Build Coastguard Worker }
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker /**
63*523fa7a6SAndroid Build Coastguard Worker  * Asserts that the parameters are valid.
64*523fa7a6SAndroid Build Coastguard Worker  */
check_embedding_xbit_args(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)65*523fa7a6SAndroid Build Coastguard Worker void check_embedding_xbit_args(
66*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
67*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
68*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
69*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_min,
70*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_max,
71*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
72*523fa7a6SAndroid Build Coastguard Worker     exec_aten::optional<ScalarType> out_dtype,
73*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
74*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
75*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8");
76*523fa7a6SAndroid Build Coastguard Worker 
77*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
78*523fa7a6SAndroid Build Coastguard Worker       weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
79*523fa7a6SAndroid Build Coastguard Worker 
80*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
81*523fa7a6SAndroid Build Coastguard Worker       weight_scales.dim() == 1 || weight_scales.dim() == 2,
82*523fa7a6SAndroid Build Coastguard Worker       "weight_scales must be 1D or 2D but got() %zd dims",
83*523fa7a6SAndroid Build Coastguard Worker       weight_scales.dim());
84*523fa7a6SAndroid Build Coastguard Worker 
85*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
86*523fa7a6SAndroid Build Coastguard Worker       weight_scales.size(0) == weight.size(0),
87*523fa7a6SAndroid Build Coastguard Worker       "Number of scales must be == weight.size(0)=%zd"
88*523fa7a6SAndroid Build Coastguard Worker       ", but got %zd",
89*523fa7a6SAndroid Build Coastguard Worker       weight_scales.size(0),
90*523fa7a6SAndroid Build Coastguard Worker       weight.size(0));
91*523fa7a6SAndroid Build Coastguard Worker 
92*523fa7a6SAndroid Build Coastguard Worker   if (weight_scales.dim() == 2) {
93*523fa7a6SAndroid Build Coastguard Worker     auto num_groups = weight_scales.size(1);
94*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_MSG(
95*523fa7a6SAndroid Build Coastguard Worker         // each 8b uint8 column is packed_values_per_byte columns
96*523fa7a6SAndroid Build Coastguard Worker         get_embedding_dim(weight.size(1), weight_nbit) % num_groups == 0,
97*523fa7a6SAndroid Build Coastguard Worker         "Number of groups must divide weight.size(1)=%zd"
98*523fa7a6SAndroid Build Coastguard Worker         ", but got # of groups = %zd",
99*523fa7a6SAndroid Build Coastguard Worker         weight.size(1),
100*523fa7a6SAndroid Build Coastguard Worker         num_groups);
101*523fa7a6SAndroid Build Coastguard Worker   }
102*523fa7a6SAndroid Build Coastguard Worker 
103*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
104*523fa7a6SAndroid Build Coastguard Worker       weight.scalar_type() == ScalarType::Byte,
105*523fa7a6SAndroid Build Coastguard Worker       "weight.scalar_type() %" PRId8 " is not supported:",
106*523fa7a6SAndroid Build Coastguard Worker       static_cast<int8_t>(weight.scalar_type()));
107*523fa7a6SAndroid Build Coastguard Worker 
108*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
109*523fa7a6SAndroid Build Coastguard Worker       out.scalar_type() == ScalarType::Float ||
110*523fa7a6SAndroid Build Coastguard Worker           out.scalar_type() == ScalarType::Half,
111*523fa7a6SAndroid Build Coastguard Worker       "out.scalar_type() %" PRId8 " is not supported:",
112*523fa7a6SAndroid Build Coastguard Worker       static_cast<int8_t>(out.scalar_type()));
113*523fa7a6SAndroid Build Coastguard Worker 
114*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
115*523fa7a6SAndroid Build Coastguard Worker       weight_scales.scalar_type() == ScalarType::Float ||
116*523fa7a6SAndroid Build Coastguard Worker           weight_scales.scalar_type() == ScalarType::Half,
117*523fa7a6SAndroid Build Coastguard Worker       "weight_scales.scalar_type() %" PRId8 " is not supported:",
118*523fa7a6SAndroid Build Coastguard Worker       static_cast<int8_t>(weight_scales.scalar_type()));
119*523fa7a6SAndroid Build Coastguard Worker 
120*523fa7a6SAndroid Build Coastguard Worker   if (opt_weight_zero_points.has_value()) {
121*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_MSG(
122*523fa7a6SAndroid Build Coastguard Worker         opt_weight_zero_points.value().dim() == weight_scales.dim(),
123*523fa7a6SAndroid Build Coastguard Worker         "weight_zero_points's rank match that of weight_scales. "
124*523fa7a6SAndroid Build Coastguard Worker         "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
125*523fa7a6SAndroid Build Coastguard Worker         static_cast<int8_t>(opt_weight_zero_points.value().dim()),
126*523fa7a6SAndroid Build Coastguard Worker         static_cast<int8_t>(weight_scales.dim()));
127*523fa7a6SAndroid Build Coastguard Worker 
128*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_MSG(
129*523fa7a6SAndroid Build Coastguard Worker         opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
130*523fa7a6SAndroid Build Coastguard Worker         "weight zero points scalar type %" PRId8
131*523fa7a6SAndroid Build Coastguard Worker         " does not match out.scalar_type()",
132*523fa7a6SAndroid Build Coastguard Worker         static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
133*523fa7a6SAndroid Build Coastguard Worker 
134*523fa7a6SAndroid Build Coastguard Worker     for (int32_t i = 0; i < weight_scales.dim(); ++i) {
135*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_MSG(
136*523fa7a6SAndroid Build Coastguard Worker           opt_weight_zero_points.value().size(i) == weight_scales.size(i),
137*523fa7a6SAndroid Build Coastguard Worker           "Dimension size misatch at dim %" PRIi32
138*523fa7a6SAndroid Build Coastguard Worker           "Weight_zero_point size = %zd"
139*523fa7a6SAndroid Build Coastguard Worker           ", weight_scales size = %zd.",
140*523fa7a6SAndroid Build Coastguard Worker           i,
141*523fa7a6SAndroid Build Coastguard Worker           opt_weight_zero_points.value().size(i),
142*523fa7a6SAndroid Build Coastguard Worker           weight_scales.size(i));
143*523fa7a6SAndroid Build Coastguard Worker     }
144*523fa7a6SAndroid Build Coastguard Worker   }
145*523fa7a6SAndroid Build Coastguard Worker 
146*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
147*523fa7a6SAndroid Build Coastguard Worker       indices.scalar_type() == ScalarType::Long,
148*523fa7a6SAndroid Build Coastguard Worker       "indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
149*523fa7a6SAndroid Build Coastguard Worker       static_cast<int8_t>(indices.scalar_type()));
150*523fa7a6SAndroid Build Coastguard Worker 
151*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
152*523fa7a6SAndroid Build Coastguard Worker       weight_quant_min <= weight_quant_max,
153*523fa7a6SAndroid Build Coastguard Worker       "weight quant min: %" PRId64
154*523fa7a6SAndroid Build Coastguard Worker       " is greater than weight quant max: %" PRId64,
155*523fa7a6SAndroid Build Coastguard Worker       weight_quant_min,
156*523fa7a6SAndroid Build Coastguard Worker       weight_quant_max);
157*523fa7a6SAndroid Build Coastguard Worker 
158*523fa7a6SAndroid Build Coastguard Worker   if (out_dtype.has_value()) {
159*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_MSG(
160*523fa7a6SAndroid Build Coastguard Worker         out.scalar_type() == out_dtype.value(),
161*523fa7a6SAndroid Build Coastguard Worker         "output_dtype must match the dtype of the out tensor");
162*523fa7a6SAndroid Build Coastguard Worker   }
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker 
165*523fa7a6SAndroid Build Coastguard Worker /**
166*523fa7a6SAndroid Build Coastguard Worker  * Retrieves the embeddings specified by indices, dequantizes them, and stores
167*523fa7a6SAndroid Build Coastguard Worker  * them in out. Weight will always be uint8
168*523fa7a6SAndroid Build Coastguard Worker  */
169*523fa7a6SAndroid Build Coastguard Worker template <typename CTYPE_PARAMS, typename CTYPE_OUT>
embedding_xbit_per_channel(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const Tensor & indices,Tensor & out,int weight_nbit)170*523fa7a6SAndroid Build Coastguard Worker void embedding_xbit_per_channel(
171*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
172*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
173*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
174*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
175*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
176*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
177*523fa7a6SAndroid Build Coastguard Worker   auto embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
178*523fa7a6SAndroid Build Coastguard Worker 
179*523fa7a6SAndroid Build Coastguard Worker   int32_t num_groups_per_channel = 1;
180*523fa7a6SAndroid Build Coastguard Worker   if (weight_scales.dim() == 2) {
181*523fa7a6SAndroid Build Coastguard Worker     num_groups_per_channel = weight_scales.size(1);
182*523fa7a6SAndroid Build Coastguard Worker   }
183*523fa7a6SAndroid Build Coastguard Worker   int32_t group_size = embedding_dim / num_groups_per_channel;
184*523fa7a6SAndroid Build Coastguard Worker 
185*523fa7a6SAndroid Build Coastguard Worker   CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
186*523fa7a6SAndroid Build Coastguard Worker   const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
187*523fa7a6SAndroid Build Coastguard Worker 
188*523fa7a6SAndroid Build Coastguard Worker   const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
189*523fa7a6SAndroid Build Coastguard Worker   const CTYPE_PARAMS* zero_points = nullptr;
190*523fa7a6SAndroid Build Coastguard Worker   if (opt_weight_zero_points.has_value()) {
191*523fa7a6SAndroid Build Coastguard Worker     zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
192*523fa7a6SAndroid Build Coastguard Worker   }
193*523fa7a6SAndroid Build Coastguard Worker 
194*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < indices.numel(); i++) {
195*523fa7a6SAndroid Build Coastguard Worker     int64_t index = indices_ptr[i];
196*523fa7a6SAndroid Build Coastguard Worker     // If using groupwise embedding
197*523fa7a6SAndroid Build Coastguard Worker     int32_t qparams_index = index * num_groups_per_channel;
198*523fa7a6SAndroid Build Coastguard Worker     CTYPE_PARAMS zp = 0.0;
199*523fa7a6SAndroid Build Coastguard Worker     const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
200*523fa7a6SAndroid Build Coastguard Worker     const CTYPE_PARAMS* zero_points_ptr = nullptr;
201*523fa7a6SAndroid Build Coastguard Worker     if (opt_weight_zero_points.has_value()) {
202*523fa7a6SAndroid Build Coastguard Worker       zero_points_ptr = zero_points + qparams_index;
203*523fa7a6SAndroid Build Coastguard Worker     }
204*523fa7a6SAndroid Build Coastguard Worker 
205*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* w_data =
206*523fa7a6SAndroid Build Coastguard Worker         weight.const_data_ptr<uint8_t>() + weight.size(1) * index;
207*523fa7a6SAndroid Build Coastguard Worker 
208*523fa7a6SAndroid Build Coastguard Worker     for (int j = 0; j < embedding_dim; ++j) {
209*523fa7a6SAndroid Build Coastguard Worker       int32_t group_id = j / group_size;
210*523fa7a6SAndroid Build Coastguard Worker       const CTYPE_PARAMS scale = scale_ptr[group_id];
211*523fa7a6SAndroid Build Coastguard Worker       if (opt_weight_zero_points.has_value()) {
212*523fa7a6SAndroid Build Coastguard Worker         zp = zero_points_ptr[group_id];
213*523fa7a6SAndroid Build Coastguard Worker       }
214*523fa7a6SAndroid Build Coastguard Worker       out_data[j] = static_cast<CTYPE_OUT>(
215*523fa7a6SAndroid Build Coastguard Worker           (static_cast<float>(weight_value(w_data, j, weight_nbit)) -
216*523fa7a6SAndroid Build Coastguard Worker            static_cast<float>(zp)) *
217*523fa7a6SAndroid Build Coastguard Worker           static_cast<float>(scale));
218*523fa7a6SAndroid Build Coastguard Worker     }
219*523fa7a6SAndroid Build Coastguard Worker     out_data += embedding_dim;
220*523fa7a6SAndroid Build Coastguard Worker   }
221*523fa7a6SAndroid Build Coastguard Worker }
222*523fa7a6SAndroid Build Coastguard Worker 
resize_out_tensor(const Tensor & weight,const Tensor & indices,Tensor & out,int weight_nbit)223*523fa7a6SAndroid Build Coastguard Worker void resize_out_tensor(
224*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
225*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
226*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
227*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
228*523fa7a6SAndroid Build Coastguard Worker   exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
229*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < indices.dim(); i++) {
230*523fa7a6SAndroid Build Coastguard Worker     expected_output_size[i] = indices.size(i);
231*523fa7a6SAndroid Build Coastguard Worker   }
232*523fa7a6SAndroid Build Coastguard Worker   const size_t embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
233*523fa7a6SAndroid Build Coastguard Worker   expected_output_size[out.dim() - 1] = embedding_dim;
234*523fa7a6SAndroid Build Coastguard Worker 
235*523fa7a6SAndroid Build Coastguard Worker   exec_aten::ArrayRef<exec_aten::SizesType> output_size{
236*523fa7a6SAndroid Build Coastguard Worker       expected_output_size, static_cast<size_t>(out.dim())};
237*523fa7a6SAndroid Build Coastguard Worker 
238*523fa7a6SAndroid Build Coastguard Worker   torch::executor::Error err = resize_tensor(out, output_size);
239*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
240*523fa7a6SAndroid Build Coastguard Worker       err == torch::executor::Error::Ok,
241*523fa7a6SAndroid Build Coastguard Worker       "Failed to resize out Tensor in quantized_embedding_xbit_out");
242*523fa7a6SAndroid Build Coastguard Worker }
243*523fa7a6SAndroid Build Coastguard Worker 
244*523fa7a6SAndroid Build Coastguard Worker } // namespace
245*523fa7a6SAndroid Build Coastguard Worker 
246*523fa7a6SAndroid Build Coastguard Worker /**
247*523fa7a6SAndroid Build Coastguard Worker  * Retrieves the embeddings specified by indices, dequantizes them, and stores
248*523fa7a6SAndroid Build Coastguard Worker  * them in out. The weight is quantized per channel, with a scale and zero_point
249*523fa7a6SAndroid Build Coastguard Worker  * for each embedding.
250*523fa7a6SAndroid Build Coastguard Worker  *
251*523fa7a6SAndroid Build Coastguard Worker  * Corresponds as the out variant to torch.ops.quantized.embedding_xbit
252*523fa7a6SAndroid Build Coastguard Worker  *
253*523fa7a6SAndroid Build Coastguard Worker  * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
254*523fa7a6SAndroid Build Coastguard Worker  * metadata that is passed around which can be useful for pattern matching. See
255*523fa7a6SAndroid Build Coastguard Worker  * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
256*523fa7a6SAndroid Build Coastguard Worker  * info.
257*523fa7a6SAndroid Build Coastguard Worker  */
quantized_embedding_xbit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)258*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out(
259*523fa7a6SAndroid Build Coastguard Worker     // TODO Evaluate whether this name is appropriate for an operator that takes
260*523fa7a6SAndroid Build Coastguard Worker     // non quant input and returns fp output
261*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
262*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
263*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
264*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_min,
265*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_max,
266*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
267*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
268*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
269*523fa7a6SAndroid Build Coastguard Worker   ScalarType out_type = out.scalar_type();
270*523fa7a6SAndroid Build Coastguard Worker 
271*523fa7a6SAndroid Build Coastguard Worker   // TODO (jakeszwe): improve these to account for the size of out in relation
272*523fa7a6SAndroid Build Coastguard Worker   // to weight and indices accounting for a possible batch dimension
273*523fa7a6SAndroid Build Coastguard Worker   check_embedding_xbit_args(
274*523fa7a6SAndroid Build Coastguard Worker       weight,
275*523fa7a6SAndroid Build Coastguard Worker       weight_scales,
276*523fa7a6SAndroid Build Coastguard Worker       opt_weight_zero_points,
277*523fa7a6SAndroid Build Coastguard Worker       weight_quant_min,
278*523fa7a6SAndroid Build Coastguard Worker       weight_quant_max,
279*523fa7a6SAndroid Build Coastguard Worker       indices,
280*523fa7a6SAndroid Build Coastguard Worker       out_type,
281*523fa7a6SAndroid Build Coastguard Worker       out,
282*523fa7a6SAndroid Build Coastguard Worker       weight_nbit);
283*523fa7a6SAndroid Build Coastguard Worker 
284*523fa7a6SAndroid Build Coastguard Worker   constexpr auto name = "quantized_decomposed::embedding_xbit.out";
285*523fa7a6SAndroid Build Coastguard Worker   ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
286*523fa7a6SAndroid Build Coastguard Worker     embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
287*523fa7a6SAndroid Build Coastguard Worker         weight,
288*523fa7a6SAndroid Build Coastguard Worker         weight_scales,
289*523fa7a6SAndroid Build Coastguard Worker         opt_weight_zero_points,
290*523fa7a6SAndroid Build Coastguard Worker         indices,
291*523fa7a6SAndroid Build Coastguard Worker         out,
292*523fa7a6SAndroid Build Coastguard Worker         weight_nbit);
293*523fa7a6SAndroid Build Coastguard Worker   });
294*523fa7a6SAndroid Build Coastguard Worker 
295*523fa7a6SAndroid Build Coastguard Worker   return out;
296*523fa7a6SAndroid Build Coastguard Worker }
297*523fa7a6SAndroid Build Coastguard Worker 
quantized_embedding_xbit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)298*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out(
299*523fa7a6SAndroid Build Coastguard Worker     KernelRuntimeContext& context,
300*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
301*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
302*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
303*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_min,
304*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_max,
305*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
306*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
307*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
308*523fa7a6SAndroid Build Coastguard Worker   // TODO(larryliu): Add a context arg to the real op function and remove this
309*523fa7a6SAndroid Build Coastguard Worker   // wrapper
310*523fa7a6SAndroid Build Coastguard Worker   (void)context;
311*523fa7a6SAndroid Build Coastguard Worker   resize_out_tensor(weight, indices, out, weight_nbit);
312*523fa7a6SAndroid Build Coastguard Worker   return quantized_embedding_xbit_out(
313*523fa7a6SAndroid Build Coastguard Worker       weight,
314*523fa7a6SAndroid Build Coastguard Worker       weight_scales,
315*523fa7a6SAndroid Build Coastguard Worker       opt_weight_zero_points,
316*523fa7a6SAndroid Build Coastguard Worker       weight_quant_min,
317*523fa7a6SAndroid Build Coastguard Worker       weight_quant_max,
318*523fa7a6SAndroid Build Coastguard Worker       indices,
319*523fa7a6SAndroid Build Coastguard Worker       out,
320*523fa7a6SAndroid Build Coastguard Worker       weight_nbit);
321*523fa7a6SAndroid Build Coastguard Worker }
322*523fa7a6SAndroid Build Coastguard Worker 
quantized_embedding_xbit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)323*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out(
324*523fa7a6SAndroid Build Coastguard Worker     // TODO Evaluate whether this name is appropriate for an operator that takes
325*523fa7a6SAndroid Build Coastguard Worker     // non quant input and returns fp output
326*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
327*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
328*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
329*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_min,
330*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_max,
331*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
332*523fa7a6SAndroid Build Coastguard Worker     exec_aten::optional<ScalarType> out_dtype,
333*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
334*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
335*523fa7a6SAndroid Build Coastguard Worker   // TODO (jakeszwe): improve these to account for the size of out in relation
336*523fa7a6SAndroid Build Coastguard Worker   // to weight and indices accounting for a possible batch dimension
337*523fa7a6SAndroid Build Coastguard Worker   check_embedding_xbit_args(
338*523fa7a6SAndroid Build Coastguard Worker       weight,
339*523fa7a6SAndroid Build Coastguard Worker       weight_scales,
340*523fa7a6SAndroid Build Coastguard Worker       opt_weight_zero_points,
341*523fa7a6SAndroid Build Coastguard Worker       weight_quant_min,
342*523fa7a6SAndroid Build Coastguard Worker       weight_quant_max,
343*523fa7a6SAndroid Build Coastguard Worker       indices,
344*523fa7a6SAndroid Build Coastguard Worker       out_dtype,
345*523fa7a6SAndroid Build Coastguard Worker       out,
346*523fa7a6SAndroid Build Coastguard Worker       weight_nbit);
347*523fa7a6SAndroid Build Coastguard Worker 
348*523fa7a6SAndroid Build Coastguard Worker   ScalarType params_type = weight_scales.scalar_type();
349*523fa7a6SAndroid Build Coastguard Worker   ScalarType out_type = out.scalar_type();
350*523fa7a6SAndroid Build Coastguard Worker 
351*523fa7a6SAndroid Build Coastguard Worker   constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out";
352*523fa7a6SAndroid Build Coastguard Worker   ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
353*523fa7a6SAndroid Build Coastguard Worker     ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
354*523fa7a6SAndroid Build Coastguard Worker       embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
355*523fa7a6SAndroid Build Coastguard Worker           weight,
356*523fa7a6SAndroid Build Coastguard Worker           weight_scales,
357*523fa7a6SAndroid Build Coastguard Worker           opt_weight_zero_points,
358*523fa7a6SAndroid Build Coastguard Worker           indices,
359*523fa7a6SAndroid Build Coastguard Worker           out,
360*523fa7a6SAndroid Build Coastguard Worker           weight_nbit);
361*523fa7a6SAndroid Build Coastguard Worker     });
362*523fa7a6SAndroid Build Coastguard Worker   });
363*523fa7a6SAndroid Build Coastguard Worker 
364*523fa7a6SAndroid Build Coastguard Worker   return out;
365*523fa7a6SAndroid Build Coastguard Worker }
366*523fa7a6SAndroid Build Coastguard Worker 
quantized_embedding_xbit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)367*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out(
368*523fa7a6SAndroid Build Coastguard Worker     KernelRuntimeContext& context,
369*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
370*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
371*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
372*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_min,
373*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_max,
374*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
375*523fa7a6SAndroid Build Coastguard Worker     exec_aten::optional<ScalarType> out_dtype,
376*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
377*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit) {
378*523fa7a6SAndroid Build Coastguard Worker   // TODO(larryliu): Add a context arg to the real op function and remove this
379*523fa7a6SAndroid Build Coastguard Worker   // wrapper
380*523fa7a6SAndroid Build Coastguard Worker   (void)context;
381*523fa7a6SAndroid Build Coastguard Worker   resize_out_tensor(weight, indices, out, weight_nbit);
382*523fa7a6SAndroid Build Coastguard Worker   return quantized_embedding_xbit_dtype_out(
383*523fa7a6SAndroid Build Coastguard Worker       weight,
384*523fa7a6SAndroid Build Coastguard Worker       weight_scales,
385*523fa7a6SAndroid Build Coastguard Worker       opt_weight_zero_points,
386*523fa7a6SAndroid Build Coastguard Worker       weight_quant_min,
387*523fa7a6SAndroid Build Coastguard Worker       weight_quant_max,
388*523fa7a6SAndroid Build Coastguard Worker       indices,
389*523fa7a6SAndroid Build Coastguard Worker       out_dtype,
390*523fa7a6SAndroid Build Coastguard Worker       out,
391*523fa7a6SAndroid Build Coastguard Worker       weight_nbit);
392*523fa7a6SAndroid Build Coastguard Worker }
393*523fa7a6SAndroid Build Coastguard Worker 
394*523fa7a6SAndroid Build Coastguard Worker } // namespace native
395*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
396*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
397