1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_includes.h>
11*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
12*523fa7a6SAndroid Build Coastguard Worker #include <cassert>
13*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes>
14*523fa7a6SAndroid Build Coastguard Worker #include <cmath>
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker namespace torch {
17*523fa7a6SAndroid Build Coastguard Worker namespace executor {
18*523fa7a6SAndroid Build Coastguard Worker namespace native {
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker using Tensor = exec_aten::Tensor;
21*523fa7a6SAndroid Build Coastguard Worker using Scalar = exec_aten::Scalar;
22*523fa7a6SAndroid Build Coastguard Worker using ScalarType = exec_aten::ScalarType;
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker namespace {
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker static inline int32_t
weight_value(const unsigned char * w_data,int32_t index,int32_t weight_nbit)27*523fa7a6SAndroid Build Coastguard Worker weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) {
28*523fa7a6SAndroid Build Coastguard Worker if (weight_nbit == 2) {
29*523fa7a6SAndroid Build Coastguard Worker int32_t subbyte = index % 4;
30*523fa7a6SAndroid Build Coastguard Worker index >>= 2;
31*523fa7a6SAndroid Build Coastguard Worker switch (subbyte) {
32*523fa7a6SAndroid Build Coastguard Worker case 0:
33*523fa7a6SAndroid Build Coastguard Worker return (int32_t)(w_data[index] & 3) - 2;
34*523fa7a6SAndroid Build Coastguard Worker case 1:
35*523fa7a6SAndroid Build Coastguard Worker return (int32_t)((w_data[index] & 12) >> 2) - 2;
36*523fa7a6SAndroid Build Coastguard Worker case 2:
37*523fa7a6SAndroid Build Coastguard Worker return (int32_t)((w_data[index] & 48) >> 4) - 2;
38*523fa7a6SAndroid Build Coastguard Worker case 3:
39*523fa7a6SAndroid Build Coastguard Worker return (int32_t)((w_data[index] & 192) >> 6) - 2;
40*523fa7a6SAndroid Build Coastguard Worker }
41*523fa7a6SAndroid Build Coastguard Worker } else if (weight_nbit == 4) {
42*523fa7a6SAndroid Build Coastguard Worker int32_t odd = index & 1;
43*523fa7a6SAndroid Build Coastguard Worker index >>= 1;
44*523fa7a6SAndroid Build Coastguard Worker if (odd) {
45*523fa7a6SAndroid Build Coastguard Worker return (int32_t)(w_data[index] & 0x0F) - 8;
46*523fa7a6SAndroid Build Coastguard Worker } else {
47*523fa7a6SAndroid Build Coastguard Worker return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
48*523fa7a6SAndroid Build Coastguard Worker }
49*523fa7a6SAndroid Build Coastguard Worker }
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(false, "invalid weight_nbit");
52*523fa7a6SAndroid Build Coastguard Worker }
53*523fa7a6SAndroid Build Coastguard Worker
get_embedding_dim(int32_t packed_dim,int32_t weight_nbit)54*523fa7a6SAndroid Build Coastguard Worker static inline int32_t get_embedding_dim(
55*523fa7a6SAndroid Build Coastguard Worker int32_t packed_dim,
56*523fa7a6SAndroid Build Coastguard Worker int32_t weight_nbit) {
57*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(8 % weight_nbit == 0, "invalid embedding dim");
58*523fa7a6SAndroid Build Coastguard Worker int packed_values_per_byte = 8 / weight_nbit;
59*523fa7a6SAndroid Build Coastguard Worker return packed_dim * packed_values_per_byte;
60*523fa7a6SAndroid Build Coastguard Worker }
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker /**
63*523fa7a6SAndroid Build Coastguard Worker * Asserts that the parameters are valid.
64*523fa7a6SAndroid Build Coastguard Worker */
check_embedding_xbit_args(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)65*523fa7a6SAndroid Build Coastguard Worker void check_embedding_xbit_args(
66*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
67*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales,
68*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points,
69*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_min,
70*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_max,
71*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
72*523fa7a6SAndroid Build Coastguard Worker exec_aten::optional<ScalarType> out_dtype,
73*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
74*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
75*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8");
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
78*523fa7a6SAndroid Build Coastguard Worker weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
81*523fa7a6SAndroid Build Coastguard Worker weight_scales.dim() == 1 || weight_scales.dim() == 2,
82*523fa7a6SAndroid Build Coastguard Worker "weight_scales must be 1D or 2D but got() %zd dims",
83*523fa7a6SAndroid Build Coastguard Worker weight_scales.dim());
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
86*523fa7a6SAndroid Build Coastguard Worker weight_scales.size(0) == weight.size(0),
87*523fa7a6SAndroid Build Coastguard Worker "Number of scales must be == weight.size(0)=%zd"
88*523fa7a6SAndroid Build Coastguard Worker ", but got %zd",
89*523fa7a6SAndroid Build Coastguard Worker weight_scales.size(0),
90*523fa7a6SAndroid Build Coastguard Worker weight.size(0));
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker if (weight_scales.dim() == 2) {
93*523fa7a6SAndroid Build Coastguard Worker auto num_groups = weight_scales.size(1);
94*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
95*523fa7a6SAndroid Build Coastguard Worker // each 8b uint8 column is packed_values_per_byte columns
96*523fa7a6SAndroid Build Coastguard Worker get_embedding_dim(weight.size(1), weight_nbit) % num_groups == 0,
97*523fa7a6SAndroid Build Coastguard Worker "Number of groups must divide weight.size(1)=%zd"
98*523fa7a6SAndroid Build Coastguard Worker ", but got # of groups = %zd",
99*523fa7a6SAndroid Build Coastguard Worker weight.size(1),
100*523fa7a6SAndroid Build Coastguard Worker num_groups);
101*523fa7a6SAndroid Build Coastguard Worker }
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
104*523fa7a6SAndroid Build Coastguard Worker weight.scalar_type() == ScalarType::Byte,
105*523fa7a6SAndroid Build Coastguard Worker "weight.scalar_type() %" PRId8 " is not supported:",
106*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(weight.scalar_type()));
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
109*523fa7a6SAndroid Build Coastguard Worker out.scalar_type() == ScalarType::Float ||
110*523fa7a6SAndroid Build Coastguard Worker out.scalar_type() == ScalarType::Half,
111*523fa7a6SAndroid Build Coastguard Worker "out.scalar_type() %" PRId8 " is not supported:",
112*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(out.scalar_type()));
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
115*523fa7a6SAndroid Build Coastguard Worker weight_scales.scalar_type() == ScalarType::Float ||
116*523fa7a6SAndroid Build Coastguard Worker weight_scales.scalar_type() == ScalarType::Half,
117*523fa7a6SAndroid Build Coastguard Worker "weight_scales.scalar_type() %" PRId8 " is not supported:",
118*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(weight_scales.scalar_type()));
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker if (opt_weight_zero_points.has_value()) {
121*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
122*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points.value().dim() == weight_scales.dim(),
123*523fa7a6SAndroid Build Coastguard Worker "weight_zero_points's rank match that of weight_scales. "
124*523fa7a6SAndroid Build Coastguard Worker "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
125*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(opt_weight_zero_points.value().dim()),
126*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(weight_scales.dim()));
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
129*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
130*523fa7a6SAndroid Build Coastguard Worker "weight zero points scalar type %" PRId8
131*523fa7a6SAndroid Build Coastguard Worker " does not match out.scalar_type()",
132*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker for (int32_t i = 0; i < weight_scales.dim(); ++i) {
135*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
136*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points.value().size(i) == weight_scales.size(i),
137*523fa7a6SAndroid Build Coastguard Worker "Dimension size misatch at dim %" PRIi32
138*523fa7a6SAndroid Build Coastguard Worker "Weight_zero_point size = %zd"
139*523fa7a6SAndroid Build Coastguard Worker ", weight_scales size = %zd.",
140*523fa7a6SAndroid Build Coastguard Worker i,
141*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points.value().size(i),
142*523fa7a6SAndroid Build Coastguard Worker weight_scales.size(i));
143*523fa7a6SAndroid Build Coastguard Worker }
144*523fa7a6SAndroid Build Coastguard Worker }
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
147*523fa7a6SAndroid Build Coastguard Worker indices.scalar_type() == ScalarType::Long,
148*523fa7a6SAndroid Build Coastguard Worker "indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
149*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(indices.scalar_type()));
150*523fa7a6SAndroid Build Coastguard Worker
151*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
152*523fa7a6SAndroid Build Coastguard Worker weight_quant_min <= weight_quant_max,
153*523fa7a6SAndroid Build Coastguard Worker "weight quant min: %" PRId64
154*523fa7a6SAndroid Build Coastguard Worker " is greater than weight quant max: %" PRId64,
155*523fa7a6SAndroid Build Coastguard Worker weight_quant_min,
156*523fa7a6SAndroid Build Coastguard Worker weight_quant_max);
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker if (out_dtype.has_value()) {
159*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
160*523fa7a6SAndroid Build Coastguard Worker out.scalar_type() == out_dtype.value(),
161*523fa7a6SAndroid Build Coastguard Worker "output_dtype must match the dtype of the out tensor");
162*523fa7a6SAndroid Build Coastguard Worker }
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker /**
166*523fa7a6SAndroid Build Coastguard Worker * Retrieves the embeddings specified by indices, dequantizes them, and stores
167*523fa7a6SAndroid Build Coastguard Worker * them in out. Weight will always be uint8
168*523fa7a6SAndroid Build Coastguard Worker */
169*523fa7a6SAndroid Build Coastguard Worker template <typename CTYPE_PARAMS, typename CTYPE_OUT>
embedding_xbit_per_channel(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const Tensor & indices,Tensor & out,int weight_nbit)170*523fa7a6SAndroid Build Coastguard Worker void embedding_xbit_per_channel(
171*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
172*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales,
173*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points,
174*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
175*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
176*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
177*523fa7a6SAndroid Build Coastguard Worker auto embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker int32_t num_groups_per_channel = 1;
180*523fa7a6SAndroid Build Coastguard Worker if (weight_scales.dim() == 2) {
181*523fa7a6SAndroid Build Coastguard Worker num_groups_per_channel = weight_scales.size(1);
182*523fa7a6SAndroid Build Coastguard Worker }
183*523fa7a6SAndroid Build Coastguard Worker int32_t group_size = embedding_dim / num_groups_per_channel;
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
186*523fa7a6SAndroid Build Coastguard Worker const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
189*523fa7a6SAndroid Build Coastguard Worker const CTYPE_PARAMS* zero_points = nullptr;
190*523fa7a6SAndroid Build Coastguard Worker if (opt_weight_zero_points.has_value()) {
191*523fa7a6SAndroid Build Coastguard Worker zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
192*523fa7a6SAndroid Build Coastguard Worker }
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < indices.numel(); i++) {
195*523fa7a6SAndroid Build Coastguard Worker int64_t index = indices_ptr[i];
196*523fa7a6SAndroid Build Coastguard Worker // If using groupwise embedding
197*523fa7a6SAndroid Build Coastguard Worker int32_t qparams_index = index * num_groups_per_channel;
198*523fa7a6SAndroid Build Coastguard Worker CTYPE_PARAMS zp = 0.0;
199*523fa7a6SAndroid Build Coastguard Worker const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
200*523fa7a6SAndroid Build Coastguard Worker const CTYPE_PARAMS* zero_points_ptr = nullptr;
201*523fa7a6SAndroid Build Coastguard Worker if (opt_weight_zero_points.has_value()) {
202*523fa7a6SAndroid Build Coastguard Worker zero_points_ptr = zero_points + qparams_index;
203*523fa7a6SAndroid Build Coastguard Worker }
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker const uint8_t* w_data =
206*523fa7a6SAndroid Build Coastguard Worker weight.const_data_ptr<uint8_t>() + weight.size(1) * index;
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < embedding_dim; ++j) {
209*523fa7a6SAndroid Build Coastguard Worker int32_t group_id = j / group_size;
210*523fa7a6SAndroid Build Coastguard Worker const CTYPE_PARAMS scale = scale_ptr[group_id];
211*523fa7a6SAndroid Build Coastguard Worker if (opt_weight_zero_points.has_value()) {
212*523fa7a6SAndroid Build Coastguard Worker zp = zero_points_ptr[group_id];
213*523fa7a6SAndroid Build Coastguard Worker }
214*523fa7a6SAndroid Build Coastguard Worker out_data[j] = static_cast<CTYPE_OUT>(
215*523fa7a6SAndroid Build Coastguard Worker (static_cast<float>(weight_value(w_data, j, weight_nbit)) -
216*523fa7a6SAndroid Build Coastguard Worker static_cast<float>(zp)) *
217*523fa7a6SAndroid Build Coastguard Worker static_cast<float>(scale));
218*523fa7a6SAndroid Build Coastguard Worker }
219*523fa7a6SAndroid Build Coastguard Worker out_data += embedding_dim;
220*523fa7a6SAndroid Build Coastguard Worker }
221*523fa7a6SAndroid Build Coastguard Worker }
222*523fa7a6SAndroid Build Coastguard Worker
resize_out_tensor(const Tensor & weight,const Tensor & indices,Tensor & out,int weight_nbit)223*523fa7a6SAndroid Build Coastguard Worker void resize_out_tensor(
224*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
225*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
226*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
227*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
228*523fa7a6SAndroid Build Coastguard Worker exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
229*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < indices.dim(); i++) {
230*523fa7a6SAndroid Build Coastguard Worker expected_output_size[i] = indices.size(i);
231*523fa7a6SAndroid Build Coastguard Worker }
232*523fa7a6SAndroid Build Coastguard Worker const size_t embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
233*523fa7a6SAndroid Build Coastguard Worker expected_output_size[out.dim() - 1] = embedding_dim;
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Worker exec_aten::ArrayRef<exec_aten::SizesType> output_size{
236*523fa7a6SAndroid Build Coastguard Worker expected_output_size, static_cast<size_t>(out.dim())};
237*523fa7a6SAndroid Build Coastguard Worker
238*523fa7a6SAndroid Build Coastguard Worker torch::executor::Error err = resize_tensor(out, output_size);
239*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
240*523fa7a6SAndroid Build Coastguard Worker err == torch::executor::Error::Ok,
241*523fa7a6SAndroid Build Coastguard Worker "Failed to resize out Tensor in quantized_embedding_xbit_out");
242*523fa7a6SAndroid Build Coastguard Worker }
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Worker } // namespace
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker /**
247*523fa7a6SAndroid Build Coastguard Worker * Retrieves the embeddings specified by indices, dequantizes them, and stores
248*523fa7a6SAndroid Build Coastguard Worker * them in out. The weight is quantized per channel, with a scale and zero_point
249*523fa7a6SAndroid Build Coastguard Worker * for each embedding.
250*523fa7a6SAndroid Build Coastguard Worker *
251*523fa7a6SAndroid Build Coastguard Worker * Corresponds as the out variant to torch.ops.quantized.embedding_xbit
252*523fa7a6SAndroid Build Coastguard Worker *
253*523fa7a6SAndroid Build Coastguard Worker * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
254*523fa7a6SAndroid Build Coastguard Worker * metadata that is passed around which can be useful for pattern matching. See
255*523fa7a6SAndroid Build Coastguard Worker * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
256*523fa7a6SAndroid Build Coastguard Worker * info.
257*523fa7a6SAndroid Build Coastguard Worker */
quantized_embedding_xbit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)258*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out(
259*523fa7a6SAndroid Build Coastguard Worker // TODO Evaluate whether this name is appropriate for an operator that takes
260*523fa7a6SAndroid Build Coastguard Worker // non quant input and returns fp output
261*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
262*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales,
263*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points,
264*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_min,
265*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_max,
266*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
267*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
268*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
269*523fa7a6SAndroid Build Coastguard Worker ScalarType out_type = out.scalar_type();
270*523fa7a6SAndroid Build Coastguard Worker
271*523fa7a6SAndroid Build Coastguard Worker // TODO (jakeszwe): improve these to account for the size of out in relation
272*523fa7a6SAndroid Build Coastguard Worker // to weight and indices accounting for a possible batch dimension
273*523fa7a6SAndroid Build Coastguard Worker check_embedding_xbit_args(
274*523fa7a6SAndroid Build Coastguard Worker weight,
275*523fa7a6SAndroid Build Coastguard Worker weight_scales,
276*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points,
277*523fa7a6SAndroid Build Coastguard Worker weight_quant_min,
278*523fa7a6SAndroid Build Coastguard Worker weight_quant_max,
279*523fa7a6SAndroid Build Coastguard Worker indices,
280*523fa7a6SAndroid Build Coastguard Worker out_type,
281*523fa7a6SAndroid Build Coastguard Worker out,
282*523fa7a6SAndroid Build Coastguard Worker weight_nbit);
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker constexpr auto name = "quantized_decomposed::embedding_xbit.out";
285*523fa7a6SAndroid Build Coastguard Worker ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
286*523fa7a6SAndroid Build Coastguard Worker embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
287*523fa7a6SAndroid Build Coastguard Worker weight,
288*523fa7a6SAndroid Build Coastguard Worker weight_scales,
289*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points,
290*523fa7a6SAndroid Build Coastguard Worker indices,
291*523fa7a6SAndroid Build Coastguard Worker out,
292*523fa7a6SAndroid Build Coastguard Worker weight_nbit);
293*523fa7a6SAndroid Build Coastguard Worker });
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker return out;
296*523fa7a6SAndroid Build Coastguard Worker }
297*523fa7a6SAndroid Build Coastguard Worker
quantized_embedding_xbit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)298*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out(
299*523fa7a6SAndroid Build Coastguard Worker KernelRuntimeContext& context,
300*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
301*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales,
302*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points,
303*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_min,
304*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_max,
305*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
306*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
307*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
308*523fa7a6SAndroid Build Coastguard Worker // TODO(larryliu): Add a context arg to the real op function and remove this
309*523fa7a6SAndroid Build Coastguard Worker // wrapper
310*523fa7a6SAndroid Build Coastguard Worker (void)context;
311*523fa7a6SAndroid Build Coastguard Worker resize_out_tensor(weight, indices, out, weight_nbit);
312*523fa7a6SAndroid Build Coastguard Worker return quantized_embedding_xbit_out(
313*523fa7a6SAndroid Build Coastguard Worker weight,
314*523fa7a6SAndroid Build Coastguard Worker weight_scales,
315*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points,
316*523fa7a6SAndroid Build Coastguard Worker weight_quant_min,
317*523fa7a6SAndroid Build Coastguard Worker weight_quant_max,
318*523fa7a6SAndroid Build Coastguard Worker indices,
319*523fa7a6SAndroid Build Coastguard Worker out,
320*523fa7a6SAndroid Build Coastguard Worker weight_nbit);
321*523fa7a6SAndroid Build Coastguard Worker }
322*523fa7a6SAndroid Build Coastguard Worker
quantized_embedding_xbit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)323*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out(
324*523fa7a6SAndroid Build Coastguard Worker // TODO Evaluate whether this name is appropriate for an operator that takes
325*523fa7a6SAndroid Build Coastguard Worker // non quant input and returns fp output
326*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
327*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales,
328*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points,
329*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_min,
330*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_max,
331*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
332*523fa7a6SAndroid Build Coastguard Worker exec_aten::optional<ScalarType> out_dtype,
333*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
334*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
335*523fa7a6SAndroid Build Coastguard Worker // TODO (jakeszwe): improve these to account for the size of out in relation
336*523fa7a6SAndroid Build Coastguard Worker // to weight and indices accounting for a possible batch dimension
337*523fa7a6SAndroid Build Coastguard Worker check_embedding_xbit_args(
338*523fa7a6SAndroid Build Coastguard Worker weight,
339*523fa7a6SAndroid Build Coastguard Worker weight_scales,
340*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points,
341*523fa7a6SAndroid Build Coastguard Worker weight_quant_min,
342*523fa7a6SAndroid Build Coastguard Worker weight_quant_max,
343*523fa7a6SAndroid Build Coastguard Worker indices,
344*523fa7a6SAndroid Build Coastguard Worker out_dtype,
345*523fa7a6SAndroid Build Coastguard Worker out,
346*523fa7a6SAndroid Build Coastguard Worker weight_nbit);
347*523fa7a6SAndroid Build Coastguard Worker
348*523fa7a6SAndroid Build Coastguard Worker ScalarType params_type = weight_scales.scalar_type();
349*523fa7a6SAndroid Build Coastguard Worker ScalarType out_type = out.scalar_type();
350*523fa7a6SAndroid Build Coastguard Worker
351*523fa7a6SAndroid Build Coastguard Worker constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out";
352*523fa7a6SAndroid Build Coastguard Worker ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
353*523fa7a6SAndroid Build Coastguard Worker ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
354*523fa7a6SAndroid Build Coastguard Worker embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
355*523fa7a6SAndroid Build Coastguard Worker weight,
356*523fa7a6SAndroid Build Coastguard Worker weight_scales,
357*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points,
358*523fa7a6SAndroid Build Coastguard Worker indices,
359*523fa7a6SAndroid Build Coastguard Worker out,
360*523fa7a6SAndroid Build Coastguard Worker weight_nbit);
361*523fa7a6SAndroid Build Coastguard Worker });
362*523fa7a6SAndroid Build Coastguard Worker });
363*523fa7a6SAndroid Build Coastguard Worker
364*523fa7a6SAndroid Build Coastguard Worker return out;
365*523fa7a6SAndroid Build Coastguard Worker }
366*523fa7a6SAndroid Build Coastguard Worker
quantized_embedding_xbit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)367*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out(
368*523fa7a6SAndroid Build Coastguard Worker KernelRuntimeContext& context,
369*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight,
370*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales,
371*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points,
372*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_min,
373*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_max,
374*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices,
375*523fa7a6SAndroid Build Coastguard Worker exec_aten::optional<ScalarType> out_dtype,
376*523fa7a6SAndroid Build Coastguard Worker Tensor& out,
377*523fa7a6SAndroid Build Coastguard Worker int weight_nbit) {
378*523fa7a6SAndroid Build Coastguard Worker // TODO(larryliu): Add a context arg to the real op function and remove this
379*523fa7a6SAndroid Build Coastguard Worker // wrapper
380*523fa7a6SAndroid Build Coastguard Worker (void)context;
381*523fa7a6SAndroid Build Coastguard Worker resize_out_tensor(weight, indices, out, weight_nbit);
382*523fa7a6SAndroid Build Coastguard Worker return quantized_embedding_xbit_dtype_out(
383*523fa7a6SAndroid Build Coastguard Worker weight,
384*523fa7a6SAndroid Build Coastguard Worker weight_scales,
385*523fa7a6SAndroid Build Coastguard Worker opt_weight_zero_points,
386*523fa7a6SAndroid Build Coastguard Worker weight_quant_min,
387*523fa7a6SAndroid Build Coastguard Worker weight_quant_max,
388*523fa7a6SAndroid Build Coastguard Worker indices,
389*523fa7a6SAndroid Build Coastguard Worker out_dtype,
390*523fa7a6SAndroid Build Coastguard Worker out,
391*523fa7a6SAndroid Build Coastguard Worker weight_nbit);
392*523fa7a6SAndroid Build Coastguard Worker }
393*523fa7a6SAndroid Build Coastguard Worker
394*523fa7a6SAndroid Build Coastguard Worker } // namespace native
395*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
396*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
397