xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_embedding2b.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 using Scalar = exec_aten::Scalar;
21 using ScalarType = exec_aten::ScalarType;
22 
23 /**
24  * Retrieves the embeddings specified by indices, dequantizes them, and stores
25  * them in out. The weight is quantized per channel, with a scale and zero_point
26  * for each embedding.
27  *
28  * Corresponds as the out variant to torch.ops.quantized.embedding_2bit
29  *
30  * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
31  * metadata that is passed around which can be useful for pattern matching. See
32  * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
33  * info.
34  */
quantized_embedding_2bit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out)35 Tensor& quantized_embedding_2bit_out(
36     // TODO Evaluate whether this name is appropriate for an operator that takes
37     // non quant input and returns fp output
38     const Tensor& weight,
39     const Tensor& weight_scales,
40     const exec_aten::optional<Tensor>& opt_weight_zero_points,
41     const int64_t weight_quant_min,
42     const int64_t weight_quant_max,
43     const Tensor& indices,
44     Tensor& out) {
45   return quantized_embedding_xbit_out(
46       weight,
47       weight_scales,
48       opt_weight_zero_points,
49       weight_quant_min,
50       weight_quant_max,
51       indices,
52       out,
53       2);
54 }
55 
quantized_embedding_2bit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out)56 Tensor& quantized_embedding_2bit_out(
57     KernelRuntimeContext& context,
58     const Tensor& weight,
59     const Tensor& weight_scales,
60     const exec_aten::optional<Tensor>& opt_weight_zero_points,
61     int64_t weight_quant_min,
62     int64_t weight_quant_max,
63     const Tensor& indices,
64     Tensor& out) {
65   return quantized_embedding_xbit_out(
66       context,
67       weight,
68       weight_scales,
69       opt_weight_zero_points,
70       weight_quant_min,
71       weight_quant_max,
72       indices,
73       out,
74       2);
75 }
76 
quantized_embedding_2bit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)77 Tensor& quantized_embedding_2bit_dtype_out(
78     const Tensor& weight,
79     const Tensor& weight_scales,
80     const exec_aten::optional<Tensor>& opt_weight_zero_points,
81     int64_t weight_quant_min,
82     int64_t weight_quant_max,
83     const Tensor& indices,
84     exec_aten::optional<ScalarType> out_dtype,
85     Tensor& out) {
86   return quantized_embedding_xbit_dtype_out(
87       weight,
88       weight_scales,
89       opt_weight_zero_points,
90       weight_quant_min,
91       weight_quant_max,
92       indices,
93       out_dtype,
94       out,
95       2);
96 }
97 
quantized_embedding_2bit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)98 Tensor& quantized_embedding_2bit_dtype_out(
99     KernelRuntimeContext& context,
100     const Tensor& weight,
101     const Tensor& weight_scales,
102     const exec_aten::optional<Tensor>& opt_weight_zero_points,
103     int64_t weight_quant_min,
104     int64_t weight_quant_max,
105     const Tensor& indices,
106     exec_aten::optional<ScalarType> out_dtype,
107     Tensor& out) {
108   return quantized_embedding_xbit_dtype_out(
109       context,
110       weight,
111       weight_scales,
112       opt_weight_zero_points,
113       weight_quant_min,
114       weight_quant_max,
115       indices,
116       out_dtype,
117       out,
118       2);
119 }
120 
121 } // namespace native
122 } // namespace executor
123 } // namespace torch
124