xref: /aosp_15_r20/external/executorch/kernels/quantized/test/op_dequantize_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the operator
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/runtime/platform/runtime.h>
15 #include <executorch/test/utils/DeathTest.h>
16 
17 #include <gtest/gtest.h>
18 #include <limits>
19 
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::optional;
23 using exec_aten::Scalar;
24 using exec_aten::ScalarType;
25 using exec_aten::Tensor;
26 using torch::executor::native::dequantize_per_channel_out;
27 using torch::executor::native::dequantize_per_tensor_out;
28 using torch::executor::native::dequantize_per_tensor_tensor_args_out;
29 using torch::executor::testing::TensorFactory;
30 
31 /// A generic smoke test that works for any dtype that supports ones() and
32 /// zeros().
33 template <ScalarType DTYPE>
test_dtype()34 void test_dtype() {
35   TensorFactory<DTYPE> tf;
36 
37   Tensor input = tf.full({3, 5}, 100);
38   double scale = 0.5;
39   int64_t zero_point = 30;
40   int64_t quant_min = 0;
41   int64_t quant_max = 255;
42 
43   TensorFactory<ScalarType::Float> tfo;
44   Tensor out = tfo.zeros({3, 5});
45   // (100 - 30) * 0.5
46   Tensor expected = tfo.full({3, 5}, 35);
47   dequantize_per_tensor_out(
48       input,
49       scale,
50       zero_point,
51       quant_min,
52       quant_max,
53       DTYPE,
54       optional<ScalarType>(),
55       out);
56 
57   EXPECT_TENSOR_EQ(out, expected);
58 }
59 
TEST(OpDequantizeOutTest,AllDtypesSupported)60 TEST(OpDequantizeOutTest, AllDtypesSupported) {
61   et_pal_init();
62   test_dtype<ScalarType::Byte>();
63   test_dtype<ScalarType::Char>();
64   test_dtype<ScalarType::Short>();
65   test_dtype<ScalarType::Bits16>();
66   test_dtype<ScalarType::UInt16>();
67   test_dtype<ScalarType::Int>();
68 }
69 
TEST(OpDequantizeOutTest,NonWholeNumbers)70 TEST(OpDequantizeOutTest, NonWholeNumbers) {
71   et_pal_init();
72   TensorFactory<ScalarType::Byte> tf;
73 
74   Tensor input = tf.full({3, 5}, 100);
75   double scale = 0.45;
76   int64_t zero_point = 30;
77   int64_t quant_min = 0;
78   int64_t quant_max = 255;
79 
80   TensorFactory<ScalarType::Float> tfo;
81   Tensor out = tfo.zeros({3, 5});
82   // (100 - 30) * 0.5
83   Tensor expected = tfo.full({3, 5}, 31.5);
84   dequantize_per_tensor_out(
85       input,
86       scale,
87       zero_point,
88       quant_min,
89       quant_max,
90       ScalarType::Byte,
91       optional<ScalarType>(),
92       out);
93 
94   EXPECT_TENSOR_EQ(out, expected);
95 }
96 
TEST(OpDequantizeOutTest,TensorArgOverload)97 TEST(OpDequantizeOutTest, TensorArgOverload) {
98   et_pal_init();
99   TensorFactory<ScalarType::Byte> tf_byte;
100   TensorFactory<ScalarType::Double> tf_double;
101   TensorFactory<ScalarType::Long> tf_long;
102 
103   Tensor input = tf_byte.full({3, 5}, 100);
104   Tensor scale = tf_double.make({1}, {0.45});
105   Tensor zero_point = tf_long.make({1}, {30});
106   int64_t quant_min = 0;
107   int64_t quant_max = 255;
108 
109   TensorFactory<ScalarType::Float> tfo;
110   Tensor out = tfo.zeros({3, 5});
111   // (100 - 30) * 0.5
112   Tensor expected = tfo.full({3, 5}, 31.5);
113   dequantize_per_tensor_tensor_args_out(
114       input,
115       scale,
116       zero_point,
117       quant_min,
118       quant_max,
119       ScalarType::Byte,
120       optional<ScalarType>(),
121       out);
122 
123   EXPECT_TENSOR_EQ(out, expected);
124 }
125 
TEST(OpDequantizeOutTest,DequantizePerChannel)126 TEST(OpDequantizeOutTest, DequantizePerChannel) {
127   et_pal_init();
128   TensorFactory<ScalarType::Byte> tf_byte;
129   TensorFactory<ScalarType::Double> tf_double;
130   TensorFactory<ScalarType::Long> tf_long;
131 
132   Tensor input = tf_byte.full({3, 2}, 100);
133   Tensor scale = tf_double.make({2}, {0.5, 1});
134   Tensor zero_point = tf_long.make({2}, {30, 60});
135   int64_t quant_min = 0;
136   int64_t quant_max = 255;
137 
138   TensorFactory<ScalarType::Float> tfo;
139   Tensor out = tfo.zeros({3, 2});
140   // (100 - 30) * 0.5
141   // (100 - 60) * 1
142   Tensor expected = tfo.make({3, 2}, {35, 40, 35, 40, 35, 40});
143   dequantize_per_channel_out(
144       input,
145       scale,
146       zero_point,
147       /*axis=*/1,
148       quant_min,
149       quant_max,
150       ScalarType::Byte,
151       optional<ScalarType>(),
152       out);
153 
154   EXPECT_TENSOR_EQ(out, expected);
155 
156   // Test with a different axis
157   out = tfo.zeros({3, 2});
158   scale = tf_double.make({3}, {0.5, 0.75, 1});
159   zero_point = tf_long.make({3}, {30, 50, 60});
160   // (100 - 30) * 0.5
161   // (100 - 50) * 0.75
162   // (100 - 60) * 1
163   expected = tfo.make({3, 2}, {35, 35, 37.5, 37.5, 40, 40});
164   dequantize_per_channel_out(
165       input,
166       scale,
167       zero_point,
168       /*axis=*/0,
169       quant_min,
170       quant_max,
171       ScalarType::Byte,
172       optional<ScalarType>(),
173       out);
174 
175   EXPECT_TENSOR_EQ(out, expected);
176 
177   // Test with a different axis
178   out = tfo.zeros({3});
179   input = tf_byte.make({3}, {100, 100, 100});
180   scale = tf_double.make({3}, {0.5, 0.75, 1});
181   zero_point = tf_long.make({3}, {30, 50, 60});
182   // (100 - 30) * 0.5
183   // (100 - 50) * 0.75
184   // (100 - 60) * 1
185   expected = tfo.make({3}, {35, 37.5, 40});
186   dequantize_per_channel_out(
187       input,
188       scale,
189       zero_point,
190       /*axis=*/0,
191       quant_min,
192       quant_max,
193       ScalarType::Byte,
194       optional<ScalarType>(),
195       out);
196   EXPECT_TENSOR_EQ(out, expected);
197 }
198