xref: /aosp_15_r20/external/armnn/delegate/test/LogicalTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1  //
2  // Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
3  // SPDX-License-Identifier: MIT
4  //
5  
6  #pragma once
7  
8  #include "TestUtils.hpp"
9  
10  #include <armnn_delegate.hpp>
11  #include <DelegateTestInterpreter.hpp>
12  
13  #include <flatbuffers/flatbuffers.h>
14  #include <tensorflow/lite/kernels/register.h>
15  #include <tensorflow/lite/version.h>
16  
17  #include <schema_generated.h>
18  
19  #include <doctest/doctest.h>
20  
21  namespace
22  {
23  
CreateLogicalBinaryTfLiteModel(tflite::BuiltinOperator logicalOperatorCode,tflite::TensorType tensorType,const std::vector<int32_t> & input0TensorShape,const std::vector<int32_t> & input1TensorShape,const std::vector<int32_t> & outputTensorShape,float quantScale=1.0f,int quantOffset=0)24  std::vector<char> CreateLogicalBinaryTfLiteModel(tflite::BuiltinOperator logicalOperatorCode,
25                                                   tflite::TensorType tensorType,
26                                                   const std::vector <int32_t>& input0TensorShape,
27                                                   const std::vector <int32_t>& input1TensorShape,
28                                                   const std::vector <int32_t>& outputTensorShape,
29                                                   float quantScale = 1.0f,
30                                                   int quantOffset  = 0)
31  {
32      using namespace tflite;
33      flatbuffers::FlatBufferBuilder flatBufferBuilder;
34  
35      std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
36      buffers.push_back(CreateBuffer(flatBufferBuilder));
37      buffers.push_back(CreateBuffer(flatBufferBuilder));
38      buffers.push_back(CreateBuffer(flatBufferBuilder));
39      buffers.push_back(CreateBuffer(flatBufferBuilder));
40  
41      auto quantizationParameters =
42          CreateQuantizationParameters(flatBufferBuilder,
43                                       0,
44                                       0,
45                                       flatBufferBuilder.CreateVector<float>({ quantScale }),
46                                       flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
47  
48  
49      std::array<flatbuffers::Offset<Tensor>, 3> tensors;
50      tensors[0] = CreateTensor(flatBufferBuilder,
51                                flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
52                                                                        input0TensorShape.size()),
53                                tensorType,
54                                1,
55                                flatBufferBuilder.CreateString("input_0"),
56                                quantizationParameters);
57      tensors[1] = CreateTensor(flatBufferBuilder,
58                                flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
59                                                                        input1TensorShape.size()),
60                                tensorType,
61                                2,
62                                flatBufferBuilder.CreateString("input_1"),
63                                quantizationParameters);
64      tensors[2] = CreateTensor(flatBufferBuilder,
65                                flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
66                                                                        outputTensorShape.size()),
67                                tensorType,
68                                3,
69                                flatBufferBuilder.CreateString("output"),
70                                quantizationParameters);
71  
72      // create operator
73      tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
74      flatbuffers::Offset<void> operatorBuiltinOptions = 0;
75      switch (logicalOperatorCode)
76      {
77          case BuiltinOperator_LOGICAL_AND:
78          {
79              operatorBuiltinOptionsType = BuiltinOptions_LogicalAndOptions;
80              operatorBuiltinOptions = CreateLogicalAndOptions(flatBufferBuilder).Union();
81              break;
82          }
83          case BuiltinOperator_LOGICAL_OR:
84          {
85              operatorBuiltinOptionsType = BuiltinOptions_LogicalOrOptions;
86              operatorBuiltinOptions = CreateLogicalOrOptions(flatBufferBuilder).Union();
87              break;
88          }
89          default:
90              break;
91      }
92      const std::vector<int32_t> operatorInputs{ {0, 1} };
93      const std::vector<int32_t> operatorOutputs{ 2 };
94      flatbuffers::Offset <Operator> logicalBinaryOperator =
95          CreateOperator(flatBufferBuilder,
96                         0,
97                         flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
98                         flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
99                         operatorBuiltinOptionsType,
100                         operatorBuiltinOptions);
101  
102      const std::vector<int> subgraphInputs{ {0, 1} };
103      const std::vector<int> subgraphOutputs{ 2 };
104      flatbuffers::Offset <SubGraph> subgraph =
105          CreateSubGraph(flatBufferBuilder,
106                         flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
107                         flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
108                         flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
109                         flatBufferBuilder.CreateVector(&logicalBinaryOperator, 1));
110  
111      flatbuffers::Offset <flatbuffers::String> modelDescription =
112          flatBufferBuilder.CreateString("ArmnnDelegate: Logical Binary Operator Model");
113      flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, logicalOperatorCode);
114  
115      flatbuffers::Offset <Model> flatbufferModel =
116          CreateModel(flatBufferBuilder,
117                      TFLITE_SCHEMA_VERSION,
118                      flatBufferBuilder.CreateVector(&operatorCode, 1),
119                      flatBufferBuilder.CreateVector(&subgraph, 1),
120                      modelDescription,
121                      flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
122  
123      flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
124  
125      return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
126                               flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
127  }
128  
LogicalBinaryTest(tflite::BuiltinOperator logicalOperatorCode,tflite::TensorType tensorType,std::vector<armnn::BackendId> & backends,std::vector<int32_t> & input0Shape,std::vector<int32_t> & input1Shape,std::vector<int32_t> & expectedOutputShape,std::vector<bool> & input0Values,std::vector<bool> & input1Values,std::vector<bool> & expectedOutputValues,float quantScale=1.0f,int quantOffset=0)129  void LogicalBinaryTest(tflite::BuiltinOperator logicalOperatorCode,
130                         tflite::TensorType tensorType,
131                         std::vector<armnn::BackendId>& backends,
132                         std::vector<int32_t>& input0Shape,
133                         std::vector<int32_t>& input1Shape,
134                         std::vector<int32_t>& expectedOutputShape,
135                         std::vector<bool>& input0Values,
136                         std::vector<bool>& input1Values,
137                         std::vector<bool>& expectedOutputValues,
138                         float quantScale = 1.0f,
139                         int quantOffset  = 0)
140  {
141      using namespace delegateTestInterpreter;
142      std::vector<char> modelBuffer = CreateLogicalBinaryTfLiteModel(logicalOperatorCode,
143                                                                     tensorType,
144                                                                     input0Shape,
145                                                                     input1Shape,
146                                                                     expectedOutputShape,
147                                                                     quantScale,
148                                                                     quantOffset);
149  
150      // Setup interpreter with just TFLite Runtime.
151      auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
152      CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
153      CHECK(tfLiteInterpreter.FillInputTensor(input0Values, 0) == kTfLiteOk);
154      CHECK(tfLiteInterpreter.FillInputTensor(input1Values, 1) == kTfLiteOk);
155      CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
156      std::vector<bool>    tfLiteOutputValues = tfLiteInterpreter.GetOutputResult(0);
157      std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
158  
159      // Setup interpreter with Arm NN Delegate applied.
160      auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
161      CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
162      CHECK(armnnInterpreter.FillInputTensor(input0Values, 0) == kTfLiteOk);
163      CHECK(armnnInterpreter.FillInputTensor(input1Values, 1) == kTfLiteOk);
164      CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
165      std::vector<bool>    armnnOutputValues = armnnInterpreter.GetOutputResult(0);
166      std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
167  
168      armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
169  
170      armnnDelegate::CompareData(expectedOutputValues, armnnOutputValues, expectedOutputValues.size());
171      armnnDelegate::CompareData(expectedOutputValues, tfLiteOutputValues, expectedOutputValues.size());
172      armnnDelegate::CompareData(tfLiteOutputValues, armnnOutputValues, expectedOutputValues.size());
173  
174      tfLiteInterpreter.Cleanup();
175      armnnInterpreter.Cleanup();
176  }
177  
178  } // anonymous namespace