xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/concatenation_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <limits>
19 #include <type_traits>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "tensorflow/lite/kernels/test_util.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 namespace {
30 
31 using ::testing::ElementsAreArray;
32 
33 class BaseConcatenationOpModel : public SingleOpModel {
34  public:
35   // TODO(ahentz): Also test different activation types, axis, input
36   // dimensions.
BaseConcatenationOpModel()37   BaseConcatenationOpModel() {}
BaseConcatenationOpModel(const std::vector<TensorData> & input_template,int axis,int num_inputs,const TensorData & output_template)38   BaseConcatenationOpModel(const std::vector<TensorData>& input_template,
39                            int axis, int num_inputs,
40                            const TensorData& output_template) {
41     std::vector<std::vector<int>> all_input_shapes;
42     CHECK_EQ(input_template.size(), num_inputs);
43     for (int i = 0; i < num_inputs; ++i) {
44       all_input_shapes.push_back(input_template[i].shape);
45       AddInput(input_template[i]);
46     }
47     output_ = AddOutput({output_template.type, /*shape=*/{},
48                          output_template.min, output_template.max});
49     SetBuiltinOp(
50         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
51         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
52             .Union());
53     BuildInterpreter(all_input_shapes);
54   }
BaseConcatenationOpModel(const TensorData & input_template,int axis,int num_inputs)55   BaseConcatenationOpModel(const TensorData& input_template, int axis,
56                            int num_inputs)
57       : BaseConcatenationOpModel(
58             std::vector<TensorData>(num_inputs, input_template), axis,
59             num_inputs, input_template) {}
60 
61  protected:
62   int output_;
63 };
64 
65 class ConcatenationOpModel : public BaseConcatenationOpModel {
66  public:
67   using BaseConcatenationOpModel::BaseConcatenationOpModel;
SetInput(int index,std::initializer_list<float> data)68   void SetInput(int index, std::initializer_list<float> data) {
69     PopulateTensor(index, data);
70   }
GetOutput()71   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
72 };
73 
74 class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
75  public:
76   using BaseConcatenationOpModel::BaseConcatenationOpModel;
77 
78   template <typename T>
SetInput(int index,std::initializer_list<float> data)79   void SetInput(int index, std::initializer_list<float> data) {
80     QuantizeAndPopulate<T>(index, data);
81   }
82   template <typename T>
GetOutput()83   std::vector<T> GetOutput() {
84     return ExtractVector<T>(output_);
85   }
86   template <typename T>
GetDequantizedOutput()87   std::vector<float> GetDequantizedOutput() {
88     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
89                          GetZeroPoint(output_));
90   }
91 };
92 
93 class BoolConcatenationOpModel : public BaseConcatenationOpModel {
94  public:
95   using BaseConcatenationOpModel::BaseConcatenationOpModel;
SetInput(int index,std::initializer_list<bool> data)96   void SetInput(int index, std::initializer_list<bool> data) {
97     PopulateTensor(index, data);
98   }
GetOutput()99   std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
100 };
101 
TEST(ConcatenationOpTest,ThreeDimensionalOneInput)102 TEST(ConcatenationOpTest, ThreeDimensionalOneInput) {
103   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
104                           /*num_inputs=*/1);
105   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
106   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
107   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7}));
108 }
109 
TEST(ConcatenationOpTest,FiveDimensionalOneInput)110 TEST(ConcatenationOpTest, FiveDimensionalOneInput) {
111   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2, 1, 3}}, /*axis=*/2,
112                           /*num_inputs=*/1);
113   m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
114                   11.0f, 12.0f});
115   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
116   EXPECT_THAT(m0.GetOutput(),
117               ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
118 }
119 
TEST(ConcatenationOpTest,FiveDimensionalTwoInput)120 TEST(ConcatenationOpTest, FiveDimensionalTwoInput) {
121   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2, 1, 3}}, /*axis=*/0,
122                           /*num_inputs=*/2);
123   m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
124                   11.0f, 12.0f});
125   m0.SetInput(1, {13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f,
126                   22.0f, 23.0f, 24.0f});
127   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
128   EXPECT_THAT(
129       m0.GetOutput(),
130       ElementsAreArray({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
131                         13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}));
132 }
133 
TEST(ConcatenationOpTest,FiveDimensionalTwoInputNegativeAxes)134 TEST(ConcatenationOpTest, FiveDimensionalTwoInputNegativeAxes) {
135   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2, 1, 3}}, /*axis=*/-2,
136                           /*num_inputs=*/2);
137   m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
138                   11.0f, 12.0f});
139   m0.SetInput(1, {13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f,
140                   22.0f, 23.0f, 24.0f});
141   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
142   EXPECT_THAT(m0.GetOutput(),
143               ElementsAreArray({1, 2, 3, 13, 14, 15, 4,  5,  6,  16, 17, 18,
144                                 7, 8, 9, 19, 20, 21, 10, 11, 12, 22, 23, 24}));
145 }
146 
TEST(ConcatenationOpTest,FiveDimensionalTwoInputQuantizedUint8)147 TEST(ConcatenationOpTest, FiveDimensionalTwoInputQuantizedUint8) {
148   QuantizedConcatenationOpModel m0(
149       {TensorType_UINT8, {2, 1, 2, 1, 3}, -12.7, 12.8},
150       /*axis=*/0,
151       /*num_inputs=*/2);
152 
153   m0.SetInput<uint8_t>(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
154                            10.0f, 11.0f, 12.0f});
155   m0.SetInput<uint8_t>(1, {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f, 9.1f,
156                            10.1f, 11.1f, 12.1f});
157   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
158   EXPECT_THAT(m0.GetDequantizedOutput<uint8_t>(),
159               ElementsAreArray(ArrayFloatNear({
160                   1.0f, 2.0f,  3.0f,  4.0f,  5.0f, 6.0f,  7.0f,  8.0f,
161                   9.0f, 10.0f, 11.0f, 12.0f, 1.1f, 2.1f,  3.1f,  4.1f,
162                   5.1f, 6.1f,  7.1f,  8.1f,  9.1f, 10.1f, 11.1f, 12.1f,
163               })));
164   EXPECT_THAT(
165       m0.GetOutput<uint8_t>(),
166       ElementsAreArray({
167           137, 147, 157, 167, 177, 187, 197, 207, 217, 227, 237, 247, 138,  //
168           148, 158, 168, 178, 188, 198, 208, 218, 228, 238, 248,
169       }));
170 }
171 
TEST(ConcatenationOpTest,ThreeDimensionalTwoInputsDifferentShapes)172 TEST(ConcatenationOpTest, ThreeDimensionalTwoInputsDifferentShapes) {
173   ConcatenationOpModel m0(
174       {{TensorType_FLOAT32, {2, 1, 2}}, {TensorType_FLOAT32, {2, 3, 2}}},
175       /*axis=*/1, /*num_inputs=*/2, TensorType_FLOAT32);
176   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
177   m0.SetInput(1, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0, 7.0f, 8.0f, 9.0f, 10.0f,
178                   11.0f, 12.0f});
179   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
180   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 1, 2, 3, 4, 5, 6, 4, 7, 7,
181                                                 8, 9, 10, 11, 12}));
182 }
183 
184 #ifdef GTEST_HAS_DEATH_TEST
TEST(ConcatenationOpTest,ThreeDimensionalTwoInputsDifferentShapesWrongAxis)185 TEST(ConcatenationOpTest, ThreeDimensionalTwoInputsDifferentShapesWrongAxis) {
186   EXPECT_DEATH(
187       ConcatenationOpModel m0(
188           {{TensorType_FLOAT32, {2, 1, 2}}, {TensorType_FLOAT32, {2, 3, 2}}},
189           /*axis=*/0, /*num_inputs=*/2, TensorType_FLOAT32),
190       "Cannot allocate tensors");
191 }
192 #endif
193 
TEST(ConcatenationOpTest,OneTrivialInput)194 TEST(ConcatenationOpTest, OneTrivialInput) {
195   ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0,
196                           /*num_inputs=*/1);
197   m0.SetInput(0, {5.0f});
198   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
199   EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5));
200 }
201 
TEST(ConcatenationOpTest,TwoDimensionalOneInput)202 TEST(ConcatenationOpTest, TwoDimensionalOneInput) {
203   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0,
204                           /*num_inputs=*/1);
205   m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
206   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
207   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
208 }
209 
TEST(ConcatenationOpTest,TwoInputsTwoAxesNegativeAxes)210 TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxes) {
211   // We will concatenate two tensors along different dimensions.
212   auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
213   auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
214 
215   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0,
216                           /*num_inputs=*/2);
217   m0.SetInput(0, tensor0);
218   m0.SetInput(1, tensor1);
219   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
220   EXPECT_THAT(m0.GetOutput(),
221               ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
222 
223   ConcatenationOpModel m0_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-2,
224                                    /*num_inputs=*/2);
225   m0_negative.SetInput(0, tensor0);
226   m0_negative.SetInput(1, tensor1);
227   ASSERT_EQ(m0_negative.Invoke(), kTfLiteOk);
228   EXPECT_THAT(m0_negative.GetOutput(),
229               ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
230 
231   ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1,
232                           /*num_inputs=*/2);
233   m1.SetInput(0, tensor0);
234   m1.SetInput(1, tensor1);
235   ASSERT_EQ(m1.Invoke(), kTfLiteOk);
236   EXPECT_THAT(m1.GetOutput(),
237               ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
238 
239   ConcatenationOpModel m1_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-1,
240                                    /*num_inputs=*/2);
241   m1_negative.SetInput(0, tensor0);
242   m1_negative.SetInput(1, tensor1);
243   ASSERT_EQ(m1_negative.Invoke(), kTfLiteOk);
244   EXPECT_THAT(m1_negative.GetOutput(),
245               ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
246 }
247 
TEST(ConcatenationOpTest,FourInputs)248 TEST(ConcatenationOpTest, FourInputs) {
249   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
250                           /*num_inputs=*/4);
251   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
252   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
253   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
254   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
255   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
256   EXPECT_THAT(m0.GetOutput(),
257               ElementsAreArray({
258                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
259                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
260               }));
261 }
262 
TEST(ConcatenationOpTest,FourInputsQuantizedUint8)263 TEST(ConcatenationOpTest, FourInputsQuantizedUint8) {
264   QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
265                                    /*axis=*/2,
266                                    /*num_inputs=*/4);
267 
268   m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
269   m0.SetInput<uint8_t>(1, {1.1f, 3.1f, 4.1f, 7.1f});
270   m0.SetInput<uint8_t>(2, {1.2f, 3.2f, 4.2f, 7.2f});
271   m0.SetInput<uint8_t>(3, {1.3f, 3.3f, 4.3f, 7.3f});
272   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
273   EXPECT_THAT(m0.GetDequantizedOutput<uint8_t>(),
274               ElementsAreArray(ArrayFloatNear({
275                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
276                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
277               })));
278   EXPECT_THAT(m0.GetOutput<uint8_t>(),
279               ElementsAreArray({
280                   137, 157, 138, 158, 139, 159, 140, 160,  //
281                   167, 197, 168, 198, 169, 199, 170, 200,  //
282               }));
283 }
284 
285 template <typename Type>
286 struct ConcatenationOpTestTyped : public testing::Test {
287   using TestType = Type;
288 
289   enum TensorType tensor_type =
290       (std::is_same<Type, int16_t>::value ? TensorType_INT16 : TensorType_INT8);
291 };
292 
293 using TestTypes = testing::Types<int8_t, int16_t>;
294 TYPED_TEST_CASE(ConcatenationOpTestTyped, TestTypes);
295 
TYPED_TEST(ConcatenationOpTestTyped,FourInputsQuantizedInt8)296 TYPED_TEST(ConcatenationOpTestTyped, FourInputsQuantizedInt8) {
297   using TestType = typename TestFixture::TestType;
298 
299   const float kMin = -1;
300   const float kMax =
301       std::numeric_limits<TestType>::max() /
302       static_cast<float>(std::numeric_limits<TestType>::max() + 1);
303 
304   QuantizedConcatenationOpModel m0(
305       {TestFixture::tensor_type, {2, 1, 2}, 12.8f * kMin, 12.8f * kMax},
306       /*axis=*/2,
307       /*num_inputs=*/4);
308 
309   m0.SetInput<TestType>(0, {1.0f, 3.0f, 4.0f, 7.0f});
310   m0.SetInput<TestType>(1, {1.1f, 3.1f, 4.1f, 7.1f});
311   m0.SetInput<TestType>(2, {1.2f, 3.2f, 4.2f, 7.2f});
312   m0.SetInput<TestType>(3, {1.3f, 3.3f, 4.3f, 7.3f});
313   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
314   EXPECT_THAT(m0.GetDequantizedOutput<TestType>(),
315               ElementsAreArray(ArrayFloatNear({
316                   1, 3, 1.1, 3.1, 1.2, 3.2, 1.3, 3.3,  //
317                   4, 7, 4.1, 7.1, 4.2, 7.2, 4.3, 7.3   //
318               })));
319 }
320 
TEST(ConcatenationOpTest,FourInputsQuantizedMixedRange)321 TEST(ConcatenationOpTest, FourInputsQuantizedMixedRange) {
322   QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
323                                     {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
324                                     {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
325                                     {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
326                                    /*axis=*/2, /*num_inputs=*/4,
327                                    {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});
328 
329   m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
330   m0.SetInput<uint8_t>(1, {1.1f, 3.1f, 4.1f, 7.1f});
331   m0.SetInput<uint8_t>(2, {1.2f, 3.2f, 4.2f, 7.2f});
332   m0.SetInput<uint8_t>(3, {1.3f, 3.3f, 4.3f, 7.3f});
333   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
334   EXPECT_THAT(m0.GetDequantizedOutput<uint8_t>(),
335               ElementsAreArray(ArrayFloatNear({
336                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
337                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
338               })));
339   EXPECT_THAT(m0.GetOutput<uint8_t>(),
340               ElementsAreArray({
341                   137, 157, 138, 158, 139, 159, 140, 160,  //
342                   167, 197, 168, 198, 169, 199, 170, 200,  //
343               }));
344 }
345 
TEST(ConcatenationOpTest,FourInputsQuantizedMixedRangeClampingLogic)346 TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) {
347   QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
348                                     {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
349                                     {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
350                                     {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
351                                    /*axis=*/2, /*num_inputs=*/4,
352                                    {TensorType_UINT8, {2, 1, 2}, -1., 1.});
353 
354   m0.SetInput<uint8_t>(0, {1.0f, -3.0f, -4.0f, -7.0f});
355   m0.SetInput<uint8_t>(1, {1.1f, 3.1f, 4.1f, 7.1f});
356   m0.SetInput<uint8_t>(2, {1.2f, -3.2f, -4.2f, 7.2f});
357   m0.SetInput<uint8_t>(3, {1.3f, 3.3f, 4.3f, 7.3f});
358   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
359   EXPECT_THAT(m0.GetDequantizedOutput<uint8_t>(),
360               ElementsAreArray(ArrayFloatNear(
361                   {
362                       1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f,   //
363                       -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f,  //
364                   },
365                   4e-3)));
366   EXPECT_THAT(m0.GetOutput<uint8_t>(),
367               ElementsAreArray({
368                   255, 0, 255, 255, 255, 0, 255, 255,  //
369                   0, 0, 255, 255, 0, 255, 255, 255,    //
370               }));
371 }
372 
TEST(ConcatenationOpTest,ThreeDimensionalNonQuantizedOneInput)373 TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) {
374   QuantizedConcatenationOpModel m0(
375       {TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits<uint8_t>::max()},
376       /*axis=*/1,
377       /*num_inputs=*/1);
378   m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
379   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
380   EXPECT_THAT(m0.GetOutput<uint8_t>(),
381               ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f})));
382 }
383 
TEST(ConcatenationOpTest,OneTrivialNonQuantizedInput)384 TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) {
385   QuantizedConcatenationOpModel m0(
386       {TensorType_UINT8, {1}, 0, std::numeric_limits<uint8_t>::max()},
387       /*axis=*/0,
388       /*num_inputs=*/1);
389   m0.SetInput<uint8_t>(0, {5.0f});
390   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
391   EXPECT_THAT(m0.GetOutput<uint8_t>(), ::testing::ElementsAre(5));
392 }
393 
TEST(ConcatenationOpTest,TwoDimensionalNonQuantizedOneInput)394 TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) {
395   QuantizedConcatenationOpModel m0(
396       {TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
397       /*axis=*/0,
398       /*num_inputs=*/1);
399   m0.SetInput<uint8_t>(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
400   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
401   EXPECT_THAT(m0.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
402 }
403 
TEST(ConcatenationOpTest,TwoInputsTwoAxesNegativeAxesNonQuantized)404 TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
405   // We will concatenate two tensors along different dimensions.
406   auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
407   auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
408 
409   QuantizedConcatenationOpModel m0(
410       {TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
411       /*axis=*/0,
412       /*num_inputs=*/2);
413   m0.SetInput<uint8_t>(0, tensor0);
414   m0.SetInput<uint8_t>(1, tensor1);
415   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
416   EXPECT_THAT(m0.GetOutput<uint8_t>(),
417               ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
418 
419   QuantizedConcatenationOpModel m0_negative(
420       {TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
421       /*axis=*/-2,
422       /*num_inputs=*/2);
423   m0_negative.SetInput<uint8_t>(0, tensor0);
424   m0_negative.SetInput<uint8_t>(1, tensor1);
425   ASSERT_EQ(m0_negative.Invoke(), kTfLiteOk);
426   EXPECT_THAT(m0_negative.GetOutput<uint8_t>(),
427               ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
428 
429   QuantizedConcatenationOpModel m1(
430       {TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
431       /*axis=*/1,
432       /*num_inputs=*/2);
433   m1.SetInput<uint8_t>(0, tensor0);
434   m1.SetInput<uint8_t>(1, tensor1);
435   ASSERT_EQ(m1.Invoke(), kTfLiteOk);
436   EXPECT_THAT(m1.GetOutput<uint8_t>(),
437               ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
438 
439   QuantizedConcatenationOpModel m1_negative(
440       {TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
441       /*axis=*/-1,
442       /*num_inputs=*/2);
443   m1_negative.SetInput<uint8_t>(0, tensor0);
444   m1_negative.SetInput<uint8_t>(1, tensor1);
445   ASSERT_EQ(m1_negative.Invoke(), kTfLiteOk);
446   EXPECT_THAT(m1_negative.GetOutput<uint8_t>(),
447               ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
448 }
449 
TEST(ConcatenationOpTest,BoolTypeOneInput)450 TEST(ConcatenationOpTest, BoolTypeOneInput) {
451   BoolConcatenationOpModel m0({TensorType_BOOL, {2, 1, 2}}, /*axis=*/1,
452                               /*num_inputs=*/1);
453   m0.SetInput(0, {true, false, false, true});
454   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
455   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({true, false, false, true}));
456 }
457 
TEST(ConcatenationOpTest,BoolTypeTwoInputs)458 TEST(ConcatenationOpTest, BoolTypeTwoInputs) {
459   BoolConcatenationOpModel m0(
460       {{TensorType_BOOL, {2, 1, 2}}, {TensorType_BOOL, {2, 3, 2}}},
461       /*axis=*/1, /*num_inputs=*/2, TensorType_BOOL);
462   m0.SetInput(0, {false, false, false, false});
463   m0.SetInput(1, {true, true, true, true, true, true, true, true, true, true,
464                   true, true});
465   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
466   EXPECT_THAT(
467       m0.GetOutput(),
468       ElementsAreArray({false, false, true, true, true, true, true, true, false,
469                         false, true, true, true, true, true, true}));
470 }
471 
472 enum class TestInputType {
473   kPersistentRo = 0,
474   kOnePersistentRo = 1,
475   kDefault = 2,
476 };
477 
478 struct PersistentTestCase {
479   TestInputType test_type;
480   TensorType tensor_type;
481   bool is_quantized = false;
482 };
483 
484 template <typename T>
485 class PersistentConcatenationOpModel : public SingleOpModel {
486  public:
PersistentConcatenationOpModel(const std::vector<TensorData> & input_template,int axis,const TensorData & output_template,PersistentTestCase test_case,std::vector<std::vector<T>> input_data_list)487   PersistentConcatenationOpModel(const std::vector<TensorData>& input_template,
488                                  int axis, const TensorData& output_template,
489                                  PersistentTestCase test_case,
490                                  std::vector<std::vector<T>> input_data_list)
491       : input_data_list_(input_data_list), test_case_(test_case) {
492     const int num_inputs = input_data_list.size();
493     std::vector<std::vector<int>> all_input_shapes;
494     CHECK_EQ(input_template.size(), num_inputs);
495     for (int i = 0; i < num_inputs; ++i) {
496       int id;
497       all_input_shapes.push_back(input_template[i].shape);
498       id = AddInput(input_template[i]);
499       concat_inputs_.push_back(id);
500     }
501     output_ = AddOutput(output_template);
502     SetBuiltinOp(
503         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
504         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
505             .Union());
506     BuildInterpreter(all_input_shapes, /*num_threads=*/-1,
507                      /*allow_fp32_relax_to_fp16=*/false,
508                      /*apply_delegate=*/true,
509                      /*allocate_and_delegate=*/false);
510 
511     int num_persistent_inputs = 0;
512     if (test_case_.test_type == TestInputType::kPersistentRo) {
513       num_persistent_inputs = num_inputs;
514     } else if (test_case_.test_type == TestInputType::kOnePersistentRo) {
515       num_persistent_inputs = 1;
516     }
517 
518     for (int i = 0; i < num_persistent_inputs; ++i) {
519       interpreter_->tensor(concat_inputs_[i])->allocation_type =
520           kTfLitePersistentRo;
521       std::vector<T>& input_data = input_data_list[i];
522       interpreter_->ResizeInputTensorStrict(concat_inputs_[i],
523                                             input_template[i].shape);
524       if (test_case.is_quantized) {
525         QuantizeAndPopulate<int8_t>(concat_inputs_[i], FloatVector(input_data));
526       } else {
527         PopulateTensor(concat_inputs_[i], input_data);
528       }
529     }
530     AllocateAndDelegate(true);
531   }
532 
FloatVector(std::vector<T> data)533   std::vector<float> FloatVector(std::vector<T> data) {
534     std::vector<float> ret;
535     for (T t : data) {
536       ret.push_back(static_cast<float>(t));
537     }
538     return ret;
539   }
540 
PopulateInputTensors()541   void PopulateInputTensors() {
542     int start = -1;
543     if (test_case_.test_type == TestInputType::kDefault) {
544       start = 0;
545     } else if (test_case_.test_type == TestInputType::kOnePersistentRo) {
546       start = 1;
547     }
548     if (start < 0) {
549       return;
550     }
551     for (int i = start; i < input_data_list_.size(); ++i) {
552       if (test_case_.is_quantized) {
553         QuantizeAndPopulate<int8_t>(concat_inputs_[i],
554                                     FloatVector(input_data_list_[i]));
555       } else {
556         std::vector<T> v(input_data_list_[i]);
557         PopulateTensor(concat_inputs_[i], v);
558       }
559     }
560   }
561 
IsPersistentOutput()562   bool IsPersistentOutput() {
563     const TfLiteTensor* tensor = interpreter_->tensor(output_);
564     return tensor->allocation_type == kTfLitePersistentRo;
565   }
566 
GetOutput()567   std::vector<float> GetOutput() {
568     if (test_case_.is_quantized) {
569       return Dequantize<int8_t>(ExtractVector<int8_t>(output_),
570                                 GetScale(output_), GetZeroPoint(output_));
571     }
572     return FloatVector(ExtractVector<T>(output_));
573   }
574 
575  protected:
576   int output_;
577   std::vector<std::vector<T>> input_data_list_;
578   PersistentTestCase test_case_;
579   std::vector<int> concat_inputs_;
580 };
581 
582 template <typename T>
583 class ConcatenationOpPersistentModelTest : public ::testing::Test {
584  public:
Range(bool is_quantized=false)585   static std::vector<PersistentTestCase> Range(bool is_quantized = false) {
586     TensorType tensor_type = TensorType_FLOAT32;
587     if (std::is_same<T, int32_t>::value) {
588       tensor_type = TensorType_INT32;
589     }
590     if (is_quantized) {
591       tensor_type = TensorType_INT8;
592     }
593     return {{TestInputType::kDefault, tensor_type, is_quantized},
594             {TestInputType::kPersistentRo, tensor_type, is_quantized}};
595   }
596 };
597 
598 using DataTypes = ::testing::Types<float, int32_t>;
599 TYPED_TEST_SUITE(ConcatenationOpPersistentModelTest, DataTypes);
600 
TYPED_TEST(ConcatenationOpPersistentModelTest,PersistentTest)601 TYPED_TEST(ConcatenationOpPersistentModelTest, PersistentTest) {
602   for (PersistentTestCase test_case :
603        ConcatenationOpPersistentModelTest<TypeParam>::Range()) {
604     std::vector<std::vector<TypeParam>> input_data_lists = {
605         {1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}};
606     std::vector<TensorData> input_template = {{test_case.tensor_type, {2, 3}},
607                                               {test_case.tensor_type, {2, 3}}};
608     TensorData output_template = {test_case.tensor_type, {4, 3}};
609     PersistentConcatenationOpModel<TypeParam> m0(input_template, /*axis=*/0,
610                                                  output_template, test_case,
611                                                  input_data_lists);
612     m0.PopulateInputTensors();
613     ASSERT_EQ(m0.Invoke(), kTfLiteOk);
614     ASSERT_EQ(m0.IsPersistentOutput(),
615               test_case.test_type == TestInputType::kPersistentRo);
616     EXPECT_THAT(
617         m0.GetOutput(),
618         ElementsAreArray(ArrayFloatNear(
619             {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0})));
620   }
621 }
622 
TYPED_TEST(ConcatenationOpPersistentModelTest,QuantizedPersistentTest)623 TYPED_TEST(ConcatenationOpPersistentModelTest, QuantizedPersistentTest) {
624   const bool is_quantized = true;
625   for (PersistentTestCase test_case :
626        ConcatenationOpPersistentModelTest<TypeParam>::Range(is_quantized)) {
627     std::vector<std::vector<TypeParam>> input_data_lists = {
628         {1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}};
629     float scale = 12.0 / 255.0;
630     int zero_point = -128;
631     std::vector<TensorData> input_template = {
632         {test_case.tensor_type, {2, 3}, 0.0, 12.0, scale, zero_point},
633         {test_case.tensor_type, {2, 3}, 0.0, 12.0, scale, zero_point},
634     };
635     TensorData output_template = {
636         test_case.tensor_type, {4, 3}, 0.0, 12.0, scale, zero_point};
637     PersistentConcatenationOpModel<TypeParam> m0(input_template, /*axis=*/0,
638                                                  output_template, test_case,
639                                                  input_data_lists);
640     m0.PopulateInputTensors();
641     ASSERT_EQ(m0.Invoke(), kTfLiteOk);
642     ASSERT_EQ(m0.IsPersistentOutput(),
643               test_case.test_type == TestInputType::kPersistentRo);
644     EXPECT_THAT(
645         m0.GetOutput(),
646         ElementsAreArray(ArrayFloatNear(
647             {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0},
648             1e-1)));
649   }
650 }
651 
652 }  // namespace
653 }  // namespace tflite
654