1 // 2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "ParserHelper.hpp" 9 #include <GraphUtils.hpp> 10 11 #include <armnn/utility/PolymorphicDowncast.hpp> 12 #include <armnnUtils/QuantizeHelper.hpp> 13 14 TEST_SUITE("TensorflowLiteParser_DetectionPostProcess") 15 { 16 struct DetectionPostProcessFixture : ParserFlatbuffersFixture 17 { DetectionPostProcessFixtureDetectionPostProcessFixture18 explicit DetectionPostProcessFixture(const std::string& custom_options) 19 { 20 /* 21 The following values were used for the custom_options: 22 use_regular_nms = true 23 max_classes_per_detection = 1 24 detections_per_class = 1 25 nms_score_threshold = 0.0 26 nms_iou_threshold = 0.5 27 max_detections = 3 28 max_detections = 3 29 num_classes = 2 30 h_scale = 5 31 w_scale = 5 32 x_scale = 10 33 y_scale = 10 34 */ 35 m_JsonString = R"( 36 { 37 "version": 3, 38 "operator_codes": [{ 39 "builtin_code": "CUSTOM", 40 "custom_code": "TFLite_Detection_PostProcess" 41 }], 42 "subgraphs": [{ 43 "tensors": [{ 44 "shape": [1, 6, 4], 45 "type": "UINT8", 46 "buffer": 0, 47 "name": "box_encodings", 48 "quantization": { 49 "min": [0.0], 50 "max": [255.0], 51 "scale": [1.0], 52 "zero_point": [ 1 ] 53 } 54 }, 55 { 56 "shape": [1, 6, 3], 57 "type": "UINT8", 58 "buffer": 1, 59 "name": "scores", 60 "quantization": { 61 "min": [0.0], 62 "max": [255.0], 63 "scale": [0.01], 64 "zero_point": [0] 65 } 66 }, 67 { 68 "shape": [6, 4], 69 "type": "UINT8", 70 "buffer": 2, 71 "name": "anchors", 72 "quantization": { 73 "min": [0.0], 74 "max": [255.0], 75 "scale": [0.5], 76 "zero_point": [0] 77 } 78 }, 79 { 80 "type": "FLOAT32", 81 "buffer": 3, 82 "name": "detection_boxes", 83 "quantization": {} 84 }, 85 { 86 "type": "FLOAT32", 87 "buffer": 4, 88 "name": "detection_classes", 89 "quantization": {} 90 }, 91 { 92 "type": "FLOAT32", 93 "buffer": 5, 94 "name": "detection_scores", 95 "quantization": {} 96 }, 97 { 98 "type": "FLOAT32", 99 "buffer": 6, 100 "name": "num_detections", 101 "quantization": {} 102 } 103 ], 104 "inputs": [0, 1, 2], 105 "outputs": [3, 4, 5, 6], 106 "operators": [{ 107 "opcode_index": 0, 108 "inputs": [0, 1, 2], 109 "outputs": [3, 4, 5, 6], 110 "builtin_options_type": 0, 111 "custom_options": [)" + custom_options + R"(], 112 "custom_options_format": "FLEXBUFFERS" 113 }] 114 }], 115 "buffers": [{}, 116 {}, 117 { "data": [ 1, 1, 2, 2, 118 1, 1, 2, 2, 119 1, 1, 2, 2, 120 1, 21, 2, 2, 121 1, 21, 2, 2, 122 1, 201, 2, 2]}, 123 {}, 124 {}, 125 {}, 126 {}, 127 ] 128 } 129 )"; 130 } 131 }; 132 133 struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture 134 { 135 private: GenerateDescriptorParseDetectionPostProcessCustomOptions136 static armnn::DetectionPostProcessDescriptor GenerateDescriptor() 137 { 138 static armnn::DetectionPostProcessDescriptor descriptor; 139 descriptor.m_UseRegularNms = true; 140 descriptor.m_MaxDetections = 3u; 141 descriptor.m_MaxClassesPerDetection = 1u; 142 descriptor.m_DetectionsPerClass = 1u; 143 descriptor.m_NumClasses = 2u; 144 descriptor.m_NmsScoreThreshold = 0.0f; 145 descriptor.m_NmsIouThreshold = 0.5f; 146 descriptor.m_ScaleH = 5.0f; 147 descriptor.m_ScaleW = 5.0f; 148 descriptor.m_ScaleX = 10.0f; 149 descriptor.m_ScaleY = 10.0f; 150 151 return descriptor; 152 } 153 154 public: ParseDetectionPostProcessCustomOptionsParseDetectionPostProcessCustomOptions155 ParseDetectionPostProcessCustomOptions() 156 : DetectionPostProcessFixture( 157 GenerateDetectionPostProcessJsonString(GenerateDescriptor())) 158 {} 159 }; 160 161 TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "ParseDetectionPostProcess") 162 { 163 Setup(); 164 165 // Inputs 166 using UnquantizedContainer = std::vector<float>; 167 UnquantizedContainer boxEncodings = 168 { 169 0.0f, 0.0f, 0.0f, 0.0f, 170 0.0f, 1.0f, 0.0f, 0.0f, 171 0.0f, -1.0f, 0.0f, 0.0f, 172 0.0f, 0.0f, 0.0f, 0.0f, 173 0.0f, 1.0f, 0.0f, 0.0f, 174 0.0f, 0.0f, 0.0f, 0.0f 175 }; 176 177 UnquantizedContainer scores = 178 { 179 0.0f, 0.9f, 0.8f, 180 0.0f, 0.75f, 0.72f, 181 0.0f, 0.6f, 0.5f, 182 0.0f, 0.93f, 0.95f, 183 0.0f, 0.5f, 0.4f, 184 0.0f, 0.3f, 0.2f 185 }; 186 187 // Outputs 188 UnquantizedContainer detectionBoxes = 189 { 190 0.0f, 10.0f, 1.0f, 11.0f, 191 0.0f, 10.0f, 1.0f, 11.0f, 192 0.0f, 0.0f, 0.0f, 0.0f 193 }; 194 195 UnquantizedContainer detectionClasses = { 1.0f, 0.0f, 0.0f }; 196 UnquantizedContainer detectionScores = { 0.95f, 0.93f, 0.0f }; 197 198 UnquantizedContainer numDetections = { 2.0f }; 199 200 // Quantize inputs and outputs 201 using QuantizedContainer = std::vector<uint8_t>; 202 203 QuantizedContainer quantBoxEncodings = armnnUtils::QuantizedVector<uint8_t>(boxEncodings, 1.00f, 1); 204 QuantizedContainer quantScores = armnnUtils::QuantizedVector<uint8_t>(scores, 0.01f, 0); 205 206 std::map<std::string, QuantizedContainer> input = 207 { 208 { "box_encodings", quantBoxEncodings }, 209 { "scores", quantScores } 210 }; 211 212 std::map<std::string, UnquantizedContainer> output = 213 { 214 { "detection_boxes", detectionBoxes}, 215 { "detection_classes", detectionClasses}, 216 { "detection_scores", detectionScores}, 217 { "num_detections", numDetections} 218 }; 219 220 RunTest<armnn::DataType::QAsymmU8, armnn::DataType::Float32>(0, input, output); 221 } 222 223 TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "DetectionPostProcessGraphStructureTest") 224 { 225 /* 226 Inputs: box_encodings scores 227 \ / 228 DetectionPostProcess 229 / / \ \ 230 / / \ \ 231 Outputs: detection detection detection num_detections 232 boxes classes scores 233 */ 234 235 ReadStringToBinary(); 236 237 armnn::INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary); 238 239 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec()); 240 241 armnn::Graph& graph = GetGraphForTesting(optimized.get()); 242 243 // Check the number of layers in the graph 244 CHECK((graph.GetNumInputs() == 2)); 245 CHECK((graph.GetNumOutputs() == 4)); 246 CHECK((graph.GetNumLayers() == 7)); 247 248 // Input layers 249 armnn::Layer* boxEncodingLayer = GetFirstLayerWithName(graph, "box_encodings"); 250 CHECK((boxEncodingLayer->GetType() == armnn::LayerType::Input)); 251 CHECK(CheckNumberOfInputSlot(boxEncodingLayer, 0)); 252 CHECK(CheckNumberOfOutputSlot(boxEncodingLayer, 1)); 253 254 armnn::Layer* scoresLayer = GetFirstLayerWithName(graph, "scores"); 255 CHECK((scoresLayer->GetType() == armnn::LayerType::Input)); 256 CHECK(CheckNumberOfInputSlot(scoresLayer, 0)); 257 CHECK(CheckNumberOfOutputSlot(scoresLayer, 1)); 258 259 // DetectionPostProcess layer 260 armnn::Layer* detectionPostProcessLayer = GetFirstLayerWithName(graph, "DetectionPostProcess:0:0"); 261 CHECK((detectionPostProcessLayer->GetType() == armnn::LayerType::DetectionPostProcess)); 262 CHECK(CheckNumberOfInputSlot(detectionPostProcessLayer, 2)); 263 CHECK(CheckNumberOfOutputSlot(detectionPostProcessLayer, 4)); 264 265 // Output layers 266 armnn::Layer* detectionBoxesLayer = GetFirstLayerWithName(graph, "detection_boxes"); 267 CHECK((detectionBoxesLayer->GetType() == armnn::LayerType::Output)); 268 CHECK(CheckNumberOfInputSlot(detectionBoxesLayer, 1)); 269 CHECK(CheckNumberOfOutputSlot(detectionBoxesLayer, 0)); 270 271 armnn::Layer* detectionClassesLayer = GetFirstLayerWithName(graph, "detection_classes"); 272 CHECK((detectionClassesLayer->GetType() == armnn::LayerType::Output)); 273 CHECK(CheckNumberOfInputSlot(detectionClassesLayer, 1)); 274 CHECK(CheckNumberOfOutputSlot(detectionClassesLayer, 0)); 275 276 armnn::Layer* detectionScoresLayer = GetFirstLayerWithName(graph, "detection_scores"); 277 CHECK((detectionScoresLayer->GetType() == armnn::LayerType::Output)); 278 CHECK(CheckNumberOfInputSlot(detectionScoresLayer, 1)); 279 CHECK(CheckNumberOfOutputSlot(detectionScoresLayer, 0)); 280 281 armnn::Layer* numDetectionsLayer = GetFirstLayerWithName(graph, "num_detections"); 282 CHECK((numDetectionsLayer->GetType() == armnn::LayerType::Output)); 283 CHECK(CheckNumberOfInputSlot(numDetectionsLayer, 1)); 284 CHECK(CheckNumberOfOutputSlot(numDetectionsLayer, 0)); 285 286 // Check the connections 287 armnn::TensorInfo boxEncodingTensor(armnn::TensorShape({ 1, 6, 4 }), armnn::DataType::QAsymmU8, 1, 1); 288 armnn::TensorInfo scoresTensor(armnn::TensorShape({ 1, 6, 3 }), armnn::DataType::QAsymmU8, 289 0.00999999978f, 0); 290 291 armnn::TensorInfo detectionBoxesTensor(armnn::TensorShape({ 1, 3, 4 }), armnn::DataType::Float32); 292 armnn::TensorInfo detectionClassesTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32); 293 armnn::TensorInfo detectionScoresTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32); 294 armnn::TensorInfo numDetectionsTensor(armnn::TensorShape({ 1 } ), armnn::DataType::Float32); 295 296 CHECK(IsConnected(boxEncodingLayer, detectionPostProcessLayer, 0, 0, boxEncodingTensor)); 297 CHECK(IsConnected(scoresLayer, detectionPostProcessLayer, 0, 1, scoresTensor)); 298 CHECK(IsConnected(detectionPostProcessLayer, detectionBoxesLayer, 0, 0, detectionBoxesTensor)); 299 CHECK(IsConnected(detectionPostProcessLayer, detectionClassesLayer, 1, 0, detectionClassesTensor)); 300 CHECK(IsConnected(detectionPostProcessLayer, detectionScoresLayer, 2, 0, detectionScoresTensor)); 301 CHECK(IsConnected(detectionPostProcessLayer, numDetectionsLayer, 3, 0, numDetectionsTensor)); 302 } 303 304 } 305