1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace{
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CreateDetectionPostProcessNetwork(const armnn::TensorInfo & boxEncodingsInfo,const armnn::TensorInfo & scoresInfo,const armnn::TensorInfo & anchorsInfo,const std::vector<T> & anchors,bool useRegularNms)18*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateDetectionPostProcessNetwork(const armnn::TensorInfo& boxEncodingsInfo,
19*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& scoresInfo,
20*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& anchorsInfo,
21*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& anchors,
22*89c4ff92SAndroid Build Coastguard Worker bool useRegularNms)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
25*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
26*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
27*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker armnn::DetectionPostProcessDescriptor desc;
30*89c4ff92SAndroid Build Coastguard Worker desc.m_UseRegularNms = useRegularNms;
31*89c4ff92SAndroid Build Coastguard Worker desc.m_MaxDetections = 3;
32*89c4ff92SAndroid Build Coastguard Worker desc.m_MaxClassesPerDetection = 1;
33*89c4ff92SAndroid Build Coastguard Worker desc.m_DetectionsPerClass =1;
34*89c4ff92SAndroid Build Coastguard Worker desc.m_NmsScoreThreshold = 0.0;
35*89c4ff92SAndroid Build Coastguard Worker desc.m_NmsIouThreshold = 0.5;
36*89c4ff92SAndroid Build Coastguard Worker desc.m_NumClasses = 2;
37*89c4ff92SAndroid Build Coastguard Worker desc.m_ScaleY = 10.0;
38*89c4ff92SAndroid Build Coastguard Worker desc.m_ScaleX = 10.0;
39*89c4ff92SAndroid Build Coastguard Worker desc.m_ScaleH = 5.0;
40*89c4ff92SAndroid Build Coastguard Worker desc.m_ScaleW = 5.0;
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr net(armnn::INetwork::Create());
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* boxesLayer = net->AddInputLayer(0);
45*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* scoresLayer = net->AddInputLayer(1);
46*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor anchorsTensor(anchorsInfo, anchors.data());
47*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* detectionLayer = net->AddDetectionPostProcessLayer(desc, anchorsTensor,
48*89c4ff92SAndroid Build Coastguard Worker "DetectionPostProcess");
49*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* detectionBoxesLayer = net->AddOutputLayer(0, "detectionBoxes");
50*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* detectionClassesLayer = net->AddOutputLayer(1, "detectionClasses");
51*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* detectionScoresLayer = net->AddOutputLayer(2, "detectionScores");
52*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* numDetectionLayer = net->AddOutputLayer(3, "numDetection");
53*89c4ff92SAndroid Build Coastguard Worker Connect(boxesLayer, detectionLayer, boxEncodingsInfo, 0, 0);
54*89c4ff92SAndroid Build Coastguard Worker Connect(scoresLayer, detectionLayer, scoresInfo, 0, 1);
55*89c4ff92SAndroid Build Coastguard Worker Connect(detectionLayer, detectionBoxesLayer, detectionBoxesInfo, 0, 0);
56*89c4ff92SAndroid Build Coastguard Worker Connect(detectionLayer, detectionClassesLayer, detectionClassesInfo, 1, 0);
57*89c4ff92SAndroid Build Coastguard Worker Connect(detectionLayer, detectionScoresLayer, detectionScoresInfo, 2, 0);
58*89c4ff92SAndroid Build Coastguard Worker Connect(detectionLayer, numDetectionLayer, numDetectionInfo, 3, 0);
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker return net;
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DetectionPostProcessEndToEnd(const std::vector<BackendId> & backends,bool useRegularNms,const std::vector<T> & boxEncodings,const std::vector<T> & scores,const std::vector<T> & anchors,const std::vector<float> & expectedDetectionBoxes,const std::vector<float> & expectedDetectionClasses,const std::vector<float> & expectedDetectionScores,const std::vector<float> & expectedNumDetections,float boxScale=1.0f,int32_t boxOffset=0,float scoreScale=1.0f,int32_t scoreOffset=0,float anchorScale=1.0f,int32_t anchorOffset=0)64*89c4ff92SAndroid Build Coastguard Worker void DetectionPostProcessEndToEnd(const std::vector<BackendId>& backends, bool useRegularNms,
65*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& boxEncodings,
66*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& scores,
67*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& anchors,
68*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expectedDetectionBoxes,
69*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expectedDetectionClasses,
70*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expectedDetectionScores,
71*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expectedNumDetections,
72*89c4ff92SAndroid Build Coastguard Worker float boxScale = 1.0f,
73*89c4ff92SAndroid Build Coastguard Worker int32_t boxOffset = 0,
74*89c4ff92SAndroid Build Coastguard Worker float scoreScale = 1.0f,
75*89c4ff92SAndroid Build Coastguard Worker int32_t scoreOffset = 0,
76*89c4ff92SAndroid Build Coastguard Worker float anchorScale = 1.0f,
77*89c4ff92SAndroid Build Coastguard Worker int32_t anchorOffset = 0)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, ArmnnType);
80*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo scoresInfo({ 1, 6, 3}, ArmnnType);
81*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo anchorsInfo({ 6, 4 }, ArmnnType);
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker boxEncodingsInfo.SetQuantizationScale(boxScale);
84*89c4ff92SAndroid Build Coastguard Worker boxEncodingsInfo.SetQuantizationOffset(boxOffset);
85*89c4ff92SAndroid Build Coastguard Worker boxEncodingsInfo.SetConstant(true);
86*89c4ff92SAndroid Build Coastguard Worker scoresInfo.SetQuantizationScale(scoreScale);
87*89c4ff92SAndroid Build Coastguard Worker scoresInfo.SetQuantizationOffset(scoreOffset);
88*89c4ff92SAndroid Build Coastguard Worker scoresInfo.SetConstant(true);
89*89c4ff92SAndroid Build Coastguard Worker anchorsInfo.SetQuantizationScale(anchorScale);
90*89c4ff92SAndroid Build Coastguard Worker anchorsInfo.SetQuantizationOffset(anchorOffset);
91*89c4ff92SAndroid Build Coastguard Worker anchorsInfo.SetConstant(true);
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
94*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr net = CreateDetectionPostProcessNetwork<T>(boxEncodingsInfo, scoresInfo,
95*89c4ff92SAndroid Build Coastguard Worker anchorsInfo, anchors, useRegularNms);
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0, boxEncodings }, { 1, scores }};
100*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<float>> expectedOutputData = {{ 0, expectedDetectionBoxes },
101*89c4ff92SAndroid Build Coastguard Worker { 1, expectedDetectionClasses },
102*89c4ff92SAndroid Build Coastguard Worker { 2, expectedDetectionScores },
103*89c4ff92SAndroid Build Coastguard Worker { 3, expectedNumDetections }};
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, armnn::DataType::Float32>(
106*89c4ff92SAndroid Build Coastguard Worker move(net), inputTensorData, expectedOutputData, backends);
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DetectionPostProcessRegularNmsEndToEnd(const std::vector<BackendId> & backends,const std::vector<T> & boxEncodings,const std::vector<T> & scores,const std::vector<T> & anchors,float boxScale=1.0f,int32_t boxOffset=0,float scoreScale=1.0f,int32_t scoreOffset=0,float anchorScale=1.0f,int32_t anchorOffset=0)110*89c4ff92SAndroid Build Coastguard Worker void DetectionPostProcessRegularNmsEndToEnd(const std::vector<BackendId>& backends,
111*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& boxEncodings,
112*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& scores,
113*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& anchors,
114*89c4ff92SAndroid Build Coastguard Worker float boxScale = 1.0f,
115*89c4ff92SAndroid Build Coastguard Worker int32_t boxOffset = 0,
116*89c4ff92SAndroid Build Coastguard Worker float scoreScale = 1.0f,
117*89c4ff92SAndroid Build Coastguard Worker int32_t scoreOffset = 0,
118*89c4ff92SAndroid Build Coastguard Worker float anchorScale = 1.0f,
119*89c4ff92SAndroid Build Coastguard Worker int32_t anchorOffset = 0)
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedDetectionBoxes({
122*89c4ff92SAndroid Build Coastguard Worker 0.0f, 10.0f, 1.0f, 11.0f,
123*89c4ff92SAndroid Build Coastguard Worker 0.0f, 10.0f, 1.0f, 11.0f,
124*89c4ff92SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f
125*89c4ff92SAndroid Build Coastguard Worker });
126*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
127*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
128*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedNumDetections({ 2.0f });
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker DetectionPostProcessEndToEnd<ArmnnType>(backends, true, boxEncodings, scores, anchors,
131*89c4ff92SAndroid Build Coastguard Worker expectedDetectionBoxes, expectedDetectionClasses,
132*89c4ff92SAndroid Build Coastguard Worker expectedDetectionScores, expectedNumDetections,
133*89c4ff92SAndroid Build Coastguard Worker boxScale, boxOffset, scoreScale, scoreOffset,
134*89c4ff92SAndroid Build Coastguard Worker anchorScale, anchorOffset);
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker };
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DetectionPostProcessFastNmsEndToEnd(const std::vector<BackendId> & backends,const std::vector<T> & boxEncodings,const std::vector<T> & scores,const std::vector<T> & anchors,float boxScale=1.0f,int32_t boxOffset=0,float scoreScale=1.0f,int32_t scoreOffset=0,float anchorScale=1.0f,int32_t anchorOffset=0)140*89c4ff92SAndroid Build Coastguard Worker void DetectionPostProcessFastNmsEndToEnd(const std::vector<BackendId>& backends,
141*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& boxEncodings,
142*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& scores,
143*89c4ff92SAndroid Build Coastguard Worker const std::vector<T>& anchors,
144*89c4ff92SAndroid Build Coastguard Worker float boxScale = 1.0f,
145*89c4ff92SAndroid Build Coastguard Worker int32_t boxOffset = 0,
146*89c4ff92SAndroid Build Coastguard Worker float scoreScale = 1.0f,
147*89c4ff92SAndroid Build Coastguard Worker int32_t scoreOffset = 0,
148*89c4ff92SAndroid Build Coastguard Worker float anchorScale = 1.0f,
149*89c4ff92SAndroid Build Coastguard Worker int32_t anchorOffset = 0)
150*89c4ff92SAndroid Build Coastguard Worker {
151*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedDetectionBoxes({
152*89c4ff92SAndroid Build Coastguard Worker 0.0f, 10.0f, 1.0f, 11.0f,
153*89c4ff92SAndroid Build Coastguard Worker 0.0f, 0.0f, 1.0f, 1.0f,
154*89c4ff92SAndroid Build Coastguard Worker 0.0f, 100.0f, 1.0f, 101.0f
155*89c4ff92SAndroid Build Coastguard Worker });
156*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
157*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
158*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedNumDetections({ 3.0f });
159*89c4ff92SAndroid Build Coastguard Worker
160*89c4ff92SAndroid Build Coastguard Worker DetectionPostProcessEndToEnd<ArmnnType>(backends, false, boxEncodings, scores, anchors,
161*89c4ff92SAndroid Build Coastguard Worker expectedDetectionBoxes, expectedDetectionClasses,
162*89c4ff92SAndroid Build Coastguard Worker expectedDetectionScores, expectedNumDetections,
163*89c4ff92SAndroid Build Coastguard Worker boxScale, boxOffset, scoreScale, scoreOffset,
164*89c4ff92SAndroid Build Coastguard Worker anchorScale, anchorOffset);
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker };
167*89c4ff92SAndroid Build Coastguard Worker
168*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
169