1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2019-2020 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "NonMaxSuppression.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/core/Types.h"
27*c217d954SCole Faust #include "tests/validation/Helpers.h"
28*c217d954SCole Faust
29*c217d954SCole Faust namespace arm_compute
30*c217d954SCole Faust {
31*c217d954SCole Faust namespace test
32*c217d954SCole Faust {
33*c217d954SCole Faust namespace validation
34*c217d954SCole Faust {
35*c217d954SCole Faust namespace reference
36*c217d954SCole Faust {
37*c217d954SCole Faust namespace
38*c217d954SCole Faust {
39*c217d954SCole Faust using CandidateBox = std::pair<int /* index */, float /* score */>;
40*c217d954SCole Faust using Box = std::tuple<float, float, float, float>;
41*c217d954SCole Faust
get_elem_by_coordinate(const SimpleTensor<float> & tensor,Coordinates coord)42*c217d954SCole Faust inline float get_elem_by_coordinate(const SimpleTensor<float> &tensor, Coordinates coord)
43*c217d954SCole Faust {
44*c217d954SCole Faust return *static_cast<const float *>(tensor(coord));
45*c217d954SCole Faust }
46*c217d954SCole Faust
get_box(const SimpleTensor<float> & boxes,size_t id)47*c217d954SCole Faust inline Box get_box(const SimpleTensor<float> &boxes, size_t id)
48*c217d954SCole Faust {
49*c217d954SCole Faust return std::make_tuple(
50*c217d954SCole Faust get_elem_by_coordinate(boxes, Coordinates(0, id)),
51*c217d954SCole Faust get_elem_by_coordinate(boxes, Coordinates(1, id)),
52*c217d954SCole Faust get_elem_by_coordinate(boxes, Coordinates(2, id)),
53*c217d954SCole Faust get_elem_by_coordinate(boxes, Coordinates(3, id)));
54*c217d954SCole Faust }
55*c217d954SCole Faust
56*c217d954SCole Faust // returns a pair (minX, minY)
get_min_yx(Box b)57*c217d954SCole Faust inline std::pair<float, float> get_min_yx(Box b)
58*c217d954SCole Faust {
59*c217d954SCole Faust return std::make_pair(
60*c217d954SCole Faust std::min<float>(std::get<0>(b), std::get<2>(b)),
61*c217d954SCole Faust std::min<float>(std::get<1>(b), std::get<3>(b)));
62*c217d954SCole Faust }
63*c217d954SCole Faust // returns a pair (maxX, maxY)
get_max_yx(Box b)64*c217d954SCole Faust inline std::pair<float, float> get_max_yx(Box b)
65*c217d954SCole Faust {
66*c217d954SCole Faust return std::make_pair(
67*c217d954SCole Faust std::max<float>(std::get<0>(b), std::get<2>(b)),
68*c217d954SCole Faust std::max<float>(std::get<1>(b), std::get<3>(b)));
69*c217d954SCole Faust }
70*c217d954SCole Faust
compute_size(const std::pair<float,float> & min,const std::pair<float,float> & max)71*c217d954SCole Faust inline float compute_size(const std::pair<float, float> &min, const std::pair<float, float> &max)
72*c217d954SCole Faust {
73*c217d954SCole Faust return (max.first - min.first) * (max.second - min.second);
74*c217d954SCole Faust }
75*c217d954SCole Faust
compute_intersection(const std::pair<float,float> & b0_min,const std::pair<float,float> & b0_max,const std::pair<float,float> & b1_min,const std::pair<float,float> & b1_max,float b0_size,float b1_size)76*c217d954SCole Faust inline float compute_intersection(const std::pair<float, float> &b0_min, const std::pair<float, float> &b0_max,
77*c217d954SCole Faust const std::pair<float, float> &b1_min, const std::pair<float, float> &b1_max, float b0_size, float b1_size)
78*c217d954SCole Faust {
79*c217d954SCole Faust const float inter = std::max<float>(std::min<float>(b0_max.first, b1_max.first) - std::max<float>(b0_min.first, b1_min.first), 0.0f) * std::max<float>(std::min<float>(b0_max.second,
80*c217d954SCole Faust b1_max.second)
81*c217d954SCole Faust - std::max<float>(b0_min.second, b1_min.second),
82*c217d954SCole Faust 0.0f);
83*c217d954SCole Faust return inter / (b0_size + b1_size - inter);
84*c217d954SCole Faust }
85*c217d954SCole Faust
reject_box(Box b0,Box b1,float threshold)86*c217d954SCole Faust inline bool reject_box(Box b0, Box b1, float threshold)
87*c217d954SCole Faust {
88*c217d954SCole Faust const auto b0_min = get_min_yx(b0);
89*c217d954SCole Faust const auto b0_max = get_max_yx(b0);
90*c217d954SCole Faust const auto b1_min = get_min_yx(b1);
91*c217d954SCole Faust const auto b1_max = get_max_yx(b1);
92*c217d954SCole Faust const float b0_size = compute_size(b0_min, b0_max);
93*c217d954SCole Faust const float b1_size = compute_size(b1_min, b1_max);
94*c217d954SCole Faust if(b0_size <= 0.f || b1_size <= 0.f)
95*c217d954SCole Faust {
96*c217d954SCole Faust return false;
97*c217d954SCole Faust }
98*c217d954SCole Faust else
99*c217d954SCole Faust {
100*c217d954SCole Faust const float box_weight = compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size);
101*c217d954SCole Faust return box_weight > threshold;
102*c217d954SCole Faust }
103*c217d954SCole Faust }
104*c217d954SCole Faust
get_candidates(const SimpleTensor<float> & scores,float threshold)105*c217d954SCole Faust inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &scores, float threshold)
106*c217d954SCole Faust {
107*c217d954SCole Faust std::vector<CandidateBox> candidates_vector;
108*c217d954SCole Faust for(int i = 0; i < scores.num_elements(); ++i)
109*c217d954SCole Faust {
110*c217d954SCole Faust if(scores[i] >= threshold)
111*c217d954SCole Faust {
112*c217d954SCole Faust const auto cb = CandidateBox({ i, scores[i] });
113*c217d954SCole Faust candidates_vector.push_back(cb);
114*c217d954SCole Faust }
115*c217d954SCole Faust }
116*c217d954SCole Faust std::stable_sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1)
117*c217d954SCole Faust {
118*c217d954SCole Faust return bb0.second > bb1.second;
119*c217d954SCole Faust });
120*c217d954SCole Faust return candidates_vector;
121*c217d954SCole Faust }
122*c217d954SCole Faust
is_box_selected(const CandidateBox & cb,const SimpleTensor<float> & bboxes,std::vector<int> & selected_boxes,float threshold)123*c217d954SCole Faust inline bool is_box_selected(const CandidateBox &cb, const SimpleTensor<float> &bboxes, std::vector<int> &selected_boxes, float threshold)
124*c217d954SCole Faust {
125*c217d954SCole Faust for(int j = selected_boxes.size() - 1; j >= 0; --j)
126*c217d954SCole Faust {
127*c217d954SCole Faust const auto selected_box_jth = get_box(bboxes, selected_boxes[j]);
128*c217d954SCole Faust const auto candidate_box = get_box(bboxes, cb.first);
129*c217d954SCole Faust const bool candidate_rejected = reject_box(candidate_box, selected_box_jth, threshold);
130*c217d954SCole Faust if(candidate_rejected)
131*c217d954SCole Faust {
132*c217d954SCole Faust return false;
133*c217d954SCole Faust }
134*c217d954SCole Faust }
135*c217d954SCole Faust return true;
136*c217d954SCole Faust }
137*c217d954SCole Faust } // namespace
138*c217d954SCole Faust
non_max_suppression(const SimpleTensor<float> & bboxes,const SimpleTensor<float> & scores,SimpleTensor<int> & indices,unsigned int max_output_size,float score_threshold,float nms_threshold)139*c217d954SCole Faust SimpleTensor<int> non_max_suppression(const SimpleTensor<float> &bboxes, const SimpleTensor<float> &scores, SimpleTensor<int> &indices,
140*c217d954SCole Faust unsigned int max_output_size, float score_threshold, float nms_threshold)
141*c217d954SCole Faust {
142*c217d954SCole Faust const size_t num_boxes = bboxes.shape().y();
143*c217d954SCole Faust const size_t output_size = std::min(static_cast<size_t>(max_output_size), num_boxes);
144*c217d954SCole Faust const std::vector<CandidateBox> candidates_vector = get_candidates(scores, score_threshold);
145*c217d954SCole Faust std::vector<int> selected;
146*c217d954SCole Faust for(const auto &c : candidates_vector)
147*c217d954SCole Faust {
148*c217d954SCole Faust if(selected.size() == output_size)
149*c217d954SCole Faust {
150*c217d954SCole Faust break;
151*c217d954SCole Faust }
152*c217d954SCole Faust if(is_box_selected(c, bboxes, selected, nms_threshold))
153*c217d954SCole Faust {
154*c217d954SCole Faust selected.push_back(c.first);
155*c217d954SCole Faust }
156*c217d954SCole Faust }
157*c217d954SCole Faust std::copy_n(selected.begin(), selected.size(), indices.data());
158*c217d954SCole Faust
159*c217d954SCole Faust for(unsigned int i = selected.size(); i < max_output_size; ++i)
160*c217d954SCole Faust {
161*c217d954SCole Faust indices[i] = -1;
162*c217d954SCole Faust }
163*c217d954SCole Faust
164*c217d954SCole Faust return indices;
165*c217d954SCole Faust }
166*c217d954SCole Faust } // namespace reference
167*c217d954SCole Faust } // namespace validation
168*c217d954SCole Faust } // namespace test
169*c217d954SCole Faust } // namespace arm_compute
170