xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/src/SSDResultDecoder.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "SSDResultDecoder.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
9*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
10*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
11*89c4ff92SAndroid Build Coastguard Worker #include <stdexcept>
12*89c4ff92SAndroid Build Coastguard Worker namespace od
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker 
Decode(const common::InferenceResults<float> & networkResults,const common::Size & outputFrameSize,const common::Size & resizedFrameSize,const std::vector<std::string> & labels)15*89c4ff92SAndroid Build Coastguard Worker DetectedObjects SSDResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
16*89c4ff92SAndroid Build Coastguard Worker     const common::Size& outputFrameSize,
17*89c4ff92SAndroid Build Coastguard Worker     const common::Size& resizedFrameSize,
18*89c4ff92SAndroid Build Coastguard Worker     const std::vector<std::string>& labels)
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker     // SSD network outputs 4 tensors: bounding boxes, labels, probabilities, number of detections.
21*89c4ff92SAndroid Build Coastguard Worker     if (networkResults.size() != 4)
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         throw std::runtime_error("Number of outputs from SSD model doesn't equal 4");
24*89c4ff92SAndroid Build Coastguard Worker     }
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     DetectedObjects detectedObjects;
27*89c4ff92SAndroid Build Coastguard Worker     const int numDetections = static_cast<int>(std::lround(networkResults[3][0]));
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     double longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30*89c4ff92SAndroid Build Coastguard Worker     double longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31*89c4ff92SAndroid Build Coastguard Worker     const double resizeFactor = longEdgeOutput/longEdgeInput;
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     for (int i=0; i<numDetections; ++i)
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         if (networkResults[2][i] > m_objectThreshold)
36*89c4ff92SAndroid Build Coastguard Worker         {
37*89c4ff92SAndroid Build Coastguard Worker             DetectedObject detectedObject;
38*89c4ff92SAndroid Build Coastguard Worker             detectedObject.SetScore(networkResults[2][i]);
39*89c4ff92SAndroid Build Coastguard Worker             auto classId = std::lround(networkResults[1][i]);
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker             if (classId < labels.size())
42*89c4ff92SAndroid Build Coastguard Worker             {
43*89c4ff92SAndroid Build Coastguard Worker                 detectedObject.SetLabel(labels[classId]);
44*89c4ff92SAndroid Build Coastguard Worker             }
45*89c4ff92SAndroid Build Coastguard Worker             else
46*89c4ff92SAndroid Build Coastguard Worker             {
47*89c4ff92SAndroid Build Coastguard Worker                 detectedObject.SetLabel(std::to_string(classId));
48*89c4ff92SAndroid Build Coastguard Worker             }
49*89c4ff92SAndroid Build Coastguard Worker             detectedObject.SetId(classId);
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker             // Convert SSD bbox outputs (ratios of image size) to pixel values.
52*89c4ff92SAndroid Build Coastguard Worker             double topLeftY = networkResults[0][i*4 + 0] * resizedFrameSize.m_Height;
53*89c4ff92SAndroid Build Coastguard Worker             double topLeftX = networkResults[0][i*4 + 1] * resizedFrameSize.m_Width;
54*89c4ff92SAndroid Build Coastguard Worker             double botRightY = networkResults[0][i*4 + 2] * resizedFrameSize.m_Height;
55*89c4ff92SAndroid Build Coastguard Worker             double botRightX = networkResults[0][i*4 + 3] * resizedFrameSize.m_Width;
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker             // Scale the coordinates to output frame size.
58*89c4ff92SAndroid Build Coastguard Worker             topLeftY *= resizeFactor;
59*89c4ff92SAndroid Build Coastguard Worker             topLeftX *= resizeFactor;
60*89c4ff92SAndroid Build Coastguard Worker             botRightY *= resizeFactor;
61*89c4ff92SAndroid Build Coastguard Worker             botRightX *= resizeFactor;
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker             assert(botRightX > topLeftX);
64*89c4ff92SAndroid Build Coastguard Worker             assert(botRightY > topLeftY);
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker             // Internal BoundingBox stores box top left x,y and width, height.
67*89c4ff92SAndroid Build Coastguard Worker             detectedObject.SetBoundingBox({static_cast<int>(std::round(topLeftX)),
68*89c4ff92SAndroid Build Coastguard Worker                                            static_cast<int>(std::round(topLeftY)),
69*89c4ff92SAndroid Build Coastguard Worker                                            static_cast<unsigned int>(botRightX - topLeftX),
70*89c4ff92SAndroid Build Coastguard Worker                                            static_cast<unsigned int>(botRightY - topLeftY)});
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker             detectedObjects.emplace_back(detectedObject);
73*89c4ff92SAndroid Build Coastguard Worker         }
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker     return detectedObjects;
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
SSDResultDecoder(float ObjectThreshold)78*89c4ff92SAndroid Build Coastguard Worker SSDResultDecoder::SSDResultDecoder(float ObjectThreshold) : m_objectThreshold(ObjectThreshold) {}
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker }// namespace od