1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20 using Scalar = exec_aten::Scalar;
21 using ScalarType = exec_aten::ScalarType;
22
23 /**
24 * Retrieves the embeddings specified by indices, dequantizes them, and stores
25 * them in out. The weight is quantized per channel, with a scale and zero_point
26 * for each embedding.
27 *
28 * Corresponds as the out variant to torch.ops.quantized.embedding_2bit
29 *
30 * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
31 * metadata that is passed around which can be useful for pattern matching. See
32 * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
33 * info.
34 */
quantized_embedding_2bit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out)35 Tensor& quantized_embedding_2bit_out(
36 // TODO Evaluate whether this name is appropriate for an operator that takes
37 // non quant input and returns fp output
38 const Tensor& weight,
39 const Tensor& weight_scales,
40 const exec_aten::optional<Tensor>& opt_weight_zero_points,
41 const int64_t weight_quant_min,
42 const int64_t weight_quant_max,
43 const Tensor& indices,
44 Tensor& out) {
45 return quantized_embedding_xbit_out(
46 weight,
47 weight_scales,
48 opt_weight_zero_points,
49 weight_quant_min,
50 weight_quant_max,
51 indices,
52 out,
53 2);
54 }
55
quantized_embedding_2bit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out)56 Tensor& quantized_embedding_2bit_out(
57 KernelRuntimeContext& context,
58 const Tensor& weight,
59 const Tensor& weight_scales,
60 const exec_aten::optional<Tensor>& opt_weight_zero_points,
61 int64_t weight_quant_min,
62 int64_t weight_quant_max,
63 const Tensor& indices,
64 Tensor& out) {
65 return quantized_embedding_xbit_out(
66 context,
67 weight,
68 weight_scales,
69 opt_weight_zero_points,
70 weight_quant_min,
71 weight_quant_max,
72 indices,
73 out,
74 2);
75 }
76
quantized_embedding_2bit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)77 Tensor& quantized_embedding_2bit_dtype_out(
78 const Tensor& weight,
79 const Tensor& weight_scales,
80 const exec_aten::optional<Tensor>& opt_weight_zero_points,
81 int64_t weight_quant_min,
82 int64_t weight_quant_max,
83 const Tensor& indices,
84 exec_aten::optional<ScalarType> out_dtype,
85 Tensor& out) {
86 return quantized_embedding_xbit_dtype_out(
87 weight,
88 weight_scales,
89 opt_weight_zero_points,
90 weight_quant_min,
91 weight_quant_max,
92 indices,
93 out_dtype,
94 out,
95 2);
96 }
97
quantized_embedding_2bit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)98 Tensor& quantized_embedding_2bit_dtype_out(
99 KernelRuntimeContext& context,
100 const Tensor& weight,
101 const Tensor& weight_scales,
102 const exec_aten::optional<Tensor>& opt_weight_zero_points,
103 int64_t weight_quant_min,
104 int64_t weight_quant_max,
105 const Tensor& indices,
106 exec_aten::optional<ScalarType> out_dtype,
107 Tensor& out) {
108 return quantized_embedding_xbit_dtype_out(
109 context,
110 weight,
111 weight_scales,
112 opt_weight_zero_points,
113 weight_quant_min,
114 weight_quant_max,
115 indices,
116 out_dtype,
117 out,
118 2);
119 }
120
121 } // namespace native
122 } // namespace executor
123 } // namespace torch
124