xref: /aosp_15_r20/external/executorch/kernels/quantized/test/op_embedding4b_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the operator
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/test/utils/DeathTest.h>
15 
16 #include <gtest/gtest.h>
17 #include <limits>
18 
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using executorch::runtime::KernelRuntimeContext;
25 using torch::executor::native::quantized_embedding_4bit_out;
26 
27 using torch::executor::testing::TensorFactory;
28 
TEST(OpQuantizedEmbedding4bTest,TestGroupWiseQuantizedEmbedding)29 TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) {
30   et_pal_init();
31   TensorFactory<ScalarType::Byte> tfb;
32   TensorFactory<ScalarType::Float> tf;
33   TensorFactory<ScalarType::Long> tfl;
34 
35   int64_t quant_min = -8;
36   int64_t quant_max = 7;
37 
38   Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
39   Tensor weight_zero_points = tf.make({3}, {1, -5, 0});
40 
41   // -3,  1,  6, 7,
42   //  2, -5, -4, 0,
43   // -8,  3, -1, 6,
44 
45   Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
46 
47   Tensor indices = tfl.make({3}, {0, 2, 1});
48 
49   Tensor out = tf.zeros({3, 4});
50   Tensor expected = tf.make(
51       {3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0});
52 
53   quantized_embedding_4bit_out(
54       qweight,
55       weight_scales,
56       weight_zero_points,
57       quant_min,
58       quant_max,
59       indices,
60       out);
61 
62   EXPECT_TENSOR_EQ(out, expected);
63 
64   out = tf.zeros({3, 4});
65   auto context = KernelRuntimeContext();
66   torch::executor::native::quantized_embedding_4bit_out(
67       context,
68       qweight,
69       weight_scales,
70       weight_zero_points,
71       quant_min,
72       quant_max,
73       indices,
74       out);
75 
76   EXPECT_TENSOR_EQ(out, expected);
77 
78   // Groupwise quantization. groupsize = 2
79   weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
80   weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1});
81   /*
82   fp_weight = [-2.0,  0.0,  11.0, 12.0,
83                 3.0, -7.5, -12.0, -4.0,
84               -12.5, 15.0,   0.0, 21.0]
85   */
86 
87   out = tf.zeros({3, 4});
88   expected = tf.make(
89       {3, 4},
90       {-2.0, 0.0, 11.0, 12.0, -12.5, 15.0, 0.0, 21.0, 3.0, -7.5, -12.0, -4.0});
91 
92   quantized_embedding_4bit_out(
93       qweight,
94       weight_scales,
95       weight_zero_points,
96       quant_min,
97       quant_max,
98       indices,
99       out);
100 
101   EXPECT_TENSOR_EQ(out, expected);
102 }
103 
TEST(OpQuantizedEmbedding4bTest,TestGroupWiseQuantizedEmbeddingDeath1)104 TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
105   et_pal_init();
106   TensorFactory<ScalarType::Byte> tfb;
107   TensorFactory<ScalarType::Float> tf;
108   TensorFactory<ScalarType::Long> tfl;
109 
110   int64_t quant_min = -8;
111   int64_t quant_max = 7;
112 
113   Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
114   Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
115   Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
116   Tensor indices = tfl.make({3}, {0, 2, 1});
117 
118   Tensor out = tf.zeros({3, 4});
119   ET_EXPECT_DEATH(
120       quantized_embedding_4bit_out(
121           qweight,
122           weight_scales,
123           weight_zero_points,
124           quant_min,
125           quant_max,
126           indices,
127           out),
128       "");
129 }
130 
TEST(OpQuantizedEmbedding4bTest,TestGroupWiseQuantizedEmbeddingDeath2)131 TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath2) {
132   et_pal_init();
133   TensorFactory<ScalarType::Byte> tfb;
134   TensorFactory<ScalarType::Float> tf;
135   TensorFactory<ScalarType::Long> tfl;
136 
137   int64_t quant_min = -8;
138   int64_t quant_max = 7;
139 
140   Tensor weight_scales = tf.make({2}, {0.5, 1.0});
141   Tensor weight_zero_points = tf.make({2}, {1, 5});
142   Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
143   Tensor indices = tfl.make({3}, {0, 2, 1});
144 
145   Tensor out = tf.zeros({3, 4});
146   ET_EXPECT_DEATH(
147       quantized_embedding_4bit_out(
148           qweight,
149           weight_scales,
150           weight_zero_points,
151           quant_min,
152           quant_max,
153           indices,
154           out),
155       "");
156 }
157