1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the operator
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/runtime/platform/runtime.h>
15 #include <executorch/test/utils/DeathTest.h>
16
17 #include <gtest/gtest.h>
18 #include <limits>
19
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::optional;
23 using exec_aten::Scalar;
24 using exec_aten::ScalarType;
25 using exec_aten::Tensor;
26 using torch::executor::native::dequantize_per_channel_out;
27 using torch::executor::native::dequantize_per_tensor_out;
28 using torch::executor::native::dequantize_per_tensor_tensor_args_out;
29 using torch::executor::testing::TensorFactory;
30
31 /// A generic smoke test that works for any dtype that supports ones() and
32 /// zeros().
33 template <ScalarType DTYPE>
test_dtype()34 void test_dtype() {
35 TensorFactory<DTYPE> tf;
36
37 Tensor input = tf.full({3, 5}, 100);
38 double scale = 0.5;
39 int64_t zero_point = 30;
40 int64_t quant_min = 0;
41 int64_t quant_max = 255;
42
43 TensorFactory<ScalarType::Float> tfo;
44 Tensor out = tfo.zeros({3, 5});
45 // (100 - 30) * 0.5
46 Tensor expected = tfo.full({3, 5}, 35);
47 dequantize_per_tensor_out(
48 input,
49 scale,
50 zero_point,
51 quant_min,
52 quant_max,
53 DTYPE,
54 optional<ScalarType>(),
55 out);
56
57 EXPECT_TENSOR_EQ(out, expected);
58 }
59
TEST(OpDequantizeOutTest,AllDtypesSupported)60 TEST(OpDequantizeOutTest, AllDtypesSupported) {
61 et_pal_init();
62 test_dtype<ScalarType::Byte>();
63 test_dtype<ScalarType::Char>();
64 test_dtype<ScalarType::Short>();
65 test_dtype<ScalarType::Bits16>();
66 test_dtype<ScalarType::UInt16>();
67 test_dtype<ScalarType::Int>();
68 }
69
TEST(OpDequantizeOutTest,NonWholeNumbers)70 TEST(OpDequantizeOutTest, NonWholeNumbers) {
71 et_pal_init();
72 TensorFactory<ScalarType::Byte> tf;
73
74 Tensor input = tf.full({3, 5}, 100);
75 double scale = 0.45;
76 int64_t zero_point = 30;
77 int64_t quant_min = 0;
78 int64_t quant_max = 255;
79
80 TensorFactory<ScalarType::Float> tfo;
81 Tensor out = tfo.zeros({3, 5});
82 // (100 - 30) * 0.5
83 Tensor expected = tfo.full({3, 5}, 31.5);
84 dequantize_per_tensor_out(
85 input,
86 scale,
87 zero_point,
88 quant_min,
89 quant_max,
90 ScalarType::Byte,
91 optional<ScalarType>(),
92 out);
93
94 EXPECT_TENSOR_EQ(out, expected);
95 }
96
TEST(OpDequantizeOutTest,TensorArgOverload)97 TEST(OpDequantizeOutTest, TensorArgOverload) {
98 et_pal_init();
99 TensorFactory<ScalarType::Byte> tf_byte;
100 TensorFactory<ScalarType::Double> tf_double;
101 TensorFactory<ScalarType::Long> tf_long;
102
103 Tensor input = tf_byte.full({3, 5}, 100);
104 Tensor scale = tf_double.make({1}, {0.45});
105 Tensor zero_point = tf_long.make({1}, {30});
106 int64_t quant_min = 0;
107 int64_t quant_max = 255;
108
109 TensorFactory<ScalarType::Float> tfo;
110 Tensor out = tfo.zeros({3, 5});
111 // (100 - 30) * 0.5
112 Tensor expected = tfo.full({3, 5}, 31.5);
113 dequantize_per_tensor_tensor_args_out(
114 input,
115 scale,
116 zero_point,
117 quant_min,
118 quant_max,
119 ScalarType::Byte,
120 optional<ScalarType>(),
121 out);
122
123 EXPECT_TENSOR_EQ(out, expected);
124 }
125
TEST(OpDequantizeOutTest,DequantizePerChannel)126 TEST(OpDequantizeOutTest, DequantizePerChannel) {
127 et_pal_init();
128 TensorFactory<ScalarType::Byte> tf_byte;
129 TensorFactory<ScalarType::Double> tf_double;
130 TensorFactory<ScalarType::Long> tf_long;
131
132 Tensor input = tf_byte.full({3, 2}, 100);
133 Tensor scale = tf_double.make({2}, {0.5, 1});
134 Tensor zero_point = tf_long.make({2}, {30, 60});
135 int64_t quant_min = 0;
136 int64_t quant_max = 255;
137
138 TensorFactory<ScalarType::Float> tfo;
139 Tensor out = tfo.zeros({3, 2});
140 // (100 - 30) * 0.5
141 // (100 - 60) * 1
142 Tensor expected = tfo.make({3, 2}, {35, 40, 35, 40, 35, 40});
143 dequantize_per_channel_out(
144 input,
145 scale,
146 zero_point,
147 /*axis=*/1,
148 quant_min,
149 quant_max,
150 ScalarType::Byte,
151 optional<ScalarType>(),
152 out);
153
154 EXPECT_TENSOR_EQ(out, expected);
155
156 // Test with a different axis
157 out = tfo.zeros({3, 2});
158 scale = tf_double.make({3}, {0.5, 0.75, 1});
159 zero_point = tf_long.make({3}, {30, 50, 60});
160 // (100 - 30) * 0.5
161 // (100 - 50) * 0.75
162 // (100 - 60) * 1
163 expected = tfo.make({3, 2}, {35, 35, 37.5, 37.5, 40, 40});
164 dequantize_per_channel_out(
165 input,
166 scale,
167 zero_point,
168 /*axis=*/0,
169 quant_min,
170 quant_max,
171 ScalarType::Byte,
172 optional<ScalarType>(),
173 out);
174
175 EXPECT_TENSOR_EQ(out, expected);
176
177 // Test with a different axis
178 out = tfo.zeros({3});
179 input = tf_byte.make({3}, {100, 100, 100});
180 scale = tf_double.make({3}, {0.5, 0.75, 1});
181 zero_point = tf_long.make({3}, {30, 50, 60});
182 // (100 - 30) * 0.5
183 // (100 - 50) * 0.75
184 // (100 - 60) * 1
185 expected = tfo.make({3}, {35, 37.5, 40});
186 dequantize_per_channel_out(
187 input,
188 scale,
189 zero_point,
190 /*axis=*/0,
191 quant_min,
192 quant_max,
193 ScalarType::Byte,
194 optional<ScalarType>(),
195 out);
196 EXPECT_TENSOR_EQ(out, expected);
197 }
198