1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the operator
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/test/utils/DeathTest.h>
15
16 #include <gtest/gtest.h>
17 #include <limits>
18
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using executorch::runtime::KernelRuntimeContext;
25 using torch::executor::native::quantized_embedding_2bit_out;
26
27 using torch::executor::testing::TensorFactory;
28
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbedding)29 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) {
30 et_pal_init();
31 TensorFactory<ScalarType::Byte> tfb;
32 TensorFactory<ScalarType::Float> tf;
33 TensorFactory<ScalarType::Long> tfl;
34
35 int64_t quant_min = -2;
36 int64_t quant_max = 1;
37
38 Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
39 Tensor weight_zero_points = tf.make({3}, {1, -2, 0});
40
41 // -2, 1, 0, 1, -> 0, 3, 2, 3 -> (reverse) 11 10 11 00 -> 236
42 // 0, -1, -2, 0, -> 2, 1, 0, 2 -> (reverse) 10 00 01 10 -> 134
43 // -2, -1, 0, 1, -> 0, 1, 2, 3 -> (reverse) 11 10 01 00 -> 228
44
45 Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
46
47 Tensor indices = tfl.make({3}, {0, 2, 1});
48
49 Tensor out = tf.zeros({3, 4});
50 Tensor expected = tf.make(
51 {3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, 2.0, 1.0, 0.0, 2.0});
52
53 quantized_embedding_2bit_out(
54 qweight,
55 weight_scales,
56 weight_zero_points,
57 quant_min,
58 quant_max,
59 indices,
60 out);
61
62 EXPECT_TENSOR_EQ(out, expected);
63
64 out = tf.zeros({3, 4});
65 auto context = KernelRuntimeContext();
66 torch::executor::native::quantized_embedding_2bit_out(
67 context,
68 qweight,
69 weight_scales,
70 weight_zero_points,
71 quant_min,
72 quant_max,
73 indices,
74 out);
75
76 EXPECT_TENSOR_EQ(out, expected);
77
78 // Groupwise quantization. groupsize = 2
79
80 weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
81 weight_zero_points = tf.make({3, 2}, {1, -2, 0, 1, -2, -1});
82
83 // -2, 1, 0, 1, -> 0, 3, 2, 3 -> (reverse) 11 10 11 00 -> 236
84 // 0, -1, -2, 0, -> 2, 1, 0, 2 -> (reverse) 10 00 01 10 -> 134
85 // -2, -1, 0, 1, -> 0, 1, 2, 3 -> (reverse) 11 10 01 00 -> 228
86
87 qweight = tfb.make({3, 1}, {236, 134, 228});
88
89 indices = tfl.make({3}, {0, 2, 1});
90
91 out = tf.zeros({3, 4});
92 expected = tf.make(
93 {3, 4}, {-1.5, 0.0, 2.0, 3.0, 0.0, 2.5, 3.0, 6.0, 0.0, -1.5, -6.0, -2.0});
94
95 quantized_embedding_2bit_out(
96 qweight,
97 weight_scales,
98 weight_zero_points,
99 quant_min,
100 quant_max,
101 indices,
102 out);
103
104 EXPECT_TENSOR_EQ(out, expected);
105 }
106
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbeddingDeath1)107 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
108 et_pal_init();
109 TensorFactory<ScalarType::Byte> tfb;
110 TensorFactory<ScalarType::Float> tf;
111 TensorFactory<ScalarType::Long> tfl;
112
113 int64_t quant_min = -2;
114 int64_t quant_max = 1;
115
116 Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
117 Tensor weight_zero_points = tf.make({4}, {1, -2, 1, 0});
118 Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
119 Tensor indices = tfl.make({3}, {0, 2, 1});
120 Tensor out = tf.zeros({3, 4});
121
122 // qvals are incompatible shape with scales/zeros
123 ET_EXPECT_DEATH(
124 quantized_embedding_2bit_out(
125 qweight,
126 weight_scales,
127 weight_zero_points,
128 quant_min,
129 quant_max,
130 indices,
131 out),
132 "");
133 }
134
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbeddingDeath2)135 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath2) {
136 et_pal_init();
137 TensorFactory<ScalarType::Byte> tfb;
138 TensorFactory<ScalarType::Float> tf;
139 TensorFactory<ScalarType::Long> tfl;
140
141 int64_t quant_min = -2;
142 int64_t quant_max = 1;
143
144 Tensor weight_scales = tf.make({2}, {0.5, 1.0});
145 Tensor weight_zero_points = tf.make({2}, {1, -2});
146 Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
147 Tensor indices = tfl.make({3}, {0, 2, 1});
148 Tensor out = tf.zeros({3, 4});
149
150 // qvals are incompatible shape with scales/zeros
151 ET_EXPECT_DEATH(
152 quantized_embedding_2bit_out(
153 qweight,
154 weight_scales,
155 weight_zero_points,
156 quant_min,
157 quant_max,
158 indices,
159 out),
160 "");
161 }
162
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbeddingDeath3)163 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath3) {
164 et_pal_init();
165 TensorFactory<ScalarType::Byte> tfb;
166 TensorFactory<ScalarType::Float> tf;
167 TensorFactory<ScalarType::Long> tfl;
168
169 int64_t quant_min = -2;
170 int64_t quant_max = 1;
171
172 Tensor weight_scales = tf.make({2, 3}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
173 Tensor weight_zero_points = tf.make({2, 3}, {0, 0, 0, 0, 0, 0});
174 Tensor qweight = tfb.make({2, 1}, {236, 134});
175 Tensor indices = tfl.make({2}, {0, 2});
176 Tensor out = tf.zeros({2, 8});
177
178 // scales/zeros imply 3 groups, which does not divide embed dimension from
179 // qvals (8)
180 ET_EXPECT_DEATH(
181 quantized_embedding_2bit_out(
182 qweight,
183 weight_scales,
184 weight_zero_points,
185 quant_min,
186 quant_max,
187 indices,
188 out),
189 "");
190 }
191