xref: /aosp_15_r20/external/ComputeLibrary/tests/datasets/ShapeDatasets.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2017-2023 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_SHAPE_DATASETS_H
25*c217d954SCole Faust #define ARM_COMPUTE_TEST_SHAPE_DATASETS_H
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "arm_compute/core/TensorShape.h"
28*c217d954SCole Faust #include "tests/framework/datasets/Datasets.h"
29*c217d954SCole Faust 
30*c217d954SCole Faust #include <type_traits>
31*c217d954SCole Faust 
32*c217d954SCole Faust namespace arm_compute
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace test
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace datasets
37*c217d954SCole Faust {
38*c217d954SCole Faust /** Parent type for all for shape datasets. */
39*c217d954SCole Faust using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>;
40*c217d954SCole Faust 
41*c217d954SCole Faust /** Data set containing tiny 1D tensor shapes. */
42*c217d954SCole Faust class Tiny1DShapes final : public ShapeDataset
43*c217d954SCole Faust {
44*c217d954SCole Faust public:
Tiny1DShapes()45*c217d954SCole Faust     Tiny1DShapes()
46*c217d954SCole Faust         : ShapeDataset("Shape",
47*c217d954SCole Faust     {
48*c217d954SCole Faust         TensorShape{ 2U },
49*c217d954SCole Faust                      TensorShape{ 3U },
50*c217d954SCole Faust     })
51*c217d954SCole Faust     {
52*c217d954SCole Faust     }
53*c217d954SCole Faust };
54*c217d954SCole Faust 
55*c217d954SCole Faust /** Data set containing small 1D tensor shapes. */
56*c217d954SCole Faust class Small1DShapes final : public ShapeDataset
57*c217d954SCole Faust {
58*c217d954SCole Faust public:
Small1DShapes()59*c217d954SCole Faust     Small1DShapes()
60*c217d954SCole Faust         : ShapeDataset("Shape",
61*c217d954SCole Faust     {
62*c217d954SCole Faust         TensorShape{ 128U },
63*c217d954SCole Faust                      TensorShape{ 256U },
64*c217d954SCole Faust                      TensorShape{ 512U },
65*c217d954SCole Faust                      TensorShape{ 1024U }
66*c217d954SCole Faust     })
67*c217d954SCole Faust     {
68*c217d954SCole Faust     }
69*c217d954SCole Faust };
70*c217d954SCole Faust 
71*c217d954SCole Faust /** Data set containing tiny 2D tensor shapes. */
72*c217d954SCole Faust class Tiny2DShapes final : public ShapeDataset
73*c217d954SCole Faust {
74*c217d954SCole Faust public:
Tiny2DShapes()75*c217d954SCole Faust     Tiny2DShapes()
76*c217d954SCole Faust         : ShapeDataset("Shape",
77*c217d954SCole Faust     {
78*c217d954SCole Faust         TensorShape{ 7U, 7U },
79*c217d954SCole Faust                      TensorShape{ 11U, 13U },
80*c217d954SCole Faust     })
81*c217d954SCole Faust     {
82*c217d954SCole Faust     }
83*c217d954SCole Faust };
84*c217d954SCole Faust /** Data set containing small 2D tensor shapes. */
85*c217d954SCole Faust class Small2DShapes final : public ShapeDataset
86*c217d954SCole Faust {
87*c217d954SCole Faust public:
Small2DShapes()88*c217d954SCole Faust     Small2DShapes()
89*c217d954SCole Faust         : ShapeDataset("Shape",
90*c217d954SCole Faust     {
91*c217d954SCole Faust         TensorShape{ 1U, 7U },
92*c217d954SCole Faust                      TensorShape{ 5U, 13U },
93*c217d954SCole Faust                      TensorShape{ 32U, 64U }
94*c217d954SCole Faust     })
95*c217d954SCole Faust     {
96*c217d954SCole Faust     }
97*c217d954SCole Faust };
98*c217d954SCole Faust 
99*c217d954SCole Faust /** Data set containing tiny 3D tensor shapes. */
100*c217d954SCole Faust class Tiny3DShapes final : public ShapeDataset
101*c217d954SCole Faust {
102*c217d954SCole Faust public:
Tiny3DShapes()103*c217d954SCole Faust     Tiny3DShapes()
104*c217d954SCole Faust         : ShapeDataset("Shape",
105*c217d954SCole Faust     {
106*c217d954SCole Faust         TensorShape{ 7U, 7U, 5U },
107*c217d954SCole Faust                      TensorShape{ 23U, 13U, 9U },
108*c217d954SCole Faust     })
109*c217d954SCole Faust     {
110*c217d954SCole Faust     }
111*c217d954SCole Faust };
112*c217d954SCole Faust 
113*c217d954SCole Faust /** Data set containing small 3D tensor shapes. */
114*c217d954SCole Faust class Small3DShapes final : public ShapeDataset
115*c217d954SCole Faust {
116*c217d954SCole Faust public:
Small3DShapes()117*c217d954SCole Faust     Small3DShapes()
118*c217d954SCole Faust         : ShapeDataset("Shape",
119*c217d954SCole Faust     {
120*c217d954SCole Faust         TensorShape{ 1U, 7U, 7U },
121*c217d954SCole Faust                      TensorShape{ 2U, 5U, 4U },
122*c217d954SCole Faust 
123*c217d954SCole Faust                      TensorShape{ 7U, 7U, 5U },
124*c217d954SCole Faust                      TensorShape{ 16U, 16U, 5U },
125*c217d954SCole Faust                      TensorShape{ 27U, 13U, 37U },
126*c217d954SCole Faust     })
127*c217d954SCole Faust     {
128*c217d954SCole Faust     }
129*c217d954SCole Faust };
130*c217d954SCole Faust 
131*c217d954SCole Faust /** Data set containing tiny 4D tensor shapes. */
132*c217d954SCole Faust class Tiny4DShapes final : public ShapeDataset
133*c217d954SCole Faust {
134*c217d954SCole Faust public:
Tiny4DShapes()135*c217d954SCole Faust     Tiny4DShapes()
136*c217d954SCole Faust         : ShapeDataset("Shape",
137*c217d954SCole Faust     {
138*c217d954SCole Faust         TensorShape{ 2U, 7U, 5U, 3U },
139*c217d954SCole Faust                      TensorShape{ 17U, 13U, 7U, 2U },
140*c217d954SCole Faust     })
141*c217d954SCole Faust     {
142*c217d954SCole Faust     }
143*c217d954SCole Faust };
144*c217d954SCole Faust /** Data set containing small 4D tensor shapes. */
145*c217d954SCole Faust class Small4DShapes final : public ShapeDataset
146*c217d954SCole Faust {
147*c217d954SCole Faust public:
Small4DShapes()148*c217d954SCole Faust     Small4DShapes()
149*c217d954SCole Faust         : ShapeDataset("Shape",
150*c217d954SCole Faust     {
151*c217d954SCole Faust         TensorShape{ 2U, 7U, 1U, 3U },
152*c217d954SCole Faust                      TensorShape{ 7U, 7U, 5U, 3U },
153*c217d954SCole Faust                      TensorShape{ 27U, 13U, 37U, 2U },
154*c217d954SCole Faust                      TensorShape{ 128U, 64U, 21U, 3U }
155*c217d954SCole Faust     })
156*c217d954SCole Faust     {
157*c217d954SCole Faust     }
158*c217d954SCole Faust };
159*c217d954SCole Faust 
160*c217d954SCole Faust /** Data set containing tiny tensor shapes. */
161*c217d954SCole Faust class TinyShapes final : public ShapeDataset
162*c217d954SCole Faust {
163*c217d954SCole Faust public:
TinyShapes()164*c217d954SCole Faust     TinyShapes()
165*c217d954SCole Faust         : ShapeDataset("Shape",
166*c217d954SCole Faust     {
167*c217d954SCole Faust         // Batch size 1
168*c217d954SCole Faust         TensorShape{ 1U, 9U },
169*c217d954SCole Faust                      TensorShape{ 27U, 13U, 2U },
170*c217d954SCole Faust     })
171*c217d954SCole Faust     {
172*c217d954SCole Faust     }
173*c217d954SCole Faust };
174*c217d954SCole Faust /** Data set containing small tensor shapes with none of the dimensions equal to 1 (unit). */
175*c217d954SCole Faust class SmallNoneUnitShapes final : public ShapeDataset
176*c217d954SCole Faust {
177*c217d954SCole Faust public:
SmallNoneUnitShapes()178*c217d954SCole Faust     SmallNoneUnitShapes()
179*c217d954SCole Faust         : ShapeDataset("Shape",
180*c217d954SCole Faust     {
181*c217d954SCole Faust         // Batch size 1
182*c217d954SCole Faust         TensorShape{ 13U, 11U },
183*c217d954SCole Faust                      TensorShape{ 16U, 16U },
184*c217d954SCole Faust                      TensorShape{ 24U, 26U, 5U },
185*c217d954SCole Faust                      TensorShape{ 7U, 7U, 17U, 2U },
186*c217d954SCole Faust                      // Batch size 4
187*c217d954SCole Faust                      TensorShape{ 27U, 13U, 2U, 4U },
188*c217d954SCole Faust                      // Arbitrary batch size
189*c217d954SCole Faust                      TensorShape{ 8U, 7U, 5U, 5U }
190*c217d954SCole Faust     })
191*c217d954SCole Faust     {
192*c217d954SCole Faust     }
193*c217d954SCole Faust };
194*c217d954SCole Faust /** Data set containing small tensor shapes. */
195*c217d954SCole Faust class SmallShapes final : public ShapeDataset
196*c217d954SCole Faust {
197*c217d954SCole Faust public:
SmallShapes()198*c217d954SCole Faust     SmallShapes()
199*c217d954SCole Faust         : ShapeDataset("Shape",
200*c217d954SCole Faust     {
201*c217d954SCole Faust         // Batch size 1
202*c217d954SCole Faust         TensorShape{ 3U, 11U },
203*c217d954SCole Faust                      TensorShape{ 1U, 16U },
204*c217d954SCole Faust                      TensorShape{ 27U, 13U, 7U },
205*c217d954SCole Faust                      TensorShape{ 7U, 7U, 17U, 2U },
206*c217d954SCole Faust                      // Batch size 4 and 2 SIMD iterations
207*c217d954SCole Faust                      TensorShape{ 33U, 13U, 2U, 4U },
208*c217d954SCole Faust                      // Arbitrary batch size
209*c217d954SCole Faust                      TensorShape{ 11U, 11U, 3U, 5U }
210*c217d954SCole Faust     })
211*c217d954SCole Faust     {
212*c217d954SCole Faust     }
213*c217d954SCole Faust };
214*c217d954SCole Faust 
215*c217d954SCole Faust /** Data set containing small tensor shapes. */
216*c217d954SCole Faust class SmallShapesNoBatches final : public ShapeDataset
217*c217d954SCole Faust {
218*c217d954SCole Faust public:
SmallShapesNoBatches()219*c217d954SCole Faust     SmallShapesNoBatches()
220*c217d954SCole Faust         : ShapeDataset("Shape",
221*c217d954SCole Faust     {
222*c217d954SCole Faust         // Batch size 1
223*c217d954SCole Faust         TensorShape{ 3U, 11U },
224*c217d954SCole Faust                      TensorShape{ 1U, 16U },
225*c217d954SCole Faust                      TensorShape{ 27U, 13U, 7U },
226*c217d954SCole Faust                      TensorShape{ 7U, 7U, 17U },
227*c217d954SCole Faust                      TensorShape{ 33U, 13U, 2U },
228*c217d954SCole Faust                      TensorShape{ 11U, 11U, 3U }
229*c217d954SCole Faust     })
230*c217d954SCole Faust     {
231*c217d954SCole Faust     }
232*c217d954SCole Faust };
233*c217d954SCole Faust 
234*c217d954SCole Faust /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */
235*c217d954SCole Faust class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
236*c217d954SCole Faust {
237*c217d954SCole Faust public:
TinyShapesBroadcast()238*c217d954SCole Faust     TinyShapesBroadcast()
239*c217d954SCole Faust         : ZipDataset<ShapeDataset, ShapeDataset>(
240*c217d954SCole Faust               ShapeDataset("Shape0",
241*c217d954SCole Faust     {
242*c217d954SCole Faust         TensorShape{ 9U, 9U },
243*c217d954SCole Faust                      TensorShape{ 10U, 2U, 14U, 2U },
244*c217d954SCole Faust     }),
245*c217d954SCole Faust     ShapeDataset("Shape1",
246*c217d954SCole Faust     {
247*c217d954SCole Faust         TensorShape{ 9U, 1U, 9U },
248*c217d954SCole Faust         TensorShape{ 10U },
249*c217d954SCole Faust     }))
250*c217d954SCole Faust     {
251*c217d954SCole Faust     }
252*c217d954SCole Faust };
253*c217d954SCole Faust /** Data set containing pairs of tiny tensor shapes that are broadcast compatible and can do in_place calculation. */
254*c217d954SCole Faust class TinyShapesBroadcastInplace final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
255*c217d954SCole Faust {
256*c217d954SCole Faust public:
TinyShapesBroadcastInplace()257*c217d954SCole Faust     TinyShapesBroadcastInplace()
258*c217d954SCole Faust         : ZipDataset<ShapeDataset, ShapeDataset>(
259*c217d954SCole Faust               ShapeDataset("Shape0",
260*c217d954SCole Faust     {
261*c217d954SCole Faust         TensorShape{ 9U },
262*c217d954SCole Faust                      TensorShape{ 10U, 2U, 14U, 2U },
263*c217d954SCole Faust     }),
264*c217d954SCole Faust     ShapeDataset("Shape1",
265*c217d954SCole Faust     {
266*c217d954SCole Faust         TensorShape{ 9U, 1U, 9U },
267*c217d954SCole Faust         TensorShape{ 10U },
268*c217d954SCole Faust     }))
269*c217d954SCole Faust     {
270*c217d954SCole Faust     }
271*c217d954SCole Faust };
272*c217d954SCole Faust /** Data set containing pairs of small tensor shapes that are broadcast compatible. */
273*c217d954SCole Faust class SmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
274*c217d954SCole Faust {
275*c217d954SCole Faust public:
SmallShapesBroadcast()276*c217d954SCole Faust     SmallShapesBroadcast()
277*c217d954SCole Faust         : ZipDataset<ShapeDataset, ShapeDataset>(
278*c217d954SCole Faust               ShapeDataset("Shape0",
279*c217d954SCole Faust     {
280*c217d954SCole Faust         TensorShape{ 9U, 9U },
281*c217d954SCole Faust                      TensorShape{ 27U, 13U, 2U },
282*c217d954SCole Faust                      TensorShape{ 128U, 1U, 5U, 3U },
283*c217d954SCole Faust                      TensorShape{ 9U, 9U, 3U, 4U },
284*c217d954SCole Faust                      TensorShape{ 27U, 13U, 2U, 4U },
285*c217d954SCole Faust                      TensorShape{ 1U, 1U, 1U, 5U },
286*c217d954SCole Faust                      TensorShape{ 1U, 16U, 10U, 2U, 128U },
287*c217d954SCole Faust                      TensorShape{ 1U, 16U, 10U, 2U, 128U }
288*c217d954SCole Faust     }),
289*c217d954SCole Faust     ShapeDataset("Shape1",
290*c217d954SCole Faust     {
291*c217d954SCole Faust         TensorShape{ 9U, 1U, 2U },
292*c217d954SCole Faust         TensorShape{ 1U, 13U, 2U },
293*c217d954SCole Faust         TensorShape{ 128U, 64U, 1U, 3U },
294*c217d954SCole Faust         TensorShape{ 9U, 1U, 3U },
295*c217d954SCole Faust         TensorShape{ 1U },
296*c217d954SCole Faust         TensorShape{ 9U, 9U, 3U, 5U },
297*c217d954SCole Faust         TensorShape{ 1U, 1U, 1U, 1U, 128U },
298*c217d954SCole Faust         TensorShape{ 128U }
299*c217d954SCole Faust     }))
300*c217d954SCole Faust     {
301*c217d954SCole Faust     }
302*c217d954SCole Faust };
303*c217d954SCole Faust 
304*c217d954SCole Faust class TemporaryLimitedSmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
305*c217d954SCole Faust {
306*c217d954SCole Faust public:
TemporaryLimitedSmallShapesBroadcast()307*c217d954SCole Faust     TemporaryLimitedSmallShapesBroadcast()
308*c217d954SCole Faust         : ZipDataset<ShapeDataset, ShapeDataset>(
309*c217d954SCole Faust               ShapeDataset("Shape0",
310*c217d954SCole Faust     {
311*c217d954SCole Faust         TensorShape{ 1U, 3U, 4U, 2U },  // LHS broadcast X
312*c217d954SCole Faust         TensorShape{ 6U, 4U, 2U, 3U },  // RHS broadcast X
313*c217d954SCole Faust         TensorShape{ 7U, 1U, 1U, 4U },  // LHS broadcast Y, Z
314*c217d954SCole Faust         TensorShape{ 8U, 5U, 6U, 3U },  // RHS broadcast Y, Z
315*c217d954SCole Faust         TensorShape{ 1U, 1U, 1U, 2U },  // LHS broadcast X, Y, Z
316*c217d954SCole Faust         TensorShape{ 2U, 6U, 4U, 3U },  // RHS broadcast X, Y, Z
317*c217d954SCole Faust     }),
318*c217d954SCole Faust     ShapeDataset("Shape1",
319*c217d954SCole Faust     {
320*c217d954SCole Faust         TensorShape{ 5U, 3U, 4U, 2U },
321*c217d954SCole Faust         TensorShape{ 1U, 4U, 2U, 3U },
322*c217d954SCole Faust         TensorShape{ 7U, 2U, 3U, 4U },
323*c217d954SCole Faust         TensorShape{ 8U, 1U, 1U, 3U },
324*c217d954SCole Faust         TensorShape{ 4U, 7U, 3U, 2U },
325*c217d954SCole Faust         TensorShape{ 1U, 1U, 1U, 3U },
326*c217d954SCole Faust     }))
327*c217d954SCole Faust     {
328*c217d954SCole Faust     }
329*c217d954SCole Faust };
330*c217d954SCole Faust 
331*c217d954SCole Faust class TemporaryLimitedLargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
332*c217d954SCole Faust {
333*c217d954SCole Faust public:
TemporaryLimitedLargeShapesBroadcast()334*c217d954SCole Faust     TemporaryLimitedLargeShapesBroadcast()
335*c217d954SCole Faust         : ZipDataset<ShapeDataset, ShapeDataset>(
336*c217d954SCole Faust               ShapeDataset("Shape0",
337*c217d954SCole Faust     {
338*c217d954SCole Faust         TensorShape{ 127U, 25U, 5U },
339*c217d954SCole Faust                      TensorShape{ 485, 40U, 10U }
340*c217d954SCole Faust     }),
341*c217d954SCole Faust     ShapeDataset("Shape1",
342*c217d954SCole Faust     {
343*c217d954SCole Faust         TensorShape{ 1U, 1U, 1U },   // Broadcast in X, Y, Z
344*c217d954SCole Faust         TensorShape{ 485U, 1U, 1U }, // Broadcast in Y, Z
345*c217d954SCole Faust     }))
346*c217d954SCole Faust     {
347*c217d954SCole Faust     }
348*c217d954SCole Faust };
349*c217d954SCole Faust 
350*c217d954SCole Faust /** Data set containing medium tensor shapes. */
351*c217d954SCole Faust class MediumShapes final : public ShapeDataset
352*c217d954SCole Faust {
353*c217d954SCole Faust public:
MediumShapes()354*c217d954SCole Faust     MediumShapes()
355*c217d954SCole Faust         : ShapeDataset("Shape",
356*c217d954SCole Faust     {
357*c217d954SCole Faust         // Batch size 1
358*c217d954SCole Faust         TensorShape{ 37U, 37U },
359*c217d954SCole Faust                      TensorShape{ 27U, 33U, 2U },
360*c217d954SCole Faust                      // Arbitrary batch size
361*c217d954SCole Faust                      TensorShape{ 37U, 37U, 3U, 5U }
362*c217d954SCole Faust     })
363*c217d954SCole Faust     {
364*c217d954SCole Faust     }
365*c217d954SCole Faust };
366*c217d954SCole Faust 
367*c217d954SCole Faust /** Data set containing medium 2D tensor shapes. */
368*c217d954SCole Faust class Medium2DShapes final : public ShapeDataset
369*c217d954SCole Faust {
370*c217d954SCole Faust public:
Medium2DShapes()371*c217d954SCole Faust     Medium2DShapes()
372*c217d954SCole Faust         : ShapeDataset("Shape",
373*c217d954SCole Faust     {
374*c217d954SCole Faust         TensorShape{ 42U, 37U },
375*c217d954SCole Faust                      TensorShape{ 57U, 60U },
376*c217d954SCole Faust                      TensorShape{ 128U, 64U },
377*c217d954SCole Faust                      TensorShape{ 83U, 72U },
378*c217d954SCole Faust                      TensorShape{ 40U, 40U }
379*c217d954SCole Faust     })
380*c217d954SCole Faust     {
381*c217d954SCole Faust     }
382*c217d954SCole Faust };
383*c217d954SCole Faust 
384*c217d954SCole Faust /** Data set containing medium 3D tensor shapes. */
385*c217d954SCole Faust class Medium3DShapes final : public ShapeDataset
386*c217d954SCole Faust {
387*c217d954SCole Faust public:
Medium3DShapes()388*c217d954SCole Faust     Medium3DShapes()
389*c217d954SCole Faust         : ShapeDataset("Shape",
390*c217d954SCole Faust     {
391*c217d954SCole Faust         TensorShape{ 42U, 37U, 8U },
392*c217d954SCole Faust                      TensorShape{ 57U, 60U, 13U },
393*c217d954SCole Faust                      TensorShape{ 83U, 72U, 14U }
394*c217d954SCole Faust     })
395*c217d954SCole Faust     {
396*c217d954SCole Faust     }
397*c217d954SCole Faust };
398*c217d954SCole Faust 
399*c217d954SCole Faust /** Data set containing medium 4D tensor shapes. */
400*c217d954SCole Faust class Medium4DShapes final : public ShapeDataset
401*c217d954SCole Faust {
402*c217d954SCole Faust public:
Medium4DShapes()403*c217d954SCole Faust     Medium4DShapes()
404*c217d954SCole Faust         : ShapeDataset("Shape",
405*c217d954SCole Faust     {
406*c217d954SCole Faust         TensorShape{ 42U, 37U, 8U, 15U },
407*c217d954SCole Faust                      TensorShape{ 57U, 60U, 13U, 8U },
408*c217d954SCole Faust                      TensorShape{ 83U, 72U, 14U, 5U }
409*c217d954SCole Faust     })
410*c217d954SCole Faust     {
411*c217d954SCole Faust     }
412*c217d954SCole Faust };
413*c217d954SCole Faust 
414*c217d954SCole Faust /** Data set containing large tensor shapes. */
415*c217d954SCole Faust class LargeShapes final : public ShapeDataset
416*c217d954SCole Faust {
417*c217d954SCole Faust public:
LargeShapes()418*c217d954SCole Faust     LargeShapes()
419*c217d954SCole Faust         : ShapeDataset("Shape",
420*c217d954SCole Faust     {
421*c217d954SCole Faust         TensorShape{ 582U, 131U, 1U, 4U },
422*c217d954SCole Faust     })
423*c217d954SCole Faust     {
424*c217d954SCole Faust     }
425*c217d954SCole Faust };
426*c217d954SCole Faust 
427*c217d954SCole Faust /** Data set containing large tensor shapes. */
428*c217d954SCole Faust class LargeShapesNoBatches final : public ShapeDataset
429*c217d954SCole Faust {
430*c217d954SCole Faust public:
LargeShapesNoBatches()431*c217d954SCole Faust     LargeShapesNoBatches()
432*c217d954SCole Faust         : ShapeDataset("Shape",
433*c217d954SCole Faust     {
434*c217d954SCole Faust         TensorShape{ 582U, 131U, 2U },
435*c217d954SCole Faust     })
436*c217d954SCole Faust     {
437*c217d954SCole Faust     }
438*c217d954SCole Faust };
439*c217d954SCole Faust 
440*c217d954SCole Faust /** Data set containing pairs of large tensor shapes that are broadcast compatible. */
441*c217d954SCole Faust class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
442*c217d954SCole Faust {
443*c217d954SCole Faust public:
LargeShapesBroadcast()444*c217d954SCole Faust     LargeShapesBroadcast()
445*c217d954SCole Faust         : ZipDataset<ShapeDataset, ShapeDataset>(
446*c217d954SCole Faust               ShapeDataset("Shape0",
447*c217d954SCole Faust     {
448*c217d954SCole Faust         TensorShape{ 1921U, 541U },
449*c217d954SCole Faust                      TensorShape{ 1U, 485U, 2U, 3U },
450*c217d954SCole Faust                      TensorShape{ 4159U, 1U },
451*c217d954SCole Faust                      TensorShape{ 799U }
452*c217d954SCole Faust     }),
453*c217d954SCole Faust     ShapeDataset("Shape1",
454*c217d954SCole Faust     {
455*c217d954SCole Faust         TensorShape{ 1921U, 1U, 2U },
456*c217d954SCole Faust         TensorShape{ 641U, 1U, 2U, 3U },
457*c217d954SCole Faust         TensorShape{ 1U, 127U, 25U },
458*c217d954SCole Faust         TensorShape{ 799U, 595U, 1U, 4U }
459*c217d954SCole Faust     }))
460*c217d954SCole Faust     {
461*c217d954SCole Faust     }
462*c217d954SCole Faust };
463*c217d954SCole Faust 
464*c217d954SCole Faust /** Data set containing large 1D tensor shapes. */
465*c217d954SCole Faust class Large1DShapes final : public ShapeDataset
466*c217d954SCole Faust {
467*c217d954SCole Faust public:
Large1DShapes()468*c217d954SCole Faust     Large1DShapes()
469*c217d954SCole Faust         : ShapeDataset("Shape",
470*c217d954SCole Faust     {
471*c217d954SCole Faust         TensorShape{ 1245U }
472*c217d954SCole Faust     })
473*c217d954SCole Faust     {
474*c217d954SCole Faust     }
475*c217d954SCole Faust };
476*c217d954SCole Faust 
477*c217d954SCole Faust /** Data set containing large 2D tensor shapes. */
478*c217d954SCole Faust class Large2DShapes final : public ShapeDataset
479*c217d954SCole Faust {
480*c217d954SCole Faust public:
Large2DShapes()481*c217d954SCole Faust     Large2DShapes()
482*c217d954SCole Faust         : ShapeDataset("Shape",
483*c217d954SCole Faust     {
484*c217d954SCole Faust         TensorShape{ 1245U, 652U }
485*c217d954SCole Faust     })
486*c217d954SCole Faust     {
487*c217d954SCole Faust     }
488*c217d954SCole Faust };
489*c217d954SCole Faust 
490*c217d954SCole Faust /** Data set containing large 3D tensor shapes. */
491*c217d954SCole Faust class Large3DShapes final : public ShapeDataset
492*c217d954SCole Faust {
493*c217d954SCole Faust public:
Large3DShapes()494*c217d954SCole Faust     Large3DShapes()
495*c217d954SCole Faust         : ShapeDataset("Shape",
496*c217d954SCole Faust     {
497*c217d954SCole Faust         TensorShape{ 320U, 240U, 3U }
498*c217d954SCole Faust     })
499*c217d954SCole Faust     {
500*c217d954SCole Faust     }
501*c217d954SCole Faust };
502*c217d954SCole Faust 
503*c217d954SCole Faust /** Data set containing large 4D tensor shapes. */
504*c217d954SCole Faust class Large4DShapes final : public ShapeDataset
505*c217d954SCole Faust {
506*c217d954SCole Faust public:
Large4DShapes()507*c217d954SCole Faust     Large4DShapes()
508*c217d954SCole Faust         : ShapeDataset("Shape",
509*c217d954SCole Faust     {
510*c217d954SCole Faust         TensorShape{ 320U, 123U, 3U, 3U }
511*c217d954SCole Faust     })
512*c217d954SCole Faust     {
513*c217d954SCole Faust     }
514*c217d954SCole Faust };
515*c217d954SCole Faust 
516*c217d954SCole Faust /** Data set containing small 3x3 tensor shapes. */
517*c217d954SCole Faust class Small3x3Shapes final : public ShapeDataset
518*c217d954SCole Faust {
519*c217d954SCole Faust public:
Small3x3Shapes()520*c217d954SCole Faust     Small3x3Shapes()
521*c217d954SCole Faust         : ShapeDataset("Shape",
522*c217d954SCole Faust     {
523*c217d954SCole Faust         TensorShape{ 3U, 3U, 7U, 4U },
524*c217d954SCole Faust                      TensorShape{ 3U, 3U, 4U, 13U },
525*c217d954SCole Faust                      TensorShape{ 3U, 3U, 3U, 5U },
526*c217d954SCole Faust     })
527*c217d954SCole Faust     {
528*c217d954SCole Faust     }
529*c217d954SCole Faust };
530*c217d954SCole Faust 
531*c217d954SCole Faust /** Data set containing small 3x1 tensor shapes. */
532*c217d954SCole Faust class Small3x1Shapes final : public ShapeDataset
533*c217d954SCole Faust {
534*c217d954SCole Faust public:
Small3x1Shapes()535*c217d954SCole Faust     Small3x1Shapes()
536*c217d954SCole Faust         : ShapeDataset("Shape",
537*c217d954SCole Faust     {
538*c217d954SCole Faust         TensorShape{ 3U, 1U, 7U, 4U },
539*c217d954SCole Faust                      TensorShape{ 3U, 1U, 4U, 13U },
540*c217d954SCole Faust                      TensorShape{ 3U, 1U, 3U, 5U },
541*c217d954SCole Faust     })
542*c217d954SCole Faust     {
543*c217d954SCole Faust     }
544*c217d954SCole Faust };
545*c217d954SCole Faust 
546*c217d954SCole Faust /** Data set containing small 1x3 tensor shapes. */
547*c217d954SCole Faust class Small1x3Shapes final : public ShapeDataset
548*c217d954SCole Faust {
549*c217d954SCole Faust public:
Small1x3Shapes()550*c217d954SCole Faust     Small1x3Shapes()
551*c217d954SCole Faust         : ShapeDataset("Shape",
552*c217d954SCole Faust     {
553*c217d954SCole Faust         TensorShape{ 1U, 3U, 7U, 4U },
554*c217d954SCole Faust                      TensorShape{ 1U, 3U, 4U, 13U },
555*c217d954SCole Faust                      TensorShape{ 1U, 3U, 3U, 5U },
556*c217d954SCole Faust     })
557*c217d954SCole Faust     {
558*c217d954SCole Faust     }
559*c217d954SCole Faust };
560*c217d954SCole Faust 
561*c217d954SCole Faust /** Data set containing large 3x3 tensor shapes. */
562*c217d954SCole Faust class Large3x3Shapes final : public ShapeDataset
563*c217d954SCole Faust {
564*c217d954SCole Faust public:
Large3x3Shapes()565*c217d954SCole Faust     Large3x3Shapes()
566*c217d954SCole Faust         : ShapeDataset("Shape",
567*c217d954SCole Faust     {
568*c217d954SCole Faust         TensorShape{ 3U, 3U, 32U, 64U },
569*c217d954SCole Faust                      TensorShape{ 3U, 3U, 51U, 13U },
570*c217d954SCole Faust                      TensorShape{ 3U, 3U, 53U, 47U },
571*c217d954SCole Faust     })
572*c217d954SCole Faust     {
573*c217d954SCole Faust     }
574*c217d954SCole Faust };
575*c217d954SCole Faust 
576*c217d954SCole Faust /** Data set containing large 3x1 tensor shapes. */
577*c217d954SCole Faust class Large3x1Shapes final : public ShapeDataset
578*c217d954SCole Faust {
579*c217d954SCole Faust public:
Large3x1Shapes()580*c217d954SCole Faust     Large3x1Shapes()
581*c217d954SCole Faust         : ShapeDataset("Shape",
582*c217d954SCole Faust     {
583*c217d954SCole Faust         TensorShape{ 3U, 1U, 32U, 64U },
584*c217d954SCole Faust                      TensorShape{ 3U, 1U, 51U, 13U },
585*c217d954SCole Faust                      TensorShape{ 3U, 1U, 53U, 47U },
586*c217d954SCole Faust     })
587*c217d954SCole Faust     {
588*c217d954SCole Faust     }
589*c217d954SCole Faust };
590*c217d954SCole Faust 
591*c217d954SCole Faust /** Data set containing large 1x3 tensor shapes. */
592*c217d954SCole Faust class Large1x3Shapes final : public ShapeDataset
593*c217d954SCole Faust {
594*c217d954SCole Faust public:
Large1x3Shapes()595*c217d954SCole Faust     Large1x3Shapes()
596*c217d954SCole Faust         : ShapeDataset("Shape",
597*c217d954SCole Faust     {
598*c217d954SCole Faust         TensorShape{ 1U, 3U, 32U, 64U },
599*c217d954SCole Faust                      TensorShape{ 1U, 3U, 51U, 13U },
600*c217d954SCole Faust                      TensorShape{ 1U, 3U, 53U, 47U },
601*c217d954SCole Faust     })
602*c217d954SCole Faust     {
603*c217d954SCole Faust     }
604*c217d954SCole Faust };
605*c217d954SCole Faust 
606*c217d954SCole Faust /** Data set containing small 5x5 tensor shapes. */
607*c217d954SCole Faust class Small5x5Shapes final : public ShapeDataset
608*c217d954SCole Faust {
609*c217d954SCole Faust public:
Small5x5Shapes()610*c217d954SCole Faust     Small5x5Shapes()
611*c217d954SCole Faust         : ShapeDataset("Shape",
612*c217d954SCole Faust     {
613*c217d954SCole Faust         TensorShape{ 5U, 5U, 7U, 4U },
614*c217d954SCole Faust                      TensorShape{ 5U, 5U, 4U, 13U },
615*c217d954SCole Faust                      TensorShape{ 5U, 5U, 3U, 5U },
616*c217d954SCole Faust     })
617*c217d954SCole Faust     {
618*c217d954SCole Faust     }
619*c217d954SCole Faust };
620*c217d954SCole Faust 
621*c217d954SCole Faust /** Data set containing small 5D tensor shapes. */
622*c217d954SCole Faust class Small5dShapes final : public ShapeDataset
623*c217d954SCole Faust {
624*c217d954SCole Faust public:
Small5dShapes()625*c217d954SCole Faust     Small5dShapes()
626*c217d954SCole Faust         : ShapeDataset("Shape",
627*c217d954SCole Faust     {
628*c217d954SCole Faust         TensorShape{ 5U, 5U, 7U, 4U, 3U },
629*c217d954SCole Faust                      TensorShape{ 5U, 5U, 4U, 13U, 2U },
630*c217d954SCole Faust                      TensorShape{ 5U, 5U, 3U, 5U, 2U },
631*c217d954SCole Faust     })
632*c217d954SCole Faust     {
633*c217d954SCole Faust     }
634*c217d954SCole Faust };
635*c217d954SCole Faust 
636*c217d954SCole Faust /** Data set containing large 5x5 tensor shapes. */
637*c217d954SCole Faust class Large5x5Shapes final : public ShapeDataset
638*c217d954SCole Faust {
639*c217d954SCole Faust public:
Large5x5Shapes()640*c217d954SCole Faust     Large5x5Shapes()
641*c217d954SCole Faust         : ShapeDataset("Shape",
642*c217d954SCole Faust     {
643*c217d954SCole Faust         TensorShape{ 5U, 5U, 32U, 64U }
644*c217d954SCole Faust     })
645*c217d954SCole Faust     {
646*c217d954SCole Faust     }
647*c217d954SCole Faust };
648*c217d954SCole Faust 
649*c217d954SCole Faust /** Data set containing large 5D tensor shapes. */
650*c217d954SCole Faust class Large5dShapes final : public ShapeDataset
651*c217d954SCole Faust {
652*c217d954SCole Faust public:
Large5dShapes()653*c217d954SCole Faust     Large5dShapes()
654*c217d954SCole Faust         : ShapeDataset("Shape",
655*c217d954SCole Faust     {
656*c217d954SCole Faust         TensorShape{ 30U, 40U, 30U, 32U, 3U }
657*c217d954SCole Faust     })
658*c217d954SCole Faust     {
659*c217d954SCole Faust     }
660*c217d954SCole Faust };
661*c217d954SCole Faust 
662*c217d954SCole Faust /** Data set containing small 5x1 tensor shapes. */
663*c217d954SCole Faust class Small5x1Shapes final : public ShapeDataset
664*c217d954SCole Faust {
665*c217d954SCole Faust public:
Small5x1Shapes()666*c217d954SCole Faust     Small5x1Shapes()
667*c217d954SCole Faust         : ShapeDataset("Shape",
668*c217d954SCole Faust     {
669*c217d954SCole Faust         TensorShape{ 5U, 1U, 7U, 4U }
670*c217d954SCole Faust     })
671*c217d954SCole Faust     {
672*c217d954SCole Faust     }
673*c217d954SCole Faust };
674*c217d954SCole Faust 
675*c217d954SCole Faust /** Data set containing large 5x1 tensor shapes. */
676*c217d954SCole Faust class Large5x1Shapes final : public ShapeDataset
677*c217d954SCole Faust {
678*c217d954SCole Faust public:
Large5x1Shapes()679*c217d954SCole Faust     Large5x1Shapes()
680*c217d954SCole Faust         : ShapeDataset("Shape",
681*c217d954SCole Faust     {
682*c217d954SCole Faust         TensorShape{ 5U, 1U, 32U, 64U }
683*c217d954SCole Faust     })
684*c217d954SCole Faust     {
685*c217d954SCole Faust     }
686*c217d954SCole Faust };
687*c217d954SCole Faust 
688*c217d954SCole Faust /** Data set containing small 1x5 tensor shapes. */
689*c217d954SCole Faust class Small1x5Shapes final : public ShapeDataset
690*c217d954SCole Faust {
691*c217d954SCole Faust public:
Small1x5Shapes()692*c217d954SCole Faust     Small1x5Shapes()
693*c217d954SCole Faust         : ShapeDataset("Shape",
694*c217d954SCole Faust     {
695*c217d954SCole Faust         TensorShape{ 1U, 5U, 7U, 4U }
696*c217d954SCole Faust     })
697*c217d954SCole Faust     {
698*c217d954SCole Faust     }
699*c217d954SCole Faust };
700*c217d954SCole Faust 
701*c217d954SCole Faust /** Data set containing large 1x5 tensor shapes. */
702*c217d954SCole Faust class Large1x5Shapes final : public ShapeDataset
703*c217d954SCole Faust {
704*c217d954SCole Faust public:
Large1x5Shapes()705*c217d954SCole Faust     Large1x5Shapes()
706*c217d954SCole Faust         : ShapeDataset("Shape",
707*c217d954SCole Faust     {
708*c217d954SCole Faust         TensorShape{ 1U, 5U, 32U, 64U }
709*c217d954SCole Faust     })
710*c217d954SCole Faust     {
711*c217d954SCole Faust     }
712*c217d954SCole Faust };
713*c217d954SCole Faust 
714*c217d954SCole Faust /** Data set containing small 1x7 tensor shapes. */
715*c217d954SCole Faust class Small1x7Shapes final : public ShapeDataset
716*c217d954SCole Faust {
717*c217d954SCole Faust public:
Small1x7Shapes()718*c217d954SCole Faust     Small1x7Shapes()
719*c217d954SCole Faust         : ShapeDataset("Shape",
720*c217d954SCole Faust     {
721*c217d954SCole Faust         TensorShape{ 1U, 7U, 7U, 4U }
722*c217d954SCole Faust     })
723*c217d954SCole Faust     {
724*c217d954SCole Faust     }
725*c217d954SCole Faust };
726*c217d954SCole Faust 
727*c217d954SCole Faust /** Data set containing large 1x7 tensor shapes. */
728*c217d954SCole Faust class Large1x7Shapes final : public ShapeDataset
729*c217d954SCole Faust {
730*c217d954SCole Faust public:
Large1x7Shapes()731*c217d954SCole Faust     Large1x7Shapes()
732*c217d954SCole Faust         : ShapeDataset("Shape",
733*c217d954SCole Faust     {
734*c217d954SCole Faust         TensorShape{ 1U, 7U, 32U, 64U }
735*c217d954SCole Faust     })
736*c217d954SCole Faust     {
737*c217d954SCole Faust     }
738*c217d954SCole Faust };
739*c217d954SCole Faust 
740*c217d954SCole Faust /** Data set containing small 7x7 tensor shapes. */
741*c217d954SCole Faust class Small7x7Shapes final : public ShapeDataset
742*c217d954SCole Faust {
743*c217d954SCole Faust public:
Small7x7Shapes()744*c217d954SCole Faust     Small7x7Shapes()
745*c217d954SCole Faust         : ShapeDataset("Shape",
746*c217d954SCole Faust     {
747*c217d954SCole Faust         TensorShape{ 7U, 7U, 7U, 4U }
748*c217d954SCole Faust     })
749*c217d954SCole Faust     {
750*c217d954SCole Faust     }
751*c217d954SCole Faust };
752*c217d954SCole Faust 
753*c217d954SCole Faust /** Data set containing large 7x7 tensor shapes. */
754*c217d954SCole Faust class Large7x7Shapes final : public ShapeDataset
755*c217d954SCole Faust {
756*c217d954SCole Faust public:
Large7x7Shapes()757*c217d954SCole Faust     Large7x7Shapes()
758*c217d954SCole Faust         : ShapeDataset("Shape",
759*c217d954SCole Faust     {
760*c217d954SCole Faust         TensorShape{ 7U, 7U, 32U, 64U }
761*c217d954SCole Faust     })
762*c217d954SCole Faust     {
763*c217d954SCole Faust     }
764*c217d954SCole Faust };
765*c217d954SCole Faust 
766*c217d954SCole Faust /** Data set containing small 7x1 tensor shapes. */
767*c217d954SCole Faust class Small7x1Shapes final : public ShapeDataset
768*c217d954SCole Faust {
769*c217d954SCole Faust public:
Small7x1Shapes()770*c217d954SCole Faust     Small7x1Shapes()
771*c217d954SCole Faust         : ShapeDataset("Shape",
772*c217d954SCole Faust     {
773*c217d954SCole Faust         TensorShape{ 7U, 1U, 7U, 4U }
774*c217d954SCole Faust     })
775*c217d954SCole Faust     {
776*c217d954SCole Faust     }
777*c217d954SCole Faust };
778*c217d954SCole Faust 
779*c217d954SCole Faust /** Data set containing large 7x1 tensor shapes. */
780*c217d954SCole Faust class Large7x1Shapes final : public ShapeDataset
781*c217d954SCole Faust {
782*c217d954SCole Faust public:
Large7x1Shapes()783*c217d954SCole Faust     Large7x1Shapes()
784*c217d954SCole Faust         : ShapeDataset("Shape",
785*c217d954SCole Faust     {
786*c217d954SCole Faust         TensorShape{ 7U, 1U, 32U, 64U }
787*c217d954SCole Faust     })
788*c217d954SCole Faust     {
789*c217d954SCole Faust     }
790*c217d954SCole Faust };
791*c217d954SCole Faust 
792*c217d954SCole Faust /** Data set containing small tensor shapes for deconvolution. */
793*c217d954SCole Faust class SmallDeconvolutionShapes final : public ShapeDataset
794*c217d954SCole Faust {
795*c217d954SCole Faust public:
SmallDeconvolutionShapes()796*c217d954SCole Faust     SmallDeconvolutionShapes()
797*c217d954SCole Faust         : ShapeDataset("InputShape",
798*c217d954SCole Faust     {
799*c217d954SCole Faust         // Multiple Vector Loops for FP32
800*c217d954SCole Faust         TensorShape{ 5U, 4U, 3U, 2U },
801*c217d954SCole Faust                      TensorShape{ 5U, 5U, 3U },
802*c217d954SCole Faust                      TensorShape{ 11U, 13U, 4U, 3U }
803*c217d954SCole Faust     })
804*c217d954SCole Faust     {
805*c217d954SCole Faust     }
806*c217d954SCole Faust };
807*c217d954SCole Faust 
808*c217d954SCole Faust class SmallDeconvolutionShapesWithLargerChannels final : public ShapeDataset
809*c217d954SCole Faust {
810*c217d954SCole Faust public:
SmallDeconvolutionShapesWithLargerChannels()811*c217d954SCole Faust     SmallDeconvolutionShapesWithLargerChannels()
812*c217d954SCole Faust         : ShapeDataset("InputShape",
813*c217d954SCole Faust     {
814*c217d954SCole Faust         // Multiple Vector Loops for all data types
815*c217d954SCole Faust         TensorShape{ 5U, 5U, 35U }
816*c217d954SCole Faust     })
817*c217d954SCole Faust     {
818*c217d954SCole Faust     }
819*c217d954SCole Faust };
820*c217d954SCole Faust 
821*c217d954SCole Faust /** Data set containing tiny tensor shapes for direct convolution. */
822*c217d954SCole Faust class TinyDirectConvolutionShapes final : public ShapeDataset
823*c217d954SCole Faust {
824*c217d954SCole Faust public:
TinyDirectConvolutionShapes()825*c217d954SCole Faust     TinyDirectConvolutionShapes()
826*c217d954SCole Faust         : ShapeDataset("InputShape",
827*c217d954SCole Faust     {
828*c217d954SCole Faust         // Batch size 1
829*c217d954SCole Faust         TensorShape{ 11U, 13U, 3U },
830*c217d954SCole Faust                      TensorShape{ 7U, 27U, 3U }
831*c217d954SCole Faust     })
832*c217d954SCole Faust     {
833*c217d954SCole Faust     }
834*c217d954SCole Faust };
835*c217d954SCole Faust /** Data set containing small tensor shapes for direct convolution. */
836*c217d954SCole Faust class SmallDirectConvolutionShapes final : public ShapeDataset
837*c217d954SCole Faust {
838*c217d954SCole Faust public:
SmallDirectConvolutionShapes()839*c217d954SCole Faust     SmallDirectConvolutionShapes()
840*c217d954SCole Faust         : ShapeDataset("InputShape",
841*c217d954SCole Faust     {
842*c217d954SCole Faust         // Batch size 1
843*c217d954SCole Faust         TensorShape{ 32U, 37U, 3U },
844*c217d954SCole Faust                      // Batch size 4
845*c217d954SCole Faust                      TensorShape{ 6U, 9U, 5U, 4U },
846*c217d954SCole Faust     })
847*c217d954SCole Faust     {
848*c217d954SCole Faust     }
849*c217d954SCole Faust };
850*c217d954SCole Faust 
851*c217d954SCole Faust class SmallDirectConv3DShapes final : public ShapeDataset
852*c217d954SCole Faust {
853*c217d954SCole Faust public:
SmallDirectConv3DShapes()854*c217d954SCole Faust     SmallDirectConv3DShapes()
855*c217d954SCole Faust         : ShapeDataset("InputShape",
856*c217d954SCole Faust     {
857*c217d954SCole Faust         // Batch size 2
858*c217d954SCole Faust         TensorShape{ 1U, 3U, 4U, 5U, 2U },
859*c217d954SCole Faust                      // Batch size 3
860*c217d954SCole Faust                      TensorShape{ 7U, 27U, 3U, 6U, 3U },
861*c217d954SCole Faust                      // Batch size 1
862*c217d954SCole Faust                      TensorShape{ 32U, 37U, 13U, 1U, 1U },
863*c217d954SCole Faust     })
864*c217d954SCole Faust     {
865*c217d954SCole Faust     }
866*c217d954SCole Faust };
867*c217d954SCole Faust 
868*c217d954SCole Faust /** Data set containing small tensor shapes for direct convolution. */
869*c217d954SCole Faust class SmallDirectConvolutionTensorShiftShapes final : public ShapeDataset
870*c217d954SCole Faust {
871*c217d954SCole Faust public:
SmallDirectConvolutionTensorShiftShapes()872*c217d954SCole Faust     SmallDirectConvolutionTensorShiftShapes()
873*c217d954SCole Faust         : ShapeDataset("InputShape",
874*c217d954SCole Faust     {
875*c217d954SCole Faust         // Batch size 1
876*c217d954SCole Faust         TensorShape{ 32U, 37U, 3U },
877*c217d954SCole Faust                      // Batch size 4
878*c217d954SCole Faust                      TensorShape{ 32U, 37U, 3U, 4U },
879*c217d954SCole Faust                      // Arbitrary batch size
880*c217d954SCole Faust                      TensorShape{ 32U, 37U, 3U, 8U }
881*c217d954SCole Faust     })
882*c217d954SCole Faust     {
883*c217d954SCole Faust     }
884*c217d954SCole Faust };
885*c217d954SCole Faust 
886*c217d954SCole Faust /** Data set containing small grouped im2col tensor shapes. */
887*c217d954SCole Faust class GroupedIm2ColSmallShapes final : public ShapeDataset
888*c217d954SCole Faust {
889*c217d954SCole Faust public:
GroupedIm2ColSmallShapes()890*c217d954SCole Faust     GroupedIm2ColSmallShapes()
891*c217d954SCole Faust         : ShapeDataset("Shape",
892*c217d954SCole Faust     {
893*c217d954SCole Faust         TensorShape{ 11U, 11U, 48U },
894*c217d954SCole Faust                      TensorShape{ 27U, 13U, 24U },
895*c217d954SCole Faust                      TensorShape{ 128U, 64U, 12U, 3U },
896*c217d954SCole Faust                      TensorShape{ 11U, 11U, 48U, 4U },
897*c217d954SCole Faust                      TensorShape{ 27U, 13U, 24U, 4U },
898*c217d954SCole Faust                      TensorShape{ 11U, 11U, 48U, 5U }
899*c217d954SCole Faust     })
900*c217d954SCole Faust     {
901*c217d954SCole Faust     }
902*c217d954SCole Faust };
903*c217d954SCole Faust 
904*c217d954SCole Faust /** Data set containing large grouped im2col tensor shapes. */
905*c217d954SCole Faust class GroupedIm2ColLargeShapes final : public ShapeDataset
906*c217d954SCole Faust {
907*c217d954SCole Faust public:
GroupedIm2ColLargeShapes()908*c217d954SCole Faust     GroupedIm2ColLargeShapes()
909*c217d954SCole Faust         : ShapeDataset("Shape",
910*c217d954SCole Faust     {
911*c217d954SCole Faust         TensorShape{ 153U, 231U, 12U },
912*c217d954SCole Faust                      TensorShape{ 123U, 191U, 12U, 2U },
913*c217d954SCole Faust     })
914*c217d954SCole Faust     {
915*c217d954SCole Faust     }
916*c217d954SCole Faust };
917*c217d954SCole Faust 
918*c217d954SCole Faust /** Data set containing small grouped weights tensor shapes. */
919*c217d954SCole Faust class GroupedWeightsSmallShapes final : public ShapeDataset
920*c217d954SCole Faust {
921*c217d954SCole Faust public:
GroupedWeightsSmallShapes()922*c217d954SCole Faust     GroupedWeightsSmallShapes()
923*c217d954SCole Faust         : ShapeDataset("Shape",
924*c217d954SCole Faust     {
925*c217d954SCole Faust         TensorShape{ 3U, 3U, 48U, 120U },
926*c217d954SCole Faust                      TensorShape{ 1U, 3U, 24U, 240U },
927*c217d954SCole Faust                      TensorShape{ 3U, 1U, 12U, 480U },
928*c217d954SCole Faust                      TensorShape{ 5U, 5U, 48U, 120U }
929*c217d954SCole Faust     })
930*c217d954SCole Faust     {
931*c217d954SCole Faust     }
932*c217d954SCole Faust };
933*c217d954SCole Faust 
934*c217d954SCole Faust /** Data set containing large grouped weights tensor shapes. */
935*c217d954SCole Faust class GroupedWeightsLargeShapes final : public ShapeDataset
936*c217d954SCole Faust {
937*c217d954SCole Faust public:
GroupedWeightsLargeShapes()938*c217d954SCole Faust     GroupedWeightsLargeShapes()
939*c217d954SCole Faust         : ShapeDataset("Shape",
940*c217d954SCole Faust     {
941*c217d954SCole Faust         TensorShape{ 9U, 9U, 96U, 240U },
942*c217d954SCole Faust                      TensorShape{ 13U, 13U, 96U, 240U }
943*c217d954SCole Faust     })
944*c217d954SCole Faust     {
945*c217d954SCole Faust     }
946*c217d954SCole Faust };
947*c217d954SCole Faust 
948*c217d954SCole Faust /** Data set containing 2D tensor shapes for DepthConcatenateLayer. */
949*c217d954SCole Faust class DepthConcatenateLayerShapes final : public ShapeDataset
950*c217d954SCole Faust {
951*c217d954SCole Faust public:
DepthConcatenateLayerShapes()952*c217d954SCole Faust     DepthConcatenateLayerShapes()
953*c217d954SCole Faust         : ShapeDataset("Shape",
954*c217d954SCole Faust     {
955*c217d954SCole Faust         TensorShape{ 322U, 243U },
956*c217d954SCole Faust                      TensorShape{ 463U, 879U },
957*c217d954SCole Faust                      TensorShape{ 416U, 651U }
958*c217d954SCole Faust     })
959*c217d954SCole Faust     {
960*c217d954SCole Faust     }
961*c217d954SCole Faust };
962*c217d954SCole Faust 
963*c217d954SCole Faust /** Data set containing tensor shapes for ConcatenateLayer. */
964*c217d954SCole Faust class ConcatenateLayerShapes final : public ShapeDataset
965*c217d954SCole Faust {
966*c217d954SCole Faust public:
ConcatenateLayerShapes()967*c217d954SCole Faust     ConcatenateLayerShapes()
968*c217d954SCole Faust         : ShapeDataset("Shape",
969*c217d954SCole Faust     {
970*c217d954SCole Faust         TensorShape{ 232U, 65U, 3U },
971*c217d954SCole Faust                      TensorShape{ 432U, 65U, 3U },
972*c217d954SCole Faust                      TensorShape{ 124U, 65U, 3U },
973*c217d954SCole Faust                      TensorShape{ 124U, 65U, 3U, 4U }
974*c217d954SCole Faust     })
975*c217d954SCole Faust     {
976*c217d954SCole Faust     }
977*c217d954SCole Faust };
978*c217d954SCole Faust 
979*c217d954SCole Faust /** Data set containing global pooling tensor shapes. */
980*c217d954SCole Faust class GlobalPoolingShapes final : public ShapeDataset
981*c217d954SCole Faust {
982*c217d954SCole Faust public:
GlobalPoolingShapes()983*c217d954SCole Faust     GlobalPoolingShapes()
984*c217d954SCole Faust         : ShapeDataset("Shape",
985*c217d954SCole Faust     {
986*c217d954SCole Faust         // Batch size 1
987*c217d954SCole Faust         TensorShape{ 9U, 9U },
988*c217d954SCole Faust                      TensorShape{ 13U, 13U, 2U },
989*c217d954SCole Faust                      TensorShape{ 27U, 27U, 1U, 3U },
990*c217d954SCole Faust                      // Batch size 4
991*c217d954SCole Faust                      TensorShape{ 31U, 31U, 3U, 4U },
992*c217d954SCole Faust                      TensorShape{ 34U, 34U, 2U, 4U }
993*c217d954SCole Faust     })
994*c217d954SCole Faust     {
995*c217d954SCole Faust     }
996*c217d954SCole Faust };
997*c217d954SCole Faust /** Data set containing tiny softmax layer shapes. */
998*c217d954SCole Faust class SoftmaxLayerTinyShapes final : public ShapeDataset
999*c217d954SCole Faust {
1000*c217d954SCole Faust public:
SoftmaxLayerTinyShapes()1001*c217d954SCole Faust     SoftmaxLayerTinyShapes()
1002*c217d954SCole Faust         : ShapeDataset("Shape",
1003*c217d954SCole Faust     {
1004*c217d954SCole Faust         TensorShape{ 9U, 9U },
1005*c217d954SCole Faust                      TensorShape{ 128U, 10U },
1006*c217d954SCole Faust     })
1007*c217d954SCole Faust     {
1008*c217d954SCole Faust     }
1009*c217d954SCole Faust };
1010*c217d954SCole Faust 
1011*c217d954SCole Faust /** Data set containing small softmax layer shapes. */
1012*c217d954SCole Faust class SoftmaxLayerSmallShapes final : public ShapeDataset
1013*c217d954SCole Faust {
1014*c217d954SCole Faust public:
SoftmaxLayerSmallShapes()1015*c217d954SCole Faust     SoftmaxLayerSmallShapes()
1016*c217d954SCole Faust         : ShapeDataset("Shape",
1017*c217d954SCole Faust     {
1018*c217d954SCole Faust         TensorShape{ 1U, 9U },
1019*c217d954SCole Faust                      TensorShape{ 256U, 10U },
1020*c217d954SCole Faust                      TensorShape{ 353U, 8U },
1021*c217d954SCole Faust                      TensorShape{ 781U, 5U },
1022*c217d954SCole Faust     })
1023*c217d954SCole Faust     {
1024*c217d954SCole Faust     }
1025*c217d954SCole Faust };
1026*c217d954SCole Faust 
1027*c217d954SCole Faust /** Data set containing large softmax layer shapes. */
1028*c217d954SCole Faust class SoftmaxLayerLargeShapes final : public ShapeDataset
1029*c217d954SCole Faust {
1030*c217d954SCole Faust public:
SoftmaxLayerLargeShapes()1031*c217d954SCole Faust     SoftmaxLayerLargeShapes()
1032*c217d954SCole Faust         : ShapeDataset("Shape",
1033*c217d954SCole Faust     {
1034*c217d954SCole Faust         TensorShape{ 1000U, 10U }
1035*c217d954SCole Faust 
1036*c217d954SCole Faust     })
1037*c217d954SCole Faust     {
1038*c217d954SCole Faust     }
1039*c217d954SCole Faust };
1040*c217d954SCole Faust 
1041*c217d954SCole Faust /** Data set containing large and small softmax layer 4D shapes. */
1042*c217d954SCole Faust class SoftmaxLayer4DShapes final : public ShapeDataset
1043*c217d954SCole Faust {
1044*c217d954SCole Faust public:
SoftmaxLayer4DShapes()1045*c217d954SCole Faust     SoftmaxLayer4DShapes()
1046*c217d954SCole Faust         : ShapeDataset("Shape",
1047*c217d954SCole Faust     {
1048*c217d954SCole Faust         TensorShape{ 9U, 9U, 9U, 9U },
1049*c217d954SCole Faust                      TensorShape{ 31U, 10U, 1U, 9U },
1050*c217d954SCole Faust     })
1051*c217d954SCole Faust     {
1052*c217d954SCole Faust     }
1053*c217d954SCole Faust };
1054*c217d954SCole Faust 
1055*c217d954SCole Faust /** Data set containing 2D tensor shapes relative to an image size. */
1056*c217d954SCole Faust class SmallImageShapes final : public ShapeDataset
1057*c217d954SCole Faust {
1058*c217d954SCole Faust public:
SmallImageShapes()1059*c217d954SCole Faust     SmallImageShapes()
1060*c217d954SCole Faust         : ShapeDataset("Shape",
1061*c217d954SCole Faust     {
1062*c217d954SCole Faust         TensorShape{ 640U, 480U },
1063*c217d954SCole Faust                      TensorShape{ 800U, 600U },
1064*c217d954SCole Faust     })
1065*c217d954SCole Faust     {
1066*c217d954SCole Faust     }
1067*c217d954SCole Faust };
1068*c217d954SCole Faust 
1069*c217d954SCole Faust /** Data set containing 2D tensor shapes relative to an image size. */
1070*c217d954SCole Faust class LargeImageShapes final : public ShapeDataset
1071*c217d954SCole Faust {
1072*c217d954SCole Faust public:
LargeImageShapes()1073*c217d954SCole Faust     LargeImageShapes()
1074*c217d954SCole Faust         : ShapeDataset("Shape",
1075*c217d954SCole Faust     {
1076*c217d954SCole Faust         TensorShape{ 1920U, 1080U },
1077*c217d954SCole Faust                      TensorShape{ 2560U, 1536U },
1078*c217d954SCole Faust                      TensorShape{ 3584U, 2048U }
1079*c217d954SCole Faust     })
1080*c217d954SCole Faust     {
1081*c217d954SCole Faust     }
1082*c217d954SCole Faust };
1083*c217d954SCole Faust 
1084*c217d954SCole Faust /** Data set containing small YOLO tensor shapes. */
1085*c217d954SCole Faust class SmallYOLOShapes final : public ShapeDataset
1086*c217d954SCole Faust {
1087*c217d954SCole Faust public:
SmallYOLOShapes()1088*c217d954SCole Faust     SmallYOLOShapes()
1089*c217d954SCole Faust         : ShapeDataset("Shape",
1090*c217d954SCole Faust     {
1091*c217d954SCole Faust         // Batch size 1
1092*c217d954SCole Faust         TensorShape{ 11U, 11U, 270U },
1093*c217d954SCole Faust                      TensorShape{ 27U, 13U, 90U },
1094*c217d954SCole Faust                      TensorShape{ 13U, 12U, 45U, 2U },
1095*c217d954SCole Faust     })
1096*c217d954SCole Faust     {
1097*c217d954SCole Faust     }
1098*c217d954SCole Faust };
1099*c217d954SCole Faust 
1100*c217d954SCole Faust /** Data set containing large YOLO tensor shapes. */
1101*c217d954SCole Faust class LargeYOLOShapes final : public ShapeDataset
1102*c217d954SCole Faust {
1103*c217d954SCole Faust public:
LargeYOLOShapes()1104*c217d954SCole Faust     LargeYOLOShapes()
1105*c217d954SCole Faust         : ShapeDataset("Shape",
1106*c217d954SCole Faust     {
1107*c217d954SCole Faust         TensorShape{ 24U, 23U, 270U },
1108*c217d954SCole Faust                      TensorShape{ 51U, 63U, 90U, 2U },
1109*c217d954SCole Faust                      TensorShape{ 76U, 91U, 45U, 3U }
1110*c217d954SCole Faust     })
1111*c217d954SCole Faust     {
1112*c217d954SCole Faust     }
1113*c217d954SCole Faust };
1114*c217d954SCole Faust 
1115*c217d954SCole Faust /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel */
1116*c217d954SCole Faust class SmallGEMMReshape2DShapes final : public ShapeDataset
1117*c217d954SCole Faust {
1118*c217d954SCole Faust public:
SmallGEMMReshape2DShapes()1119*c217d954SCole Faust     SmallGEMMReshape2DShapes()
1120*c217d954SCole Faust         : ShapeDataset("Shape",
1121*c217d954SCole Faust     {
1122*c217d954SCole Faust         TensorShape{ 63U, 72U },
1123*c217d954SCole Faust     })
1124*c217d954SCole Faust     {
1125*c217d954SCole Faust     }
1126*c217d954SCole Faust };
1127*c217d954SCole Faust 
1128*c217d954SCole Faust /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */
1129*c217d954SCole Faust class SmallGEMMReshape3DShapes final : public ShapeDataset
1130*c217d954SCole Faust {
1131*c217d954SCole Faust public:
SmallGEMMReshape3DShapes()1132*c217d954SCole Faust     SmallGEMMReshape3DShapes()
1133*c217d954SCole Faust         : ShapeDataset("Shape",
1134*c217d954SCole Faust     {
1135*c217d954SCole Faust         TensorShape{ 63U, 9U, 8U },
1136*c217d954SCole Faust     })
1137*c217d954SCole Faust     {
1138*c217d954SCole Faust     }
1139*c217d954SCole Faust };
1140*c217d954SCole Faust 
1141*c217d954SCole Faust /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel */
1142*c217d954SCole Faust class LargeGEMMReshape2DShapes final : public ShapeDataset
1143*c217d954SCole Faust {
1144*c217d954SCole Faust public:
LargeGEMMReshape2DShapes()1145*c217d954SCole Faust     LargeGEMMReshape2DShapes()
1146*c217d954SCole Faust         : ShapeDataset("Shape",
1147*c217d954SCole Faust     {
1148*c217d954SCole Faust         TensorShape{ 16U, 27U },
1149*c217d954SCole Faust                      TensorShape{ 345U, 171U }
1150*c217d954SCole Faust     })
1151*c217d954SCole Faust     {
1152*c217d954SCole Faust     }
1153*c217d954SCole Faust };
1154*c217d954SCole Faust 
1155*c217d954SCole Faust /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */
1156*c217d954SCole Faust class LargeGEMMReshape3DShapes final : public ShapeDataset
1157*c217d954SCole Faust {
1158*c217d954SCole Faust public:
LargeGEMMReshape3DShapes()1159*c217d954SCole Faust     LargeGEMMReshape3DShapes()
1160*c217d954SCole Faust         : ShapeDataset("Shape",
1161*c217d954SCole Faust     {
1162*c217d954SCole Faust         TensorShape{ 16U, 3U, 9U },
1163*c217d954SCole Faust                      TensorShape{ 345U, 34U, 18U }
1164*c217d954SCole Faust     })
1165*c217d954SCole Faust     {
1166*c217d954SCole Faust     }
1167*c217d954SCole Faust };
1168*c217d954SCole Faust 
1169*c217d954SCole Faust /** Data set containing small 2D tensor shapes. */
1170*c217d954SCole Faust class Small2DNonMaxSuppressionShapes final : public ShapeDataset
1171*c217d954SCole Faust {
1172*c217d954SCole Faust public:
Small2DNonMaxSuppressionShapes()1173*c217d954SCole Faust     Small2DNonMaxSuppressionShapes()
1174*c217d954SCole Faust         : ShapeDataset("Shape",
1175*c217d954SCole Faust     {
1176*c217d954SCole Faust         TensorShape{ 4U, 7U },
1177*c217d954SCole Faust                      TensorShape{ 4U, 13U },
1178*c217d954SCole Faust                      TensorShape{ 4U, 64U }
1179*c217d954SCole Faust     })
1180*c217d954SCole Faust     {
1181*c217d954SCole Faust     }
1182*c217d954SCole Faust };
1183*c217d954SCole Faust 
1184*c217d954SCole Faust /** Data set containing large 2D tensor shapes. */
1185*c217d954SCole Faust class Large2DNonMaxSuppressionShapes final : public ShapeDataset
1186*c217d954SCole Faust {
1187*c217d954SCole Faust public:
Large2DNonMaxSuppressionShapes()1188*c217d954SCole Faust     Large2DNonMaxSuppressionShapes()
1189*c217d954SCole Faust         : ShapeDataset("Shape",
1190*c217d954SCole Faust     {
1191*c217d954SCole Faust         TensorShape{ 4U, 113U }
1192*c217d954SCole Faust     })
1193*c217d954SCole Faust     {
1194*c217d954SCole Faust     }
1195*c217d954SCole Faust };
1196*c217d954SCole Faust 
1197*c217d954SCole Faust } // namespace datasets
1198*c217d954SCole Faust } // namespace test
1199*c217d954SCole Faust } // namespace arm_compute
1200*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_SHAPE_DATASETS_H */
1201