xref: /aosp_15_r20/external/executorch/kernels/quantized/test/op_embedding2b_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the operator
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/test/utils/DeathTest.h>
15 
16 #include <gtest/gtest.h>
17 #include <limits>
18 
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using executorch::runtime::KernelRuntimeContext;
25 using torch::executor::native::quantized_embedding_2bit_out;
26 
27 using torch::executor::testing::TensorFactory;
28 
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbedding)29 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) {
30   et_pal_init();
31   TensorFactory<ScalarType::Byte> tfb;
32   TensorFactory<ScalarType::Float> tf;
33   TensorFactory<ScalarType::Long> tfl;
34 
35   int64_t quant_min = -2;
36   int64_t quant_max = 1;
37 
38   Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
39   Tensor weight_zero_points = tf.make({3}, {1, -2, 0});
40 
41   // -2,  1,  0, 1, -> 0, 3, 2, 3 -> (reverse) 11 10 11 00 -> 236
42   //  0, -1, -2, 0, -> 2, 1, 0, 2 -> (reverse) 10 00 01 10 -> 134
43   // -2,  -1, 0, 1, -> 0, 1, 2, 3 -> (reverse) 11 10 01 00 -> 228
44 
45   Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
46 
47   Tensor indices = tfl.make({3}, {0, 2, 1});
48 
49   Tensor out = tf.zeros({3, 4});
50   Tensor expected = tf.make(
51       {3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, 2.0, 1.0, 0.0, 2.0});
52 
53   quantized_embedding_2bit_out(
54       qweight,
55       weight_scales,
56       weight_zero_points,
57       quant_min,
58       quant_max,
59       indices,
60       out);
61 
62   EXPECT_TENSOR_EQ(out, expected);
63 
64   out = tf.zeros({3, 4});
65   auto context = KernelRuntimeContext();
66   torch::executor::native::quantized_embedding_2bit_out(
67       context,
68       qweight,
69       weight_scales,
70       weight_zero_points,
71       quant_min,
72       quant_max,
73       indices,
74       out);
75 
76   EXPECT_TENSOR_EQ(out, expected);
77 
78   // Groupwise quantization. groupsize = 2
79 
80   weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
81   weight_zero_points = tf.make({3, 2}, {1, -2, 0, 1, -2, -1});
82 
83   // -2,  1,  0, 1, -> 0, 3, 2, 3 -> (reverse) 11 10 11 00 -> 236
84   //  0, -1, -2, 0, -> 2, 1, 0, 2 -> (reverse) 10 00 01 10 -> 134
85   // -2,  -1, 0, 1, -> 0, 1, 2, 3 -> (reverse) 11 10 01 00 -> 228
86 
87   qweight = tfb.make({3, 1}, {236, 134, 228});
88 
89   indices = tfl.make({3}, {0, 2, 1});
90 
91   out = tf.zeros({3, 4});
92   expected = tf.make(
93       {3, 4}, {-1.5, 0.0, 2.0, 3.0, 0.0, 2.5, 3.0, 6.0, 0.0, -1.5, -6.0, -2.0});
94 
95   quantized_embedding_2bit_out(
96       qweight,
97       weight_scales,
98       weight_zero_points,
99       quant_min,
100       quant_max,
101       indices,
102       out);
103 
104   EXPECT_TENSOR_EQ(out, expected);
105 }
106 
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbeddingDeath1)107 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
108   et_pal_init();
109   TensorFactory<ScalarType::Byte> tfb;
110   TensorFactory<ScalarType::Float> tf;
111   TensorFactory<ScalarType::Long> tfl;
112 
113   int64_t quant_min = -2;
114   int64_t quant_max = 1;
115 
116   Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
117   Tensor weight_zero_points = tf.make({4}, {1, -2, 1, 0});
118   Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
119   Tensor indices = tfl.make({3}, {0, 2, 1});
120   Tensor out = tf.zeros({3, 4});
121 
122   // qvals are incompatible shape with scales/zeros
123   ET_EXPECT_DEATH(
124       quantized_embedding_2bit_out(
125           qweight,
126           weight_scales,
127           weight_zero_points,
128           quant_min,
129           quant_max,
130           indices,
131           out),
132       "");
133 }
134 
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbeddingDeath2)135 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath2) {
136   et_pal_init();
137   TensorFactory<ScalarType::Byte> tfb;
138   TensorFactory<ScalarType::Float> tf;
139   TensorFactory<ScalarType::Long> tfl;
140 
141   int64_t quant_min = -2;
142   int64_t quant_max = 1;
143 
144   Tensor weight_scales = tf.make({2}, {0.5, 1.0});
145   Tensor weight_zero_points = tf.make({2}, {1, -2});
146   Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
147   Tensor indices = tfl.make({3}, {0, 2, 1});
148   Tensor out = tf.zeros({3, 4});
149 
150   // qvals are incompatible shape with scales/zeros
151   ET_EXPECT_DEATH(
152       quantized_embedding_2bit_out(
153           qweight,
154           weight_scales,
155           weight_zero_points,
156           quant_min,
157           quant_max,
158           indices,
159           out),
160       "");
161 }
162 
TEST(OpQuantizedEmbedding2bTest,TestGroupWiseQuantizedEmbeddingDeath3)163 TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath3) {
164   et_pal_init();
165   TensorFactory<ScalarType::Byte> tfb;
166   TensorFactory<ScalarType::Float> tf;
167   TensorFactory<ScalarType::Long> tfl;
168 
169   int64_t quant_min = -2;
170   int64_t quant_max = 1;
171 
172   Tensor weight_scales = tf.make({2, 3}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
173   Tensor weight_zero_points = tf.make({2, 3}, {0, 0, 0, 0, 0, 0});
174   Tensor qweight = tfb.make({2, 1}, {236, 134});
175   Tensor indices = tfl.make({2}, {0, 2});
176   Tensor out = tf.zeros({2, 8});
177 
178   // scales/zeros imply 3 groups, which does not divide embed dimension from
179   // qvals (8)
180   ET_EXPECT_DEATH(
181       quantized_embedding_2bit_out(
182           qweight,
183           weight_scales,
184           weight_zero_points,
185           quant_min,
186           quant_max,
187           indices,
188           out),
189       "");
190 }
191