1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the operator
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/test/utils/DeathTest.h>
15
16 #include <gtest/gtest.h>
17 #include <limits>
18
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using executorch::runtime::KernelRuntimeContext;
25 using torch::executor::native::quantized_embedding_4bit_out;
26
27 using torch::executor::testing::TensorFactory;
28
TEST(OpQuantizedEmbedding4bTest,TestGroupWiseQuantizedEmbedding)29 TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) {
30 et_pal_init();
31 TensorFactory<ScalarType::Byte> tfb;
32 TensorFactory<ScalarType::Float> tf;
33 TensorFactory<ScalarType::Long> tfl;
34
35 int64_t quant_min = -8;
36 int64_t quant_max = 7;
37
38 Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
39 Tensor weight_zero_points = tf.make({3}, {1, -5, 0});
40
41 // -3, 1, 6, 7,
42 // 2, -5, -4, 0,
43 // -8, 3, -1, 6,
44
45 Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
46
47 Tensor indices = tfl.make({3}, {0, 2, 1});
48
49 Tensor out = tf.zeros({3, 4});
50 Tensor expected = tf.make(
51 {3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0});
52
53 quantized_embedding_4bit_out(
54 qweight,
55 weight_scales,
56 weight_zero_points,
57 quant_min,
58 quant_max,
59 indices,
60 out);
61
62 EXPECT_TENSOR_EQ(out, expected);
63
64 out = tf.zeros({3, 4});
65 auto context = KernelRuntimeContext();
66 torch::executor::native::quantized_embedding_4bit_out(
67 context,
68 qweight,
69 weight_scales,
70 weight_zero_points,
71 quant_min,
72 quant_max,
73 indices,
74 out);
75
76 EXPECT_TENSOR_EQ(out, expected);
77
78 // Groupwise quantization. groupsize = 2
79 weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
80 weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1});
81 /*
82 fp_weight = [-2.0, 0.0, 11.0, 12.0,
83 3.0, -7.5, -12.0, -4.0,
84 -12.5, 15.0, 0.0, 21.0]
85 */
86
87 out = tf.zeros({3, 4});
88 expected = tf.make(
89 {3, 4},
90 {-2.0, 0.0, 11.0, 12.0, -12.5, 15.0, 0.0, 21.0, 3.0, -7.5, -12.0, -4.0});
91
92 quantized_embedding_4bit_out(
93 qweight,
94 weight_scales,
95 weight_zero_points,
96 quant_min,
97 quant_max,
98 indices,
99 out);
100
101 EXPECT_TENSOR_EQ(out, expected);
102 }
103
TEST(OpQuantizedEmbedding4bTest,TestGroupWiseQuantizedEmbeddingDeath1)104 TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
105 et_pal_init();
106 TensorFactory<ScalarType::Byte> tfb;
107 TensorFactory<ScalarType::Float> tf;
108 TensorFactory<ScalarType::Long> tfl;
109
110 int64_t quant_min = -8;
111 int64_t quant_max = 7;
112
113 Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
114 Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
115 Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
116 Tensor indices = tfl.make({3}, {0, 2, 1});
117
118 Tensor out = tf.zeros({3, 4});
119 ET_EXPECT_DEATH(
120 quantized_embedding_4bit_out(
121 qweight,
122 weight_scales,
123 weight_zero_points,
124 quant_min,
125 quant_max,
126 indices,
127 out),
128 "");
129 }
130
TEST(OpQuantizedEmbedding4bTest,TestGroupWiseQuantizedEmbeddingDeath2)131 TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath2) {
132 et_pal_init();
133 TensorFactory<ScalarType::Byte> tfb;
134 TensorFactory<ScalarType::Float> tf;
135 TensorFactory<ScalarType::Long> tfl;
136
137 int64_t quant_min = -8;
138 int64_t quant_max = 7;
139
140 Tensor weight_scales = tf.make({2}, {0.5, 1.0});
141 Tensor weight_zero_points = tf.make({2}, {1, 5});
142 Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
143 Tensor indices = tfl.make({3}, {0, 2, 1});
144
145 Tensor out = tf.zeros({3, 4});
146 ET_EXPECT_DEATH(
147 quantized_embedding_4bit_out(
148 qweight,
149 weight_scales,
150 weight_zero_points,
151 quant_min,
152 quant_max,
153 indices,
154 out),
155 "");
156 }
157