xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefDetectionPostProcessTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <reference/workloads/DetectionPostProcess.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("RefDetectionPostProcess")
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TopKSortTest")
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker     unsigned int k = 3;
18*89c4ff92SAndroid Build Coastguard Worker     unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
19*89c4ff92SAndroid Build Coastguard Worker     float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
20*89c4ff92SAndroid Build Coastguard Worker     armnn::TopKSort(k, indices, values, 8);
21*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[0] == 7);
22*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[1] == 1);
23*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[2] == 2);
24*89c4ff92SAndroid Build Coastguard Worker }
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullTopKSortTest")
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker     unsigned int k = 8;
29*89c4ff92SAndroid Build Coastguard Worker     unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
30*89c4ff92SAndroid Build Coastguard Worker     float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
31*89c4ff92SAndroid Build Coastguard Worker     armnn::TopKSort(k, indices, values, 8);
32*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[0] == 7);
33*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[1] == 1);
34*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[2] == 2);
35*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[3] == 3);
36*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[4] == 4);
37*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[5] == 5);
38*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[6] == 6);
39*89c4ff92SAndroid Build Coastguard Worker     CHECK(indices[7] == 0);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("IouTest")
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker     float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
45*89c4ff92SAndroid Build Coastguard Worker     float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
46*89c4ff92SAndroid Build Coastguard Worker     float iou = armnn::IntersectionOverUnion(boxI, boxJ);
47*89c4ff92SAndroid Build Coastguard Worker     CHECK(iou == doctest::Approx(0.68).epsilon(0.001f));
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NmsFunction")
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> boxCorners({
53*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 1.0f,
54*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.1f, 1.0f, 1.1f,
55*89c4ff92SAndroid Build Coastguard Worker         0.0f, -0.1f, 1.0f, 0.9f,
56*89c4ff92SAndroid Build Coastguard Worker         0.0f, 10.0f, 1.0f, 11.0f,
57*89c4ff92SAndroid Build Coastguard Worker         0.0f, 10.1f, 1.0f, 11.1f,
58*89c4ff92SAndroid Build Coastguard Worker         0.0f, 100.0f, 1.0f, 101.0f
59*89c4ff92SAndroid Build Coastguard Worker     });
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> result =
64*89c4ff92SAndroid Build Coastguard Worker         armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     CHECK(result.size() == 3);
67*89c4ff92SAndroid Build Coastguard Worker     CHECK(result[0] == 3);
68*89c4ff92SAndroid Build Coastguard Worker     CHECK(result[1] == 0);
69*89c4ff92SAndroid Build Coastguard Worker     CHECK(result[2] == 5);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker 
DetectionPostProcessTestImpl(bool useRegularNms,const std::vector<float> & expectedDetectionBoxes,const std::vector<float> & expectedDetectionClasses,const std::vector<float> & expectedDetectionScores,const std::vector<float> & expectedNumDetections)72*89c4ff92SAndroid Build Coastguard Worker void DetectionPostProcessTestImpl(bool useRegularNms,
73*89c4ff92SAndroid Build Coastguard Worker                                   const std::vector<float>& expectedDetectionBoxes,
74*89c4ff92SAndroid Build Coastguard Worker                                   const std::vector<float>& expectedDetectionClasses,
75*89c4ff92SAndroid Build Coastguard Worker                                   const std::vector<float>& expectedDetectionScores,
76*89c4ff92SAndroid Build Coastguard Worker                                   const std::vector<float>& expectedNumDetections)
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
79*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::Float32);
80*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
83*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
84*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
85*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     armnn::DetectionPostProcessDescriptor desc;
88*89c4ff92SAndroid Build Coastguard Worker     desc.m_UseRegularNms = useRegularNms;
89*89c4ff92SAndroid Build Coastguard Worker     desc.m_MaxDetections = 3;
90*89c4ff92SAndroid Build Coastguard Worker     desc.m_MaxClassesPerDetection = 1;
91*89c4ff92SAndroid Build Coastguard Worker     desc.m_DetectionsPerClass =1;
92*89c4ff92SAndroid Build Coastguard Worker     desc.m_NmsScoreThreshold = 0.0;
93*89c4ff92SAndroid Build Coastguard Worker     desc.m_NmsIouThreshold = 0.5;
94*89c4ff92SAndroid Build Coastguard Worker     desc.m_NumClasses = 2;
95*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleY = 10.0;
96*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleX = 10.0;
97*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleH = 5.0;
98*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleW = 5.0;
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> boxEncodings({
101*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 0.0f, 0.0f,
102*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 0.0f, 0.0f,
103*89c4ff92SAndroid Build Coastguard Worker         0.0f, -1.0f, 0.0f, 0.0f,
104*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 0.0f, 0.0f,
105*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 0.0f, 0.0f,
106*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 0.0f, 0.0f
107*89c4ff92SAndroid Build Coastguard Worker     });
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> scores({
110*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.9f, 0.8f,
111*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.75f, 0.72f,
112*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.6f, 0.5f,
113*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.93f, 0.95f,
114*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.5f, 0.4f,
115*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.3f, 0.2f
116*89c4ff92SAndroid Build Coastguard Worker     });
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> anchors({
119*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 1.0f, 1.0f,
120*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 1.0f, 1.0f,
121*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 1.0f, 1.0f,
122*89c4ff92SAndroid Build Coastguard Worker         0.5f, 10.5f, 1.0f, 1.0f,
123*89c4ff92SAndroid Build Coastguard Worker         0.5f, 10.5f, 1.0f, 1.0f,
124*89c4ff92SAndroid Build Coastguard Worker         0.5f, 100.5f, 1.0f, 1.0f
125*89c4ff92SAndroid Build Coastguard Worker     });
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     auto boxEncodingsDecoder = armnn::MakeDecoder<float>(boxEncodingsInfo, boxEncodings.data());
128*89c4ff92SAndroid Build Coastguard Worker     auto scoresDecoder       = armnn::MakeDecoder<float>(scoresInfo, scores.data());
129*89c4ff92SAndroid Build Coastguard Worker     auto anchorsDecoder      = armnn::MakeDecoder<float>(anchorsInfo, anchors.data());
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
132*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
133*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
134*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> numDetections(1);
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker     armnn::DetectionPostProcess(boxEncodingsInfo,
137*89c4ff92SAndroid Build Coastguard Worker                                 scoresInfo,
138*89c4ff92SAndroid Build Coastguard Worker                                 anchorsInfo,
139*89c4ff92SAndroid Build Coastguard Worker                                 detectionBoxesInfo,
140*89c4ff92SAndroid Build Coastguard Worker                                 detectionClassesInfo,
141*89c4ff92SAndroid Build Coastguard Worker                                 detectionScoresInfo,
142*89c4ff92SAndroid Build Coastguard Worker                                 numDetectionInfo,
143*89c4ff92SAndroid Build Coastguard Worker                                 desc,
144*89c4ff92SAndroid Build Coastguard Worker                                 *boxEncodingsDecoder,
145*89c4ff92SAndroid Build Coastguard Worker                                 *scoresDecoder,
146*89c4ff92SAndroid Build Coastguard Worker                                 *anchorsDecoder,
147*89c4ff92SAndroid Build Coastguard Worker                                 detectionBoxes.data(),
148*89c4ff92SAndroid Build Coastguard Worker                                 detectionClasses.data(),
149*89c4ff92SAndroid Build Coastguard Worker                                 detectionScores.data(),
150*89c4ff92SAndroid Build Coastguard Worker                                 numDetections.data());
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(detectionBoxes.begin(),
153*89c4ff92SAndroid Build Coastguard Worker                                   detectionBoxes.end(),
154*89c4ff92SAndroid Build Coastguard Worker                                   expectedDetectionBoxes.begin(),
155*89c4ff92SAndroid Build Coastguard Worker                                   expectedDetectionBoxes.end()));
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(detectionScores.begin(), detectionScores.end(),
158*89c4ff92SAndroid Build Coastguard Worker         expectedDetectionScores.begin(), expectedDetectionScores.end()));
159*89c4ff92SAndroid Build Coastguard Worker 
160*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(detectionClasses.begin(), detectionClasses.end(),
161*89c4ff92SAndroid Build Coastguard Worker         expectedDetectionClasses.begin(), expectedDetectionClasses.end()));
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(numDetections.begin(), numDetections.end(),
164*89c4ff92SAndroid Build Coastguard Worker         expectedNumDetections.begin(), expectedNumDetections.end()));
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RegularNmsDetectionPostProcess")
168*89c4ff92SAndroid Build Coastguard Worker {
169*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedDetectionBoxes({
170*89c4ff92SAndroid Build Coastguard Worker         0.0f, 10.0f, 1.0f, 11.0f,
171*89c4ff92SAndroid Build Coastguard Worker         0.0f, 10.0f, 1.0f, 11.0f,
172*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 0.0f, 0.0f
173*89c4ff92SAndroid Build Coastguard Worker     });
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
176*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
177*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedNumDetections({ 2.0f });
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker     DetectionPostProcessTestImpl(true, expectedDetectionBoxes, expectedDetectionClasses,
180*89c4ff92SAndroid Build Coastguard Worker                                  expectedDetectionScores, expectedNumDetections);
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FastNmsDetectionPostProcess")
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedDetectionBoxes({
186*89c4ff92SAndroid Build Coastguard Worker         0.0f, 10.0f, 1.0f, 11.0f,
187*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 1.0f,
188*89c4ff92SAndroid Build Coastguard Worker         0.0f, 100.0f, 1.0f, 101.0f
189*89c4ff92SAndroid Build Coastguard Worker     });
190*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
191*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
192*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedNumDetections({ 3.0f });
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     DetectionPostProcessTestImpl(false, expectedDetectionBoxes, expectedDetectionClasses,
195*89c4ff92SAndroid Build Coastguard Worker                                  expectedDetectionScores, expectedNumDetections);
196*89c4ff92SAndroid Build Coastguard Worker }
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker }