1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/runtime/CPP/functions/CPPBoxWithNonMaximaSuppressionLimit.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/CPP/kernels/CPPBoxWithNonMaximaSuppressionLimitKernel.h"
27*c217d954SCole Faust #include "arm_compute/runtime/Scheduler.h"
28*c217d954SCole Faust 
29*c217d954SCole Faust #include "src/common/utils/Log.h"
30*c217d954SCole Faust 
31*c217d954SCole Faust namespace arm_compute
32*c217d954SCole Faust {
33*c217d954SCole Faust namespace
34*c217d954SCole Faust {
dequantize_tensor(const ITensor * input,ITensor * output)35*c217d954SCole Faust void dequantize_tensor(const ITensor *input, ITensor *output)
36*c217d954SCole Faust {
37*c217d954SCole Faust     const UniformQuantizationInfo qinfo     = input->info()->quantization_info().uniform();
38*c217d954SCole Faust     const DataType                data_type = input->info()->data_type();
39*c217d954SCole Faust 
40*c217d954SCole Faust     Window window;
41*c217d954SCole Faust     window.use_tensor_dimensions(input->info()->tensor_shape());
42*c217d954SCole Faust     Iterator input_it(input, window);
43*c217d954SCole Faust     Iterator output_it(output, window);
44*c217d954SCole Faust 
45*c217d954SCole Faust     switch(data_type)
46*c217d954SCole Faust     {
47*c217d954SCole Faust         case DataType::QASYMM8:
48*c217d954SCole Faust             execute_window_loop(window, [&](const Coordinates &)
49*c217d954SCole Faust             {
50*c217d954SCole Faust                 *reinterpret_cast<float *>(output_it.ptr()) = dequantize(*reinterpret_cast<const uint8_t *>(input_it.ptr()), qinfo.scale, qinfo.offset);
51*c217d954SCole Faust             },
52*c217d954SCole Faust             input_it, output_it);
53*c217d954SCole Faust             break;
54*c217d954SCole Faust         case DataType::QASYMM8_SIGNED:
55*c217d954SCole Faust             execute_window_loop(window, [&](const Coordinates &)
56*c217d954SCole Faust             {
57*c217d954SCole Faust                 *reinterpret_cast<float *>(output_it.ptr()) = dequantize_qasymm8_signed(*reinterpret_cast<const int8_t *>(input_it.ptr()), qinfo);
58*c217d954SCole Faust             },
59*c217d954SCole Faust             input_it, output_it);
60*c217d954SCole Faust             break;
61*c217d954SCole Faust         case DataType::QASYMM16:
62*c217d954SCole Faust             execute_window_loop(window, [&](const Coordinates &)
63*c217d954SCole Faust             {
64*c217d954SCole Faust                 *reinterpret_cast<float *>(output_it.ptr()) = dequantize(*reinterpret_cast<const uint16_t *>(input_it.ptr()), qinfo.scale, qinfo.offset);
65*c217d954SCole Faust             },
66*c217d954SCole Faust             input_it, output_it);
67*c217d954SCole Faust             break;
68*c217d954SCole Faust         default:
69*c217d954SCole Faust             ARM_COMPUTE_ERROR("Unsupported data type");
70*c217d954SCole Faust     }
71*c217d954SCole Faust }
72*c217d954SCole Faust 
quantize_tensor(const ITensor * input,ITensor * output)73*c217d954SCole Faust void quantize_tensor(const ITensor *input, ITensor *output)
74*c217d954SCole Faust {
75*c217d954SCole Faust     const UniformQuantizationInfo qinfo     = output->info()->quantization_info().uniform();
76*c217d954SCole Faust     const DataType                data_type = output->info()->data_type();
77*c217d954SCole Faust 
78*c217d954SCole Faust     Window window;
79*c217d954SCole Faust     window.use_tensor_dimensions(input->info()->tensor_shape());
80*c217d954SCole Faust     Iterator input_it(input, window);
81*c217d954SCole Faust     Iterator output_it(output, window);
82*c217d954SCole Faust 
83*c217d954SCole Faust     switch(data_type)
84*c217d954SCole Faust     {
85*c217d954SCole Faust         case DataType::QASYMM8:
86*c217d954SCole Faust             execute_window_loop(window, [&](const Coordinates &)
87*c217d954SCole Faust             {
88*c217d954SCole Faust                 *reinterpret_cast<uint8_t *>(output_it.ptr()) = quantize_qasymm8(*reinterpret_cast<const float *>(input_it.ptr()), qinfo);
89*c217d954SCole Faust             },
90*c217d954SCole Faust             input_it, output_it);
91*c217d954SCole Faust             break;
92*c217d954SCole Faust         case DataType::QASYMM8_SIGNED:
93*c217d954SCole Faust             execute_window_loop(window, [&](const Coordinates &)
94*c217d954SCole Faust             {
95*c217d954SCole Faust                 *reinterpret_cast<int8_t *>(output_it.ptr()) = quantize_qasymm8_signed(*reinterpret_cast<const float *>(input_it.ptr()), qinfo);
96*c217d954SCole Faust             },
97*c217d954SCole Faust             input_it, output_it);
98*c217d954SCole Faust             break;
99*c217d954SCole Faust         case DataType::QASYMM16:
100*c217d954SCole Faust             execute_window_loop(window, [&](const Coordinates &)
101*c217d954SCole Faust             {
102*c217d954SCole Faust                 *reinterpret_cast<uint16_t *>(output_it.ptr()) = quantize_qasymm16(*reinterpret_cast<const float *>(input_it.ptr()), qinfo);
103*c217d954SCole Faust             },
104*c217d954SCole Faust             input_it, output_it);
105*c217d954SCole Faust             break;
106*c217d954SCole Faust         default:
107*c217d954SCole Faust             ARM_COMPUTE_ERROR("Unsupported data type");
108*c217d954SCole Faust     }
109*c217d954SCole Faust }
110*c217d954SCole Faust } // namespace
111*c217d954SCole Faust 
CPPBoxWithNonMaximaSuppressionLimit(std::shared_ptr<IMemoryManager> memory_manager)112*c217d954SCole Faust CPPBoxWithNonMaximaSuppressionLimit::CPPBoxWithNonMaximaSuppressionLimit(std::shared_ptr<IMemoryManager> memory_manager)
113*c217d954SCole Faust     : _memory_group(std::move(memory_manager)),
114*c217d954SCole Faust       _box_with_nms_limit_kernel(),
115*c217d954SCole Faust       _scores_in(),
116*c217d954SCole Faust       _boxes_in(),
117*c217d954SCole Faust       _batch_splits_in(),
118*c217d954SCole Faust       _scores_out(),
119*c217d954SCole Faust       _boxes_out(),
120*c217d954SCole Faust       _classes(),
121*c217d954SCole Faust       _batch_splits_out(),
122*c217d954SCole Faust       _keeps(),
123*c217d954SCole Faust       _scores_in_f32(),
124*c217d954SCole Faust       _boxes_in_f32(),
125*c217d954SCole Faust       _batch_splits_in_f32(),
126*c217d954SCole Faust       _scores_out_f32(),
127*c217d954SCole Faust       _boxes_out_f32(),
128*c217d954SCole Faust       _classes_f32(),
129*c217d954SCole Faust       _batch_splits_out_f32(),
130*c217d954SCole Faust       _keeps_f32(),
131*c217d954SCole Faust       _is_qasymm8(false)
132*c217d954SCole Faust {
133*c217d954SCole Faust }
134*c217d954SCole Faust 
configure(const ITensor * scores_in,const ITensor * boxes_in,const ITensor * batch_splits_in,ITensor * scores_out,ITensor * boxes_out,ITensor * classes,ITensor * batch_splits_out,ITensor * keeps,ITensor * keeps_size,const BoxNMSLimitInfo info)135*c217d954SCole Faust void CPPBoxWithNonMaximaSuppressionLimit::configure(const ITensor *scores_in, const ITensor *boxes_in, const ITensor *batch_splits_in,
136*c217d954SCole Faust                                                     ITensor *scores_out, ITensor *boxes_out, ITensor *classes, ITensor *batch_splits_out,
137*c217d954SCole Faust                                                     ITensor *keeps, ITensor *keeps_size, const BoxNMSLimitInfo info)
138*c217d954SCole Faust {
139*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(scores_in, boxes_in, scores_out, boxes_out, classes);
140*c217d954SCole Faust     ARM_COMPUTE_LOG_PARAMS(scores_in, boxes_in, batch_splits_in, scores_out, boxes_out, classes, batch_splits_out, keeps, keeps_size, info);
141*c217d954SCole Faust 
142*c217d954SCole Faust     _is_qasymm8 = scores_in->info()->data_type() == DataType::QASYMM8 || scores_in->info()->data_type() == DataType::QASYMM8_SIGNED;
143*c217d954SCole Faust 
144*c217d954SCole Faust     _scores_in        = scores_in;
145*c217d954SCole Faust     _boxes_in         = boxes_in;
146*c217d954SCole Faust     _batch_splits_in  = batch_splits_in;
147*c217d954SCole Faust     _scores_out       = scores_out;
148*c217d954SCole Faust     _boxes_out        = boxes_out;
149*c217d954SCole Faust     _classes          = classes;
150*c217d954SCole Faust     _batch_splits_out = batch_splits_out;
151*c217d954SCole Faust     _keeps            = keeps;
152*c217d954SCole Faust 
153*c217d954SCole Faust     if(_is_qasymm8)
154*c217d954SCole Faust     {
155*c217d954SCole Faust         // Manage intermediate buffers
156*c217d954SCole Faust         _memory_group.manage(&_scores_in_f32);
157*c217d954SCole Faust         _memory_group.manage(&_boxes_in_f32);
158*c217d954SCole Faust         _memory_group.manage(&_scores_out_f32);
159*c217d954SCole Faust         _memory_group.manage(&_boxes_out_f32);
160*c217d954SCole Faust         _memory_group.manage(&_classes_f32);
161*c217d954SCole Faust         _scores_in_f32.allocator()->init(scores_in->info()->clone()->set_data_type(DataType::F32));
162*c217d954SCole Faust         _boxes_in_f32.allocator()->init(boxes_in->info()->clone()->set_data_type(DataType::F32));
163*c217d954SCole Faust         if(batch_splits_in != nullptr)
164*c217d954SCole Faust         {
165*c217d954SCole Faust             _memory_group.manage(&_batch_splits_in_f32);
166*c217d954SCole Faust             _batch_splits_in_f32.allocator()->init(batch_splits_in->info()->clone()->set_data_type(DataType::F32));
167*c217d954SCole Faust         }
168*c217d954SCole Faust         _scores_out_f32.allocator()->init(scores_out->info()->clone()->set_data_type(DataType::F32));
169*c217d954SCole Faust         _boxes_out_f32.allocator()->init(boxes_out->info()->clone()->set_data_type(DataType::F32));
170*c217d954SCole Faust         _classes_f32.allocator()->init(classes->info()->clone()->set_data_type(DataType::F32));
171*c217d954SCole Faust         if(batch_splits_out != nullptr)
172*c217d954SCole Faust         {
173*c217d954SCole Faust             _memory_group.manage(&_batch_splits_out_f32);
174*c217d954SCole Faust             _batch_splits_out_f32.allocator()->init(batch_splits_out->info()->clone()->set_data_type(DataType::F32));
175*c217d954SCole Faust         }
176*c217d954SCole Faust         if(keeps != nullptr)
177*c217d954SCole Faust         {
178*c217d954SCole Faust             _memory_group.manage(&_keeps_f32);
179*c217d954SCole Faust             _keeps_f32.allocator()->init(keeps->info()->clone()->set_data_type(DataType::F32));
180*c217d954SCole Faust         }
181*c217d954SCole Faust 
182*c217d954SCole Faust         _box_with_nms_limit_kernel.configure(&_scores_in_f32, &_boxes_in_f32, (batch_splits_in != nullptr) ? &_batch_splits_in_f32 : nullptr,
183*c217d954SCole Faust                                              &_scores_out_f32, &_boxes_out_f32, &_classes_f32,
184*c217d954SCole Faust                                              (batch_splits_out != nullptr) ? &_batch_splits_out_f32 : nullptr, (keeps != nullptr) ? &_keeps_f32 : nullptr,
185*c217d954SCole Faust                                              keeps_size, info);
186*c217d954SCole Faust     }
187*c217d954SCole Faust     else
188*c217d954SCole Faust     {
189*c217d954SCole Faust         _box_with_nms_limit_kernel.configure(scores_in, boxes_in, batch_splits_in, scores_out, boxes_out, classes, batch_splits_out, keeps, keeps_size, info);
190*c217d954SCole Faust     }
191*c217d954SCole Faust 
192*c217d954SCole Faust     if(_is_qasymm8)
193*c217d954SCole Faust     {
194*c217d954SCole Faust         _scores_in_f32.allocator()->allocate();
195*c217d954SCole Faust         _boxes_in_f32.allocator()->allocate();
196*c217d954SCole Faust         if(_batch_splits_in != nullptr)
197*c217d954SCole Faust         {
198*c217d954SCole Faust             _batch_splits_in_f32.allocator()->allocate();
199*c217d954SCole Faust         }
200*c217d954SCole Faust         _scores_out_f32.allocator()->allocate();
201*c217d954SCole Faust         _boxes_out_f32.allocator()->allocate();
202*c217d954SCole Faust         _classes_f32.allocator()->allocate();
203*c217d954SCole Faust         if(batch_splits_out != nullptr)
204*c217d954SCole Faust         {
205*c217d954SCole Faust             _batch_splits_out_f32.allocator()->allocate();
206*c217d954SCole Faust         }
207*c217d954SCole Faust         if(keeps != nullptr)
208*c217d954SCole Faust         {
209*c217d954SCole Faust             _keeps_f32.allocator()->allocate();
210*c217d954SCole Faust         }
211*c217d954SCole Faust     }
212*c217d954SCole Faust }
213*c217d954SCole Faust 
validate(const ITensorInfo * scores_in,const ITensorInfo * boxes_in,const ITensorInfo * batch_splits_in,const ITensorInfo * scores_out,const ITensorInfo * boxes_out,const ITensorInfo * classes,const ITensorInfo * batch_splits_out,const ITensorInfo * keeps,const ITensorInfo * keeps_size,const BoxNMSLimitInfo info)214*c217d954SCole Faust Status validate(const ITensorInfo *scores_in, const ITensorInfo *boxes_in, const ITensorInfo *batch_splits_in, const ITensorInfo *scores_out, const ITensorInfo *boxes_out, const ITensorInfo *classes,
215*c217d954SCole Faust                 const ITensorInfo *batch_splits_out, const ITensorInfo *keeps, const ITensorInfo *keeps_size, const BoxNMSLimitInfo info)
216*c217d954SCole Faust {
217*c217d954SCole Faust     ARM_COMPUTE_UNUSED(batch_splits_in, batch_splits_out, keeps, keeps_size, info);
218*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(scores_in, boxes_in, scores_out, boxes_out, classes);
219*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(scores_in, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
220*c217d954SCole Faust 
221*c217d954SCole Faust     const bool is_qasymm8 = scores_in->data_type() == DataType::QASYMM8 || scores_in->data_type() == DataType::QASYMM8_SIGNED;
222*c217d954SCole Faust     if(is_qasymm8)
223*c217d954SCole Faust     {
224*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(boxes_in, 1, DataType::QASYMM16);
225*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(boxes_in, boxes_out);
226*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(boxes_in, boxes_out);
227*c217d954SCole Faust         const UniformQuantizationInfo boxes_qinfo = boxes_in->quantization_info().uniform();
228*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(boxes_qinfo.scale != 0.125f);
229*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(boxes_qinfo.offset != 0);
230*c217d954SCole Faust     }
231*c217d954SCole Faust 
232*c217d954SCole Faust     return Status{};
233*c217d954SCole Faust }
234*c217d954SCole Faust 
run()235*c217d954SCole Faust void CPPBoxWithNonMaximaSuppressionLimit::run()
236*c217d954SCole Faust {
237*c217d954SCole Faust     // Acquire all the temporaries
238*c217d954SCole Faust     MemoryGroupResourceScope scope_mg(_memory_group);
239*c217d954SCole Faust 
240*c217d954SCole Faust     if(_is_qasymm8)
241*c217d954SCole Faust     {
242*c217d954SCole Faust         dequantize_tensor(_scores_in, &_scores_in_f32);
243*c217d954SCole Faust         dequantize_tensor(_boxes_in, &_boxes_in_f32);
244*c217d954SCole Faust         if(_batch_splits_in != nullptr)
245*c217d954SCole Faust         {
246*c217d954SCole Faust             dequantize_tensor(_batch_splits_in, &_batch_splits_in_f32);
247*c217d954SCole Faust         }
248*c217d954SCole Faust     }
249*c217d954SCole Faust 
250*c217d954SCole Faust     Scheduler::get().schedule(&_box_with_nms_limit_kernel, Window::DimY);
251*c217d954SCole Faust 
252*c217d954SCole Faust     if(_is_qasymm8)
253*c217d954SCole Faust     {
254*c217d954SCole Faust         quantize_tensor(&_scores_out_f32, _scores_out);
255*c217d954SCole Faust         quantize_tensor(&_boxes_out_f32, _boxes_out);
256*c217d954SCole Faust         quantize_tensor(&_classes_f32, _classes);
257*c217d954SCole Faust         if(_batch_splits_out != nullptr)
258*c217d954SCole Faust         {
259*c217d954SCole Faust             quantize_tensor(&_batch_splits_out_f32, _batch_splits_out);
260*c217d954SCole Faust         }
261*c217d954SCole Faust         if(_keeps != nullptr)
262*c217d954SCole Faust         {
263*c217d954SCole Faust             quantize_tensor(&_keeps_f32, _keeps);
264*c217d954SCole Faust         }
265*c217d954SCole Faust     }
266*c217d954SCole Faust }
267*c217d954SCole Faust } // namespace arm_compute
268