xref: /aosp_15_r20/external/ComputeLibrary/tests/datasets/ScaleValidationDataset.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2020-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef TESTS_DATASETS_SCALEVALIDATIONDATASET
25*c217d954SCole Faust #define TESTS_DATASETS_SCALEVALIDATIONDATASET
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "arm_compute/core/Types.h"
28*c217d954SCole Faust #include "tests/datasets/BorderModeDataset.h"
29*c217d954SCole Faust #include "tests/datasets/SamplingPolicyDataset.h"
30*c217d954SCole Faust #include "tests/datasets/ShapeDatasets.h"
31*c217d954SCole Faust 
32*c217d954SCole Faust namespace arm_compute
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace test
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace datasets
37*c217d954SCole Faust {
38*c217d954SCole Faust /** Class to generate boundary values for the given template parameters
39*c217d954SCole Faust  * including shapes with large differences between width and height.
40*c217d954SCole Faust  * element_per_iteration is the number of elements processed by one iteration
41*c217d954SCole Faust  * of an implementation. (E.g., if an iteration is based on a 16-byte vector
42*c217d954SCole Faust  * and size of one element is 1-byte, this value would be 16.).
43*c217d954SCole Faust  * iterations is the total number of complete iterations we want to test
44*c217d954SCole Faust  * for the effect of larger shapes.
45*c217d954SCole Faust  */
46*c217d954SCole Faust template <uint32_t channel, uint32_t batch, uint32_t element_per_iteration, uint32_t iterations>
47*c217d954SCole Faust class ScaleShapesBaseDataSet : public ShapeDataset
48*c217d954SCole Faust {
49*c217d954SCole Faust     static constexpr auto boundary_minus_one = element_per_iteration * iterations - 1;
50*c217d954SCole Faust     static constexpr auto boundary_plus_one  = element_per_iteration * iterations + 1;
51*c217d954SCole Faust     static constexpr auto small_size         = 3;
52*c217d954SCole Faust 
53*c217d954SCole Faust public:
54*c217d954SCole Faust     // These tensor shapes are NCHW layout, fixture will convert to NHWC.
ScaleShapesBaseDataSet()55*c217d954SCole Faust     ScaleShapesBaseDataSet()
56*c217d954SCole Faust         : ShapeDataset("Shape",
57*c217d954SCole Faust     {
58*c217d954SCole Faust         TensorShape{ small_size, boundary_minus_one, channel, batch },
59*c217d954SCole Faust                      TensorShape{ small_size, boundary_plus_one, channel, batch },
60*c217d954SCole Faust                      TensorShape{ boundary_minus_one, small_size, channel, batch },
61*c217d954SCole Faust                      TensorShape{ boundary_plus_one, small_size, channel, batch },
62*c217d954SCole Faust                      TensorShape{ boundary_minus_one, boundary_plus_one, channel, batch },
63*c217d954SCole Faust                      TensorShape{ boundary_plus_one, boundary_minus_one, channel, batch },
64*c217d954SCole Faust     })
65*c217d954SCole Faust     {
66*c217d954SCole Faust     }
67*c217d954SCole Faust };
68*c217d954SCole Faust 
69*c217d954SCole Faust /** For the single vector, only larger value (+1) than boundary
70*c217d954SCole Faust  * since smaller value (-1) could cause some invalid shapes like
71*c217d954SCole Faust  * - invalid zero size
72*c217d954SCole Faust  * - size 1 which isn't compatible with scale with aligned corners.
73*c217d954SCole Faust  */
74*c217d954SCole Faust template <uint32_t channel, uint32_t batch, uint32_t element_per_iteration>
75*c217d954SCole Faust class ScaleShapesBaseDataSet<channel, batch, element_per_iteration, 1> : public ShapeDataset
76*c217d954SCole Faust {
77*c217d954SCole Faust     static constexpr auto small_size        = 3;
78*c217d954SCole Faust     static constexpr auto boundary_plus_one = element_per_iteration + 1;
79*c217d954SCole Faust 
80*c217d954SCole Faust public:
81*c217d954SCole Faust     // These tensor shapes are NCHW layout, fixture will convert to NHWC.
ScaleShapesBaseDataSet()82*c217d954SCole Faust     ScaleShapesBaseDataSet()
83*c217d954SCole Faust         : ShapeDataset("Shape",
84*c217d954SCole Faust     {
85*c217d954SCole Faust         TensorShape{ small_size, boundary_plus_one, channel, batch },
86*c217d954SCole Faust                      TensorShape{ boundary_plus_one, small_size, channel, batch },
87*c217d954SCole Faust     })
88*c217d954SCole Faust     {
89*c217d954SCole Faust     }
90*c217d954SCole Faust };
91*c217d954SCole Faust 
92*c217d954SCole Faust /** For the shapes smaller than one vector, only pre-defined tiny shapes
93*c217d954SCole Faust  * are tested (3x2, 2x3) as smaller shapes are more likely to cause
94*c217d954SCole Faust  * issues and easier to debug.
95*c217d954SCole Faust  */
96*c217d954SCole Faust template <uint32_t channel, uint32_t batch, uint32_t element_per_iteration>
97*c217d954SCole Faust class ScaleShapesBaseDataSet<channel, batch, element_per_iteration, 0> : public ShapeDataset
98*c217d954SCole Faust {
99*c217d954SCole Faust     static constexpr auto small_size                 = 3;
100*c217d954SCole Faust     static constexpr auto zero_vector_boundary_value = 2;
101*c217d954SCole Faust 
102*c217d954SCole Faust public:
103*c217d954SCole Faust     // These tensor shapes are NCHW layout, fixture will convert to NHWC.
ScaleShapesBaseDataSet()104*c217d954SCole Faust     ScaleShapesBaseDataSet()
105*c217d954SCole Faust         : ShapeDataset("Shape",
106*c217d954SCole Faust     {
107*c217d954SCole Faust         TensorShape{ small_size, zero_vector_boundary_value, channel, batch },
108*c217d954SCole Faust                      TensorShape{ zero_vector_boundary_value, small_size, channel, batch },
109*c217d954SCole Faust     })
110*c217d954SCole Faust     {
111*c217d954SCole Faust     }
112*c217d954SCole Faust };
113*c217d954SCole Faust 
114*c217d954SCole Faust /** Interpolation policy test set */
115*c217d954SCole Faust const auto ScaleInterpolationPolicySet = framework::dataset::make("InterpolationPolicy",
116*c217d954SCole Faust {
117*c217d954SCole Faust     InterpolationPolicy::NEAREST_NEIGHBOR,
118*c217d954SCole Faust     InterpolationPolicy::BILINEAR,
119*c217d954SCole Faust });
120*c217d954SCole Faust 
121*c217d954SCole Faust /** Scale data types */
122*c217d954SCole Faust const auto ScaleDataLayouts = framework::dataset::make("DataLayout",
123*c217d954SCole Faust {
124*c217d954SCole Faust     DataLayout::NCHW,
125*c217d954SCole Faust     DataLayout::NHWC,
126*c217d954SCole Faust });
127*c217d954SCole Faust 
128*c217d954SCole Faust /** Sampling policy data set */
129*c217d954SCole Faust const auto ScaleSamplingPolicySet = combine(datasets::SamplingPolicies(),
130*c217d954SCole Faust                                             framework::dataset::make("AlignCorners", { false }));
131*c217d954SCole Faust 
132*c217d954SCole Faust /** Sampling policy data set for Aligned Corners which only allows TOP_LEFT policy.*/
133*c217d954SCole Faust const auto ScaleAlignCornersSamplingPolicySet = combine(framework::dataset::make("SamplingPolicy",
134*c217d954SCole Faust {
135*c217d954SCole Faust     SamplingPolicy::TOP_LEFT,
136*c217d954SCole Faust }),
137*c217d954SCole Faust framework::dataset::make("AlignCorners", { true }));
138*c217d954SCole Faust 
139*c217d954SCole Faust /** Generated shapes: used by precommit and nightly for CPU tests
140*c217d954SCole Faust  * - 2D shapes with 0, 1, 2 vector iterations
141*c217d954SCole Faust  * - 3D shapes with 0, 1 vector iterations
142*c217d954SCole Faust  * - 4D shapes with 0 vector iterations
143*c217d954SCole Faust  */
144*c217d954SCole Faust #define SCALE_SHAPE_DATASET(element_per_iteration)                                    \
145*c217d954SCole Faust     concat(concat(concat(ScaleShapesBaseDataSet<1, 1, (element_per_iteration), 0>(),  \
146*c217d954SCole Faust                          ScaleShapesBaseDataSet<1, 1, (element_per_iteration), 2>()), \
147*c217d954SCole Faust                   ScaleShapesBaseDataSet<3, 1, (element_per_iteration), 1>()),        \
148*c217d954SCole Faust            ScaleShapesBaseDataSet<40, 3, (element_per_iteration), 0>())
149*c217d954SCole Faust 
150*c217d954SCole Faust // To prevent long precommit time for OpenCL, shape set for OpenCL is separated into below two parts.
151*c217d954SCole Faust /** Generated shapes for precommits to achieve essential coverage. Used by CL precommit and nightly
152*c217d954SCole Faust  * - 3D shapes with 1 vector iterations
153*c217d954SCole Faust  * - 4D shapes with 1 vector iterations
154*c217d954SCole Faust  */
155*c217d954SCole Faust #define SCALE_PRECOMMIT_SHAPE_DATASET(element_per_iteration) \
156*c217d954SCole Faust     concat(ScaleShapesBaseDataSet<3, 1, (element_per_iteration), 1>(), ScaleShapesBaseDataSet<3, 3, (element_per_iteration), 1>())
157*c217d954SCole Faust 
158*c217d954SCole Faust /** Generated shapes for nightly to achieve more small and variety shapes. Used by CL nightly
159*c217d954SCole Faust  * - 2D shapes with 0, 1, 2 vector iterations
160*c217d954SCole Faust  * - 3D shapes with 0 vector iterations (1 vector iteration is covered by SCALE_PRECOMMIT_SHAPE_DATASET)
161*c217d954SCole Faust  * - 4D shapes with 0 vector iterations
162*c217d954SCole Faust  */
163*c217d954SCole Faust #define SCALE_NIGHTLY_SHAPE_DATASET(element_per_iteration)                            \
164*c217d954SCole Faust     concat(concat(concat(ScaleShapesBaseDataSet<1, 1, (element_per_iteration), 0>(),  \
165*c217d954SCole Faust                          ScaleShapesBaseDataSet<1, 1, (element_per_iteration), 1>()), \
166*c217d954SCole Faust                   ScaleShapesBaseDataSet<3, 1, (element_per_iteration), 0>()),        \
167*c217d954SCole Faust            ScaleShapesBaseDataSet<3, 3, (element_per_iteration), 0>())
168*c217d954SCole Faust 
169*c217d954SCole Faust /** Generating dataset for non-quantized data types with the given shapes */
170*c217d954SCole Faust #define ASSEMBLE_DATASET(shape, samping_policy_set)             \
171*c217d954SCole Faust     combine(combine(combine(combine((shape), ScaleDataLayouts), \
172*c217d954SCole Faust                             ScaleInterpolationPolicySet),       \
173*c217d954SCole Faust                     datasets::BorderModes()),                   \
174*c217d954SCole Faust             samping_policy_set)
175*c217d954SCole Faust 
176*c217d954SCole Faust #define ASSEMBLE_DATASET_DYNAMIC_FUSION(shape, samping_policy_set)                                  \
177*c217d954SCole Faust     combine(combine(combine((shape), framework::dataset::make("DataLayout", { DataLayout::NHWC })), \
178*c217d954SCole Faust                     ScaleInterpolationPolicySet),                                                   \
179*c217d954SCole Faust             samping_policy_set)
180*c217d954SCole Faust 
181*c217d954SCole Faust #define ASSEMBLE_S8_DATASET(shape, samping_policy_set)                                                           \
182*c217d954SCole Faust     combine(combine(combine(combine((shape), framework::dataset::make("DataLayout", DataLayout::NHWC)),          \
183*c217d954SCole Faust                             framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::BILINEAR })), \
184*c217d954SCole Faust                     framework::dataset::make("BorderMode", { BorderMode::REPLICATE })),                          \
185*c217d954SCole Faust             samping_policy_set)
186*c217d954SCole Faust 
187*c217d954SCole Faust #define ASSEMBLE_NHWC_DATASET(shape, samping_policy_set)                                                      \
188*c217d954SCole Faust     combine(combine(combine(combine((shape), framework::dataset::make("DataLayout", DataLayout::NHWC)),       \
189*c217d954SCole Faust                             ScaleInterpolationPolicySet),                                                     \
190*c217d954SCole Faust                     framework::dataset::make("BorderMode", { BorderMode::CONSTANT, BorderMode::REPLICATE })), \
191*c217d954SCole Faust             samping_policy_set)
192*c217d954SCole Faust 
193*c217d954SCole Faust /** Generating dataset for quantized data tyeps with the given shapes */
194*c217d954SCole Faust #define ASSEMBLE_QUANTIZED_DATASET(shape, sampling_policy_set, quantization_info_set) \
195*c217d954SCole Faust     combine(combine(combine(combine(combine(shape,                                    \
196*c217d954SCole Faust                                             quantization_info_set),                   \
197*c217d954SCole Faust                                     ScaleDataLayouts),                                \
198*c217d954SCole Faust                             ScaleInterpolationPolicySet),                             \
199*c217d954SCole Faust                     datasets::BorderModes()),                                         \
200*c217d954SCole Faust             sampling_policy_set)
201*c217d954SCole Faust 
202*c217d954SCole Faust #define ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(shape, sampling_policy_set, quantization_info_set) \
203*c217d954SCole Faust     combine(combine(combine(combine(shape,                                                           \
204*c217d954SCole Faust                                     quantization_info_set),                                          \
205*c217d954SCole Faust                             framework::dataset::make("DataLayout", { DataLayout::NHWC })),           \
206*c217d954SCole Faust                     ScaleInterpolationPolicySet),                                                    \
207*c217d954SCole Faust             sampling_policy_set)
208*c217d954SCole Faust 
209*c217d954SCole Faust /** Generating dataset for quantized data tyeps with the given shapes */
210*c217d954SCole Faust #define ASSEMBLE_DIFFERENTLY_QUANTIZED_DATASET(shape, sampling_policy_set, input_quant_info_set, output_quant_info_set) \
211*c217d954SCole Faust     combine(combine(combine(combine(combine(combine(shape,                                                              \
212*c217d954SCole Faust                                                     input_quant_info_set),                                              \
213*c217d954SCole Faust                                             output_quant_info_set),                                                     \
214*c217d954SCole Faust                                     framework::dataset::make("DataLayout", { DataLayout::NHWC })),                      \
215*c217d954SCole Faust                             framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::BILINEAR })),        \
216*c217d954SCole Faust                     framework::dataset::make("BorderMode", { BorderMode::REPLICATE })),                                 \
217*c217d954SCole Faust             sampling_policy_set)
218*c217d954SCole Faust 
219*c217d954SCole Faust } // namespace datasets
220*c217d954SCole Faust } // namespace test
221*c217d954SCole Faust } // namespace arm_compute
222*c217d954SCole Faust #endif /* TESTS_DATASETS_SCALEVALIDATIONDATASET */
223