1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2017-2023 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_SHAPE_DATASETS_H 25*c217d954SCole Faust #define ARM_COMPUTE_TEST_SHAPE_DATASETS_H 26*c217d954SCole Faust 27*c217d954SCole Faust #include "arm_compute/core/TensorShape.h" 28*c217d954SCole Faust #include "tests/framework/datasets/Datasets.h" 29*c217d954SCole Faust 30*c217d954SCole Faust #include <type_traits> 31*c217d954SCole Faust 32*c217d954SCole Faust namespace arm_compute 33*c217d954SCole Faust { 34*c217d954SCole Faust namespace test 35*c217d954SCole Faust { 36*c217d954SCole Faust namespace datasets 37*c217d954SCole Faust { 38*c217d954SCole Faust /** Parent type for all for shape datasets. */ 39*c217d954SCole Faust using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>; 40*c217d954SCole Faust 41*c217d954SCole Faust /** Data set containing tiny 1D tensor shapes. */ 42*c217d954SCole Faust class Tiny1DShapes final : public ShapeDataset 43*c217d954SCole Faust { 44*c217d954SCole Faust public: Tiny1DShapes()45*c217d954SCole Faust Tiny1DShapes() 46*c217d954SCole Faust : ShapeDataset("Shape", 47*c217d954SCole Faust { 48*c217d954SCole Faust TensorShape{ 2U }, 49*c217d954SCole Faust TensorShape{ 3U }, 50*c217d954SCole Faust }) 51*c217d954SCole Faust { 52*c217d954SCole Faust } 53*c217d954SCole Faust }; 54*c217d954SCole Faust 55*c217d954SCole Faust /** Data set containing small 1D tensor shapes. */ 56*c217d954SCole Faust class Small1DShapes final : public ShapeDataset 57*c217d954SCole Faust { 58*c217d954SCole Faust public: Small1DShapes()59*c217d954SCole Faust Small1DShapes() 60*c217d954SCole Faust : ShapeDataset("Shape", 61*c217d954SCole Faust { 62*c217d954SCole Faust TensorShape{ 128U }, 63*c217d954SCole Faust TensorShape{ 256U }, 64*c217d954SCole Faust TensorShape{ 512U }, 65*c217d954SCole Faust TensorShape{ 1024U } 66*c217d954SCole Faust }) 67*c217d954SCole Faust { 68*c217d954SCole Faust } 69*c217d954SCole Faust }; 70*c217d954SCole Faust 71*c217d954SCole Faust /** Data set containing tiny 2D tensor shapes. */ 72*c217d954SCole Faust class Tiny2DShapes final : public ShapeDataset 73*c217d954SCole Faust { 74*c217d954SCole Faust public: Tiny2DShapes()75*c217d954SCole Faust Tiny2DShapes() 76*c217d954SCole Faust : ShapeDataset("Shape", 77*c217d954SCole Faust { 78*c217d954SCole Faust TensorShape{ 7U, 7U }, 79*c217d954SCole Faust TensorShape{ 11U, 13U }, 80*c217d954SCole Faust }) 81*c217d954SCole Faust { 82*c217d954SCole Faust } 83*c217d954SCole Faust }; 84*c217d954SCole Faust /** Data set containing small 2D tensor shapes. */ 85*c217d954SCole Faust class Small2DShapes final : public ShapeDataset 86*c217d954SCole Faust { 87*c217d954SCole Faust public: Small2DShapes()88*c217d954SCole Faust Small2DShapes() 89*c217d954SCole Faust : ShapeDataset("Shape", 90*c217d954SCole Faust { 91*c217d954SCole Faust TensorShape{ 1U, 7U }, 92*c217d954SCole Faust TensorShape{ 5U, 13U }, 93*c217d954SCole Faust TensorShape{ 32U, 64U } 94*c217d954SCole Faust }) 95*c217d954SCole Faust { 96*c217d954SCole Faust } 97*c217d954SCole Faust }; 98*c217d954SCole Faust 99*c217d954SCole Faust /** Data set containing tiny 3D tensor shapes. */ 100*c217d954SCole Faust class Tiny3DShapes final : public ShapeDataset 101*c217d954SCole Faust { 102*c217d954SCole Faust public: Tiny3DShapes()103*c217d954SCole Faust Tiny3DShapes() 104*c217d954SCole Faust : ShapeDataset("Shape", 105*c217d954SCole Faust { 106*c217d954SCole Faust TensorShape{ 7U, 7U, 5U }, 107*c217d954SCole Faust TensorShape{ 23U, 13U, 9U }, 108*c217d954SCole Faust }) 109*c217d954SCole Faust { 110*c217d954SCole Faust } 111*c217d954SCole Faust }; 112*c217d954SCole Faust 113*c217d954SCole Faust /** Data set containing small 3D tensor shapes. */ 114*c217d954SCole Faust class Small3DShapes final : public ShapeDataset 115*c217d954SCole Faust { 116*c217d954SCole Faust public: Small3DShapes()117*c217d954SCole Faust Small3DShapes() 118*c217d954SCole Faust : ShapeDataset("Shape", 119*c217d954SCole Faust { 120*c217d954SCole Faust TensorShape{ 1U, 7U, 7U }, 121*c217d954SCole Faust TensorShape{ 2U, 5U, 4U }, 122*c217d954SCole Faust 123*c217d954SCole Faust TensorShape{ 7U, 7U, 5U }, 124*c217d954SCole Faust TensorShape{ 16U, 16U, 5U }, 125*c217d954SCole Faust TensorShape{ 27U, 13U, 37U }, 126*c217d954SCole Faust }) 127*c217d954SCole Faust { 128*c217d954SCole Faust } 129*c217d954SCole Faust }; 130*c217d954SCole Faust 131*c217d954SCole Faust /** Data set containing tiny 4D tensor shapes. */ 132*c217d954SCole Faust class Tiny4DShapes final : public ShapeDataset 133*c217d954SCole Faust { 134*c217d954SCole Faust public: Tiny4DShapes()135*c217d954SCole Faust Tiny4DShapes() 136*c217d954SCole Faust : ShapeDataset("Shape", 137*c217d954SCole Faust { 138*c217d954SCole Faust TensorShape{ 2U, 7U, 5U, 3U }, 139*c217d954SCole Faust TensorShape{ 17U, 13U, 7U, 2U }, 140*c217d954SCole Faust }) 141*c217d954SCole Faust { 142*c217d954SCole Faust } 143*c217d954SCole Faust }; 144*c217d954SCole Faust /** Data set containing small 4D tensor shapes. */ 145*c217d954SCole Faust class Small4DShapes final : public ShapeDataset 146*c217d954SCole Faust { 147*c217d954SCole Faust public: Small4DShapes()148*c217d954SCole Faust Small4DShapes() 149*c217d954SCole Faust : ShapeDataset("Shape", 150*c217d954SCole Faust { 151*c217d954SCole Faust TensorShape{ 2U, 7U, 1U, 3U }, 152*c217d954SCole Faust TensorShape{ 7U, 7U, 5U, 3U }, 153*c217d954SCole Faust TensorShape{ 27U, 13U, 37U, 2U }, 154*c217d954SCole Faust TensorShape{ 128U, 64U, 21U, 3U } 155*c217d954SCole Faust }) 156*c217d954SCole Faust { 157*c217d954SCole Faust } 158*c217d954SCole Faust }; 159*c217d954SCole Faust 160*c217d954SCole Faust /** Data set containing tiny tensor shapes. */ 161*c217d954SCole Faust class TinyShapes final : public ShapeDataset 162*c217d954SCole Faust { 163*c217d954SCole Faust public: TinyShapes()164*c217d954SCole Faust TinyShapes() 165*c217d954SCole Faust : ShapeDataset("Shape", 166*c217d954SCole Faust { 167*c217d954SCole Faust // Batch size 1 168*c217d954SCole Faust TensorShape{ 1U, 9U }, 169*c217d954SCole Faust TensorShape{ 27U, 13U, 2U }, 170*c217d954SCole Faust }) 171*c217d954SCole Faust { 172*c217d954SCole Faust } 173*c217d954SCole Faust }; 174*c217d954SCole Faust /** Data set containing small tensor shapes with none of the dimensions equal to 1 (unit). */ 175*c217d954SCole Faust class SmallNoneUnitShapes final : public ShapeDataset 176*c217d954SCole Faust { 177*c217d954SCole Faust public: SmallNoneUnitShapes()178*c217d954SCole Faust SmallNoneUnitShapes() 179*c217d954SCole Faust : ShapeDataset("Shape", 180*c217d954SCole Faust { 181*c217d954SCole Faust // Batch size 1 182*c217d954SCole Faust TensorShape{ 13U, 11U }, 183*c217d954SCole Faust TensorShape{ 16U, 16U }, 184*c217d954SCole Faust TensorShape{ 24U, 26U, 5U }, 185*c217d954SCole Faust TensorShape{ 7U, 7U, 17U, 2U }, 186*c217d954SCole Faust // Batch size 4 187*c217d954SCole Faust TensorShape{ 27U, 13U, 2U, 4U }, 188*c217d954SCole Faust // Arbitrary batch size 189*c217d954SCole Faust TensorShape{ 8U, 7U, 5U, 5U } 190*c217d954SCole Faust }) 191*c217d954SCole Faust { 192*c217d954SCole Faust } 193*c217d954SCole Faust }; 194*c217d954SCole Faust /** Data set containing small tensor shapes. */ 195*c217d954SCole Faust class SmallShapes final : public ShapeDataset 196*c217d954SCole Faust { 197*c217d954SCole Faust public: SmallShapes()198*c217d954SCole Faust SmallShapes() 199*c217d954SCole Faust : ShapeDataset("Shape", 200*c217d954SCole Faust { 201*c217d954SCole Faust // Batch size 1 202*c217d954SCole Faust TensorShape{ 3U, 11U }, 203*c217d954SCole Faust TensorShape{ 1U, 16U }, 204*c217d954SCole Faust TensorShape{ 27U, 13U, 7U }, 205*c217d954SCole Faust TensorShape{ 7U, 7U, 17U, 2U }, 206*c217d954SCole Faust // Batch size 4 and 2 SIMD iterations 207*c217d954SCole Faust TensorShape{ 33U, 13U, 2U, 4U }, 208*c217d954SCole Faust // Arbitrary batch size 209*c217d954SCole Faust TensorShape{ 11U, 11U, 3U, 5U } 210*c217d954SCole Faust }) 211*c217d954SCole Faust { 212*c217d954SCole Faust } 213*c217d954SCole Faust }; 214*c217d954SCole Faust 215*c217d954SCole Faust /** Data set containing small tensor shapes. */ 216*c217d954SCole Faust class SmallShapesNoBatches final : public ShapeDataset 217*c217d954SCole Faust { 218*c217d954SCole Faust public: SmallShapesNoBatches()219*c217d954SCole Faust SmallShapesNoBatches() 220*c217d954SCole Faust : ShapeDataset("Shape", 221*c217d954SCole Faust { 222*c217d954SCole Faust // Batch size 1 223*c217d954SCole Faust TensorShape{ 3U, 11U }, 224*c217d954SCole Faust TensorShape{ 1U, 16U }, 225*c217d954SCole Faust TensorShape{ 27U, 13U, 7U }, 226*c217d954SCole Faust TensorShape{ 7U, 7U, 17U }, 227*c217d954SCole Faust TensorShape{ 33U, 13U, 2U }, 228*c217d954SCole Faust TensorShape{ 11U, 11U, 3U } 229*c217d954SCole Faust }) 230*c217d954SCole Faust { 231*c217d954SCole Faust } 232*c217d954SCole Faust }; 233*c217d954SCole Faust 234*c217d954SCole Faust /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */ 235*c217d954SCole Faust class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 236*c217d954SCole Faust { 237*c217d954SCole Faust public: TinyShapesBroadcast()238*c217d954SCole Faust TinyShapesBroadcast() 239*c217d954SCole Faust : ZipDataset<ShapeDataset, ShapeDataset>( 240*c217d954SCole Faust ShapeDataset("Shape0", 241*c217d954SCole Faust { 242*c217d954SCole Faust TensorShape{ 9U, 9U }, 243*c217d954SCole Faust TensorShape{ 10U, 2U, 14U, 2U }, 244*c217d954SCole Faust }), 245*c217d954SCole Faust ShapeDataset("Shape1", 246*c217d954SCole Faust { 247*c217d954SCole Faust TensorShape{ 9U, 1U, 9U }, 248*c217d954SCole Faust TensorShape{ 10U }, 249*c217d954SCole Faust })) 250*c217d954SCole Faust { 251*c217d954SCole Faust } 252*c217d954SCole Faust }; 253*c217d954SCole Faust /** Data set containing pairs of tiny tensor shapes that are broadcast compatible and can do in_place calculation. */ 254*c217d954SCole Faust class TinyShapesBroadcastInplace final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 255*c217d954SCole Faust { 256*c217d954SCole Faust public: TinyShapesBroadcastInplace()257*c217d954SCole Faust TinyShapesBroadcastInplace() 258*c217d954SCole Faust : ZipDataset<ShapeDataset, ShapeDataset>( 259*c217d954SCole Faust ShapeDataset("Shape0", 260*c217d954SCole Faust { 261*c217d954SCole Faust TensorShape{ 9U }, 262*c217d954SCole Faust TensorShape{ 10U, 2U, 14U, 2U }, 263*c217d954SCole Faust }), 264*c217d954SCole Faust ShapeDataset("Shape1", 265*c217d954SCole Faust { 266*c217d954SCole Faust TensorShape{ 9U, 1U, 9U }, 267*c217d954SCole Faust TensorShape{ 10U }, 268*c217d954SCole Faust })) 269*c217d954SCole Faust { 270*c217d954SCole Faust } 271*c217d954SCole Faust }; 272*c217d954SCole Faust /** Data set containing pairs of small tensor shapes that are broadcast compatible. */ 273*c217d954SCole Faust class SmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 274*c217d954SCole Faust { 275*c217d954SCole Faust public: SmallShapesBroadcast()276*c217d954SCole Faust SmallShapesBroadcast() 277*c217d954SCole Faust : ZipDataset<ShapeDataset, ShapeDataset>( 278*c217d954SCole Faust ShapeDataset("Shape0", 279*c217d954SCole Faust { 280*c217d954SCole Faust TensorShape{ 9U, 9U }, 281*c217d954SCole Faust TensorShape{ 27U, 13U, 2U }, 282*c217d954SCole Faust TensorShape{ 128U, 1U, 5U, 3U }, 283*c217d954SCole Faust TensorShape{ 9U, 9U, 3U, 4U }, 284*c217d954SCole Faust TensorShape{ 27U, 13U, 2U, 4U }, 285*c217d954SCole Faust TensorShape{ 1U, 1U, 1U, 5U }, 286*c217d954SCole Faust TensorShape{ 1U, 16U, 10U, 2U, 128U }, 287*c217d954SCole Faust TensorShape{ 1U, 16U, 10U, 2U, 128U } 288*c217d954SCole Faust }), 289*c217d954SCole Faust ShapeDataset("Shape1", 290*c217d954SCole Faust { 291*c217d954SCole Faust TensorShape{ 9U, 1U, 2U }, 292*c217d954SCole Faust TensorShape{ 1U, 13U, 2U }, 293*c217d954SCole Faust TensorShape{ 128U, 64U, 1U, 3U }, 294*c217d954SCole Faust TensorShape{ 9U, 1U, 3U }, 295*c217d954SCole Faust TensorShape{ 1U }, 296*c217d954SCole Faust TensorShape{ 9U, 9U, 3U, 5U }, 297*c217d954SCole Faust TensorShape{ 1U, 1U, 1U, 1U, 128U }, 298*c217d954SCole Faust TensorShape{ 128U } 299*c217d954SCole Faust })) 300*c217d954SCole Faust { 301*c217d954SCole Faust } 302*c217d954SCole Faust }; 303*c217d954SCole Faust 304*c217d954SCole Faust class TemporaryLimitedSmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 305*c217d954SCole Faust { 306*c217d954SCole Faust public: TemporaryLimitedSmallShapesBroadcast()307*c217d954SCole Faust TemporaryLimitedSmallShapesBroadcast() 308*c217d954SCole Faust : ZipDataset<ShapeDataset, ShapeDataset>( 309*c217d954SCole Faust ShapeDataset("Shape0", 310*c217d954SCole Faust { 311*c217d954SCole Faust TensorShape{ 1U, 3U, 4U, 2U }, // LHS broadcast X 312*c217d954SCole Faust TensorShape{ 6U, 4U, 2U, 3U }, // RHS broadcast X 313*c217d954SCole Faust TensorShape{ 7U, 1U, 1U, 4U }, // LHS broadcast Y, Z 314*c217d954SCole Faust TensorShape{ 8U, 5U, 6U, 3U }, // RHS broadcast Y, Z 315*c217d954SCole Faust TensorShape{ 1U, 1U, 1U, 2U }, // LHS broadcast X, Y, Z 316*c217d954SCole Faust TensorShape{ 2U, 6U, 4U, 3U }, // RHS broadcast X, Y, Z 317*c217d954SCole Faust }), 318*c217d954SCole Faust ShapeDataset("Shape1", 319*c217d954SCole Faust { 320*c217d954SCole Faust TensorShape{ 5U, 3U, 4U, 2U }, 321*c217d954SCole Faust TensorShape{ 1U, 4U, 2U, 3U }, 322*c217d954SCole Faust TensorShape{ 7U, 2U, 3U, 4U }, 323*c217d954SCole Faust TensorShape{ 8U, 1U, 1U, 3U }, 324*c217d954SCole Faust TensorShape{ 4U, 7U, 3U, 2U }, 325*c217d954SCole Faust TensorShape{ 1U, 1U, 1U, 3U }, 326*c217d954SCole Faust })) 327*c217d954SCole Faust { 328*c217d954SCole Faust } 329*c217d954SCole Faust }; 330*c217d954SCole Faust 331*c217d954SCole Faust class TemporaryLimitedLargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 332*c217d954SCole Faust { 333*c217d954SCole Faust public: TemporaryLimitedLargeShapesBroadcast()334*c217d954SCole Faust TemporaryLimitedLargeShapesBroadcast() 335*c217d954SCole Faust : ZipDataset<ShapeDataset, ShapeDataset>( 336*c217d954SCole Faust ShapeDataset("Shape0", 337*c217d954SCole Faust { 338*c217d954SCole Faust TensorShape{ 127U, 25U, 5U }, 339*c217d954SCole Faust TensorShape{ 485, 40U, 10U } 340*c217d954SCole Faust }), 341*c217d954SCole Faust ShapeDataset("Shape1", 342*c217d954SCole Faust { 343*c217d954SCole Faust TensorShape{ 1U, 1U, 1U }, // Broadcast in X, Y, Z 344*c217d954SCole Faust TensorShape{ 485U, 1U, 1U }, // Broadcast in Y, Z 345*c217d954SCole Faust })) 346*c217d954SCole Faust { 347*c217d954SCole Faust } 348*c217d954SCole Faust }; 349*c217d954SCole Faust 350*c217d954SCole Faust /** Data set containing medium tensor shapes. */ 351*c217d954SCole Faust class MediumShapes final : public ShapeDataset 352*c217d954SCole Faust { 353*c217d954SCole Faust public: MediumShapes()354*c217d954SCole Faust MediumShapes() 355*c217d954SCole Faust : ShapeDataset("Shape", 356*c217d954SCole Faust { 357*c217d954SCole Faust // Batch size 1 358*c217d954SCole Faust TensorShape{ 37U, 37U }, 359*c217d954SCole Faust TensorShape{ 27U, 33U, 2U }, 360*c217d954SCole Faust // Arbitrary batch size 361*c217d954SCole Faust TensorShape{ 37U, 37U, 3U, 5U } 362*c217d954SCole Faust }) 363*c217d954SCole Faust { 364*c217d954SCole Faust } 365*c217d954SCole Faust }; 366*c217d954SCole Faust 367*c217d954SCole Faust /** Data set containing medium 2D tensor shapes. */ 368*c217d954SCole Faust class Medium2DShapes final : public ShapeDataset 369*c217d954SCole Faust { 370*c217d954SCole Faust public: Medium2DShapes()371*c217d954SCole Faust Medium2DShapes() 372*c217d954SCole Faust : ShapeDataset("Shape", 373*c217d954SCole Faust { 374*c217d954SCole Faust TensorShape{ 42U, 37U }, 375*c217d954SCole Faust TensorShape{ 57U, 60U }, 376*c217d954SCole Faust TensorShape{ 128U, 64U }, 377*c217d954SCole Faust TensorShape{ 83U, 72U }, 378*c217d954SCole Faust TensorShape{ 40U, 40U } 379*c217d954SCole Faust }) 380*c217d954SCole Faust { 381*c217d954SCole Faust } 382*c217d954SCole Faust }; 383*c217d954SCole Faust 384*c217d954SCole Faust /** Data set containing medium 3D tensor shapes. */ 385*c217d954SCole Faust class Medium3DShapes final : public ShapeDataset 386*c217d954SCole Faust { 387*c217d954SCole Faust public: Medium3DShapes()388*c217d954SCole Faust Medium3DShapes() 389*c217d954SCole Faust : ShapeDataset("Shape", 390*c217d954SCole Faust { 391*c217d954SCole Faust TensorShape{ 42U, 37U, 8U }, 392*c217d954SCole Faust TensorShape{ 57U, 60U, 13U }, 393*c217d954SCole Faust TensorShape{ 83U, 72U, 14U } 394*c217d954SCole Faust }) 395*c217d954SCole Faust { 396*c217d954SCole Faust } 397*c217d954SCole Faust }; 398*c217d954SCole Faust 399*c217d954SCole Faust /** Data set containing medium 4D tensor shapes. */ 400*c217d954SCole Faust class Medium4DShapes final : public ShapeDataset 401*c217d954SCole Faust { 402*c217d954SCole Faust public: Medium4DShapes()403*c217d954SCole Faust Medium4DShapes() 404*c217d954SCole Faust : ShapeDataset("Shape", 405*c217d954SCole Faust { 406*c217d954SCole Faust TensorShape{ 42U, 37U, 8U, 15U }, 407*c217d954SCole Faust TensorShape{ 57U, 60U, 13U, 8U }, 408*c217d954SCole Faust TensorShape{ 83U, 72U, 14U, 5U } 409*c217d954SCole Faust }) 410*c217d954SCole Faust { 411*c217d954SCole Faust } 412*c217d954SCole Faust }; 413*c217d954SCole Faust 414*c217d954SCole Faust /** Data set containing large tensor shapes. */ 415*c217d954SCole Faust class LargeShapes final : public ShapeDataset 416*c217d954SCole Faust { 417*c217d954SCole Faust public: LargeShapes()418*c217d954SCole Faust LargeShapes() 419*c217d954SCole Faust : ShapeDataset("Shape", 420*c217d954SCole Faust { 421*c217d954SCole Faust TensorShape{ 582U, 131U, 1U, 4U }, 422*c217d954SCole Faust }) 423*c217d954SCole Faust { 424*c217d954SCole Faust } 425*c217d954SCole Faust }; 426*c217d954SCole Faust 427*c217d954SCole Faust /** Data set containing large tensor shapes. */ 428*c217d954SCole Faust class LargeShapesNoBatches final : public ShapeDataset 429*c217d954SCole Faust { 430*c217d954SCole Faust public: LargeShapesNoBatches()431*c217d954SCole Faust LargeShapesNoBatches() 432*c217d954SCole Faust : ShapeDataset("Shape", 433*c217d954SCole Faust { 434*c217d954SCole Faust TensorShape{ 582U, 131U, 2U }, 435*c217d954SCole Faust }) 436*c217d954SCole Faust { 437*c217d954SCole Faust } 438*c217d954SCole Faust }; 439*c217d954SCole Faust 440*c217d954SCole Faust /** Data set containing pairs of large tensor shapes that are broadcast compatible. */ 441*c217d954SCole Faust class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> 442*c217d954SCole Faust { 443*c217d954SCole Faust public: LargeShapesBroadcast()444*c217d954SCole Faust LargeShapesBroadcast() 445*c217d954SCole Faust : ZipDataset<ShapeDataset, ShapeDataset>( 446*c217d954SCole Faust ShapeDataset("Shape0", 447*c217d954SCole Faust { 448*c217d954SCole Faust TensorShape{ 1921U, 541U }, 449*c217d954SCole Faust TensorShape{ 1U, 485U, 2U, 3U }, 450*c217d954SCole Faust TensorShape{ 4159U, 1U }, 451*c217d954SCole Faust TensorShape{ 799U } 452*c217d954SCole Faust }), 453*c217d954SCole Faust ShapeDataset("Shape1", 454*c217d954SCole Faust { 455*c217d954SCole Faust TensorShape{ 1921U, 1U, 2U }, 456*c217d954SCole Faust TensorShape{ 641U, 1U, 2U, 3U }, 457*c217d954SCole Faust TensorShape{ 1U, 127U, 25U }, 458*c217d954SCole Faust TensorShape{ 799U, 595U, 1U, 4U } 459*c217d954SCole Faust })) 460*c217d954SCole Faust { 461*c217d954SCole Faust } 462*c217d954SCole Faust }; 463*c217d954SCole Faust 464*c217d954SCole Faust /** Data set containing large 1D tensor shapes. */ 465*c217d954SCole Faust class Large1DShapes final : public ShapeDataset 466*c217d954SCole Faust { 467*c217d954SCole Faust public: Large1DShapes()468*c217d954SCole Faust Large1DShapes() 469*c217d954SCole Faust : ShapeDataset("Shape", 470*c217d954SCole Faust { 471*c217d954SCole Faust TensorShape{ 1245U } 472*c217d954SCole Faust }) 473*c217d954SCole Faust { 474*c217d954SCole Faust } 475*c217d954SCole Faust }; 476*c217d954SCole Faust 477*c217d954SCole Faust /** Data set containing large 2D tensor shapes. */ 478*c217d954SCole Faust class Large2DShapes final : public ShapeDataset 479*c217d954SCole Faust { 480*c217d954SCole Faust public: Large2DShapes()481*c217d954SCole Faust Large2DShapes() 482*c217d954SCole Faust : ShapeDataset("Shape", 483*c217d954SCole Faust { 484*c217d954SCole Faust TensorShape{ 1245U, 652U } 485*c217d954SCole Faust }) 486*c217d954SCole Faust { 487*c217d954SCole Faust } 488*c217d954SCole Faust }; 489*c217d954SCole Faust 490*c217d954SCole Faust /** Data set containing large 3D tensor shapes. */ 491*c217d954SCole Faust class Large3DShapes final : public ShapeDataset 492*c217d954SCole Faust { 493*c217d954SCole Faust public: Large3DShapes()494*c217d954SCole Faust Large3DShapes() 495*c217d954SCole Faust : ShapeDataset("Shape", 496*c217d954SCole Faust { 497*c217d954SCole Faust TensorShape{ 320U, 240U, 3U } 498*c217d954SCole Faust }) 499*c217d954SCole Faust { 500*c217d954SCole Faust } 501*c217d954SCole Faust }; 502*c217d954SCole Faust 503*c217d954SCole Faust /** Data set containing large 4D tensor shapes. */ 504*c217d954SCole Faust class Large4DShapes final : public ShapeDataset 505*c217d954SCole Faust { 506*c217d954SCole Faust public: Large4DShapes()507*c217d954SCole Faust Large4DShapes() 508*c217d954SCole Faust : ShapeDataset("Shape", 509*c217d954SCole Faust { 510*c217d954SCole Faust TensorShape{ 320U, 123U, 3U, 3U } 511*c217d954SCole Faust }) 512*c217d954SCole Faust { 513*c217d954SCole Faust } 514*c217d954SCole Faust }; 515*c217d954SCole Faust 516*c217d954SCole Faust /** Data set containing small 3x3 tensor shapes. */ 517*c217d954SCole Faust class Small3x3Shapes final : public ShapeDataset 518*c217d954SCole Faust { 519*c217d954SCole Faust public: Small3x3Shapes()520*c217d954SCole Faust Small3x3Shapes() 521*c217d954SCole Faust : ShapeDataset("Shape", 522*c217d954SCole Faust { 523*c217d954SCole Faust TensorShape{ 3U, 3U, 7U, 4U }, 524*c217d954SCole Faust TensorShape{ 3U, 3U, 4U, 13U }, 525*c217d954SCole Faust TensorShape{ 3U, 3U, 3U, 5U }, 526*c217d954SCole Faust }) 527*c217d954SCole Faust { 528*c217d954SCole Faust } 529*c217d954SCole Faust }; 530*c217d954SCole Faust 531*c217d954SCole Faust /** Data set containing small 3x1 tensor shapes. */ 532*c217d954SCole Faust class Small3x1Shapes final : public ShapeDataset 533*c217d954SCole Faust { 534*c217d954SCole Faust public: Small3x1Shapes()535*c217d954SCole Faust Small3x1Shapes() 536*c217d954SCole Faust : ShapeDataset("Shape", 537*c217d954SCole Faust { 538*c217d954SCole Faust TensorShape{ 3U, 1U, 7U, 4U }, 539*c217d954SCole Faust TensorShape{ 3U, 1U, 4U, 13U }, 540*c217d954SCole Faust TensorShape{ 3U, 1U, 3U, 5U }, 541*c217d954SCole Faust }) 542*c217d954SCole Faust { 543*c217d954SCole Faust } 544*c217d954SCole Faust }; 545*c217d954SCole Faust 546*c217d954SCole Faust /** Data set containing small 1x3 tensor shapes. */ 547*c217d954SCole Faust class Small1x3Shapes final : public ShapeDataset 548*c217d954SCole Faust { 549*c217d954SCole Faust public: Small1x3Shapes()550*c217d954SCole Faust Small1x3Shapes() 551*c217d954SCole Faust : ShapeDataset("Shape", 552*c217d954SCole Faust { 553*c217d954SCole Faust TensorShape{ 1U, 3U, 7U, 4U }, 554*c217d954SCole Faust TensorShape{ 1U, 3U, 4U, 13U }, 555*c217d954SCole Faust TensorShape{ 1U, 3U, 3U, 5U }, 556*c217d954SCole Faust }) 557*c217d954SCole Faust { 558*c217d954SCole Faust } 559*c217d954SCole Faust }; 560*c217d954SCole Faust 561*c217d954SCole Faust /** Data set containing large 3x3 tensor shapes. */ 562*c217d954SCole Faust class Large3x3Shapes final : public ShapeDataset 563*c217d954SCole Faust { 564*c217d954SCole Faust public: Large3x3Shapes()565*c217d954SCole Faust Large3x3Shapes() 566*c217d954SCole Faust : ShapeDataset("Shape", 567*c217d954SCole Faust { 568*c217d954SCole Faust TensorShape{ 3U, 3U, 32U, 64U }, 569*c217d954SCole Faust TensorShape{ 3U, 3U, 51U, 13U }, 570*c217d954SCole Faust TensorShape{ 3U, 3U, 53U, 47U }, 571*c217d954SCole Faust }) 572*c217d954SCole Faust { 573*c217d954SCole Faust } 574*c217d954SCole Faust }; 575*c217d954SCole Faust 576*c217d954SCole Faust /** Data set containing large 3x1 tensor shapes. */ 577*c217d954SCole Faust class Large3x1Shapes final : public ShapeDataset 578*c217d954SCole Faust { 579*c217d954SCole Faust public: Large3x1Shapes()580*c217d954SCole Faust Large3x1Shapes() 581*c217d954SCole Faust : ShapeDataset("Shape", 582*c217d954SCole Faust { 583*c217d954SCole Faust TensorShape{ 3U, 1U, 32U, 64U }, 584*c217d954SCole Faust TensorShape{ 3U, 1U, 51U, 13U }, 585*c217d954SCole Faust TensorShape{ 3U, 1U, 53U, 47U }, 586*c217d954SCole Faust }) 587*c217d954SCole Faust { 588*c217d954SCole Faust } 589*c217d954SCole Faust }; 590*c217d954SCole Faust 591*c217d954SCole Faust /** Data set containing large 1x3 tensor shapes. */ 592*c217d954SCole Faust class Large1x3Shapes final : public ShapeDataset 593*c217d954SCole Faust { 594*c217d954SCole Faust public: Large1x3Shapes()595*c217d954SCole Faust Large1x3Shapes() 596*c217d954SCole Faust : ShapeDataset("Shape", 597*c217d954SCole Faust { 598*c217d954SCole Faust TensorShape{ 1U, 3U, 32U, 64U }, 599*c217d954SCole Faust TensorShape{ 1U, 3U, 51U, 13U }, 600*c217d954SCole Faust TensorShape{ 1U, 3U, 53U, 47U }, 601*c217d954SCole Faust }) 602*c217d954SCole Faust { 603*c217d954SCole Faust } 604*c217d954SCole Faust }; 605*c217d954SCole Faust 606*c217d954SCole Faust /** Data set containing small 5x5 tensor shapes. */ 607*c217d954SCole Faust class Small5x5Shapes final : public ShapeDataset 608*c217d954SCole Faust { 609*c217d954SCole Faust public: Small5x5Shapes()610*c217d954SCole Faust Small5x5Shapes() 611*c217d954SCole Faust : ShapeDataset("Shape", 612*c217d954SCole Faust { 613*c217d954SCole Faust TensorShape{ 5U, 5U, 7U, 4U }, 614*c217d954SCole Faust TensorShape{ 5U, 5U, 4U, 13U }, 615*c217d954SCole Faust TensorShape{ 5U, 5U, 3U, 5U }, 616*c217d954SCole Faust }) 617*c217d954SCole Faust { 618*c217d954SCole Faust } 619*c217d954SCole Faust }; 620*c217d954SCole Faust 621*c217d954SCole Faust /** Data set containing small 5D tensor shapes. */ 622*c217d954SCole Faust class Small5dShapes final : public ShapeDataset 623*c217d954SCole Faust { 624*c217d954SCole Faust public: Small5dShapes()625*c217d954SCole Faust Small5dShapes() 626*c217d954SCole Faust : ShapeDataset("Shape", 627*c217d954SCole Faust { 628*c217d954SCole Faust TensorShape{ 5U, 5U, 7U, 4U, 3U }, 629*c217d954SCole Faust TensorShape{ 5U, 5U, 4U, 13U, 2U }, 630*c217d954SCole Faust TensorShape{ 5U, 5U, 3U, 5U, 2U }, 631*c217d954SCole Faust }) 632*c217d954SCole Faust { 633*c217d954SCole Faust } 634*c217d954SCole Faust }; 635*c217d954SCole Faust 636*c217d954SCole Faust /** Data set containing large 5x5 tensor shapes. */ 637*c217d954SCole Faust class Large5x5Shapes final : public ShapeDataset 638*c217d954SCole Faust { 639*c217d954SCole Faust public: Large5x5Shapes()640*c217d954SCole Faust Large5x5Shapes() 641*c217d954SCole Faust : ShapeDataset("Shape", 642*c217d954SCole Faust { 643*c217d954SCole Faust TensorShape{ 5U, 5U, 32U, 64U } 644*c217d954SCole Faust }) 645*c217d954SCole Faust { 646*c217d954SCole Faust } 647*c217d954SCole Faust }; 648*c217d954SCole Faust 649*c217d954SCole Faust /** Data set containing large 5D tensor shapes. */ 650*c217d954SCole Faust class Large5dShapes final : public ShapeDataset 651*c217d954SCole Faust { 652*c217d954SCole Faust public: Large5dShapes()653*c217d954SCole Faust Large5dShapes() 654*c217d954SCole Faust : ShapeDataset("Shape", 655*c217d954SCole Faust { 656*c217d954SCole Faust TensorShape{ 30U, 40U, 30U, 32U, 3U } 657*c217d954SCole Faust }) 658*c217d954SCole Faust { 659*c217d954SCole Faust } 660*c217d954SCole Faust }; 661*c217d954SCole Faust 662*c217d954SCole Faust /** Data set containing small 5x1 tensor shapes. */ 663*c217d954SCole Faust class Small5x1Shapes final : public ShapeDataset 664*c217d954SCole Faust { 665*c217d954SCole Faust public: Small5x1Shapes()666*c217d954SCole Faust Small5x1Shapes() 667*c217d954SCole Faust : ShapeDataset("Shape", 668*c217d954SCole Faust { 669*c217d954SCole Faust TensorShape{ 5U, 1U, 7U, 4U } 670*c217d954SCole Faust }) 671*c217d954SCole Faust { 672*c217d954SCole Faust } 673*c217d954SCole Faust }; 674*c217d954SCole Faust 675*c217d954SCole Faust /** Data set containing large 5x1 tensor shapes. */ 676*c217d954SCole Faust class Large5x1Shapes final : public ShapeDataset 677*c217d954SCole Faust { 678*c217d954SCole Faust public: Large5x1Shapes()679*c217d954SCole Faust Large5x1Shapes() 680*c217d954SCole Faust : ShapeDataset("Shape", 681*c217d954SCole Faust { 682*c217d954SCole Faust TensorShape{ 5U, 1U, 32U, 64U } 683*c217d954SCole Faust }) 684*c217d954SCole Faust { 685*c217d954SCole Faust } 686*c217d954SCole Faust }; 687*c217d954SCole Faust 688*c217d954SCole Faust /** Data set containing small 1x5 tensor shapes. */ 689*c217d954SCole Faust class Small1x5Shapes final : public ShapeDataset 690*c217d954SCole Faust { 691*c217d954SCole Faust public: Small1x5Shapes()692*c217d954SCole Faust Small1x5Shapes() 693*c217d954SCole Faust : ShapeDataset("Shape", 694*c217d954SCole Faust { 695*c217d954SCole Faust TensorShape{ 1U, 5U, 7U, 4U } 696*c217d954SCole Faust }) 697*c217d954SCole Faust { 698*c217d954SCole Faust } 699*c217d954SCole Faust }; 700*c217d954SCole Faust 701*c217d954SCole Faust /** Data set containing large 1x5 tensor shapes. */ 702*c217d954SCole Faust class Large1x5Shapes final : public ShapeDataset 703*c217d954SCole Faust { 704*c217d954SCole Faust public: Large1x5Shapes()705*c217d954SCole Faust Large1x5Shapes() 706*c217d954SCole Faust : ShapeDataset("Shape", 707*c217d954SCole Faust { 708*c217d954SCole Faust TensorShape{ 1U, 5U, 32U, 64U } 709*c217d954SCole Faust }) 710*c217d954SCole Faust { 711*c217d954SCole Faust } 712*c217d954SCole Faust }; 713*c217d954SCole Faust 714*c217d954SCole Faust /** Data set containing small 1x7 tensor shapes. */ 715*c217d954SCole Faust class Small1x7Shapes final : public ShapeDataset 716*c217d954SCole Faust { 717*c217d954SCole Faust public: Small1x7Shapes()718*c217d954SCole Faust Small1x7Shapes() 719*c217d954SCole Faust : ShapeDataset("Shape", 720*c217d954SCole Faust { 721*c217d954SCole Faust TensorShape{ 1U, 7U, 7U, 4U } 722*c217d954SCole Faust }) 723*c217d954SCole Faust { 724*c217d954SCole Faust } 725*c217d954SCole Faust }; 726*c217d954SCole Faust 727*c217d954SCole Faust /** Data set containing large 1x7 tensor shapes. */ 728*c217d954SCole Faust class Large1x7Shapes final : public ShapeDataset 729*c217d954SCole Faust { 730*c217d954SCole Faust public: Large1x7Shapes()731*c217d954SCole Faust Large1x7Shapes() 732*c217d954SCole Faust : ShapeDataset("Shape", 733*c217d954SCole Faust { 734*c217d954SCole Faust TensorShape{ 1U, 7U, 32U, 64U } 735*c217d954SCole Faust }) 736*c217d954SCole Faust { 737*c217d954SCole Faust } 738*c217d954SCole Faust }; 739*c217d954SCole Faust 740*c217d954SCole Faust /** Data set containing small 7x7 tensor shapes. */ 741*c217d954SCole Faust class Small7x7Shapes final : public ShapeDataset 742*c217d954SCole Faust { 743*c217d954SCole Faust public: Small7x7Shapes()744*c217d954SCole Faust Small7x7Shapes() 745*c217d954SCole Faust : ShapeDataset("Shape", 746*c217d954SCole Faust { 747*c217d954SCole Faust TensorShape{ 7U, 7U, 7U, 4U } 748*c217d954SCole Faust }) 749*c217d954SCole Faust { 750*c217d954SCole Faust } 751*c217d954SCole Faust }; 752*c217d954SCole Faust 753*c217d954SCole Faust /** Data set containing large 7x7 tensor shapes. */ 754*c217d954SCole Faust class Large7x7Shapes final : public ShapeDataset 755*c217d954SCole Faust { 756*c217d954SCole Faust public: Large7x7Shapes()757*c217d954SCole Faust Large7x7Shapes() 758*c217d954SCole Faust : ShapeDataset("Shape", 759*c217d954SCole Faust { 760*c217d954SCole Faust TensorShape{ 7U, 7U, 32U, 64U } 761*c217d954SCole Faust }) 762*c217d954SCole Faust { 763*c217d954SCole Faust } 764*c217d954SCole Faust }; 765*c217d954SCole Faust 766*c217d954SCole Faust /** Data set containing small 7x1 tensor shapes. */ 767*c217d954SCole Faust class Small7x1Shapes final : public ShapeDataset 768*c217d954SCole Faust { 769*c217d954SCole Faust public: Small7x1Shapes()770*c217d954SCole Faust Small7x1Shapes() 771*c217d954SCole Faust : ShapeDataset("Shape", 772*c217d954SCole Faust { 773*c217d954SCole Faust TensorShape{ 7U, 1U, 7U, 4U } 774*c217d954SCole Faust }) 775*c217d954SCole Faust { 776*c217d954SCole Faust } 777*c217d954SCole Faust }; 778*c217d954SCole Faust 779*c217d954SCole Faust /** Data set containing large 7x1 tensor shapes. */ 780*c217d954SCole Faust class Large7x1Shapes final : public ShapeDataset 781*c217d954SCole Faust { 782*c217d954SCole Faust public: Large7x1Shapes()783*c217d954SCole Faust Large7x1Shapes() 784*c217d954SCole Faust : ShapeDataset("Shape", 785*c217d954SCole Faust { 786*c217d954SCole Faust TensorShape{ 7U, 1U, 32U, 64U } 787*c217d954SCole Faust }) 788*c217d954SCole Faust { 789*c217d954SCole Faust } 790*c217d954SCole Faust }; 791*c217d954SCole Faust 792*c217d954SCole Faust /** Data set containing small tensor shapes for deconvolution. */ 793*c217d954SCole Faust class SmallDeconvolutionShapes final : public ShapeDataset 794*c217d954SCole Faust { 795*c217d954SCole Faust public: SmallDeconvolutionShapes()796*c217d954SCole Faust SmallDeconvolutionShapes() 797*c217d954SCole Faust : ShapeDataset("InputShape", 798*c217d954SCole Faust { 799*c217d954SCole Faust // Multiple Vector Loops for FP32 800*c217d954SCole Faust TensorShape{ 5U, 4U, 3U, 2U }, 801*c217d954SCole Faust TensorShape{ 5U, 5U, 3U }, 802*c217d954SCole Faust TensorShape{ 11U, 13U, 4U, 3U } 803*c217d954SCole Faust }) 804*c217d954SCole Faust { 805*c217d954SCole Faust } 806*c217d954SCole Faust }; 807*c217d954SCole Faust 808*c217d954SCole Faust class SmallDeconvolutionShapesWithLargerChannels final : public ShapeDataset 809*c217d954SCole Faust { 810*c217d954SCole Faust public: SmallDeconvolutionShapesWithLargerChannels()811*c217d954SCole Faust SmallDeconvolutionShapesWithLargerChannels() 812*c217d954SCole Faust : ShapeDataset("InputShape", 813*c217d954SCole Faust { 814*c217d954SCole Faust // Multiple Vector Loops for all data types 815*c217d954SCole Faust TensorShape{ 5U, 5U, 35U } 816*c217d954SCole Faust }) 817*c217d954SCole Faust { 818*c217d954SCole Faust } 819*c217d954SCole Faust }; 820*c217d954SCole Faust 821*c217d954SCole Faust /** Data set containing tiny tensor shapes for direct convolution. */ 822*c217d954SCole Faust class TinyDirectConvolutionShapes final : public ShapeDataset 823*c217d954SCole Faust { 824*c217d954SCole Faust public: TinyDirectConvolutionShapes()825*c217d954SCole Faust TinyDirectConvolutionShapes() 826*c217d954SCole Faust : ShapeDataset("InputShape", 827*c217d954SCole Faust { 828*c217d954SCole Faust // Batch size 1 829*c217d954SCole Faust TensorShape{ 11U, 13U, 3U }, 830*c217d954SCole Faust TensorShape{ 7U, 27U, 3U } 831*c217d954SCole Faust }) 832*c217d954SCole Faust { 833*c217d954SCole Faust } 834*c217d954SCole Faust }; 835*c217d954SCole Faust /** Data set containing small tensor shapes for direct convolution. */ 836*c217d954SCole Faust class SmallDirectConvolutionShapes final : public ShapeDataset 837*c217d954SCole Faust { 838*c217d954SCole Faust public: SmallDirectConvolutionShapes()839*c217d954SCole Faust SmallDirectConvolutionShapes() 840*c217d954SCole Faust : ShapeDataset("InputShape", 841*c217d954SCole Faust { 842*c217d954SCole Faust // Batch size 1 843*c217d954SCole Faust TensorShape{ 32U, 37U, 3U }, 844*c217d954SCole Faust // Batch size 4 845*c217d954SCole Faust TensorShape{ 6U, 9U, 5U, 4U }, 846*c217d954SCole Faust }) 847*c217d954SCole Faust { 848*c217d954SCole Faust } 849*c217d954SCole Faust }; 850*c217d954SCole Faust 851*c217d954SCole Faust class SmallDirectConv3DShapes final : public ShapeDataset 852*c217d954SCole Faust { 853*c217d954SCole Faust public: SmallDirectConv3DShapes()854*c217d954SCole Faust SmallDirectConv3DShapes() 855*c217d954SCole Faust : ShapeDataset("InputShape", 856*c217d954SCole Faust { 857*c217d954SCole Faust // Batch size 2 858*c217d954SCole Faust TensorShape{ 1U, 3U, 4U, 5U, 2U }, 859*c217d954SCole Faust // Batch size 3 860*c217d954SCole Faust TensorShape{ 7U, 27U, 3U, 6U, 3U }, 861*c217d954SCole Faust // Batch size 1 862*c217d954SCole Faust TensorShape{ 32U, 37U, 13U, 1U, 1U }, 863*c217d954SCole Faust }) 864*c217d954SCole Faust { 865*c217d954SCole Faust } 866*c217d954SCole Faust }; 867*c217d954SCole Faust 868*c217d954SCole Faust /** Data set containing small tensor shapes for direct convolution. */ 869*c217d954SCole Faust class SmallDirectConvolutionTensorShiftShapes final : public ShapeDataset 870*c217d954SCole Faust { 871*c217d954SCole Faust public: SmallDirectConvolutionTensorShiftShapes()872*c217d954SCole Faust SmallDirectConvolutionTensorShiftShapes() 873*c217d954SCole Faust : ShapeDataset("InputShape", 874*c217d954SCole Faust { 875*c217d954SCole Faust // Batch size 1 876*c217d954SCole Faust TensorShape{ 32U, 37U, 3U }, 877*c217d954SCole Faust // Batch size 4 878*c217d954SCole Faust TensorShape{ 32U, 37U, 3U, 4U }, 879*c217d954SCole Faust // Arbitrary batch size 880*c217d954SCole Faust TensorShape{ 32U, 37U, 3U, 8U } 881*c217d954SCole Faust }) 882*c217d954SCole Faust { 883*c217d954SCole Faust } 884*c217d954SCole Faust }; 885*c217d954SCole Faust 886*c217d954SCole Faust /** Data set containing small grouped im2col tensor shapes. */ 887*c217d954SCole Faust class GroupedIm2ColSmallShapes final : public ShapeDataset 888*c217d954SCole Faust { 889*c217d954SCole Faust public: GroupedIm2ColSmallShapes()890*c217d954SCole Faust GroupedIm2ColSmallShapes() 891*c217d954SCole Faust : ShapeDataset("Shape", 892*c217d954SCole Faust { 893*c217d954SCole Faust TensorShape{ 11U, 11U, 48U }, 894*c217d954SCole Faust TensorShape{ 27U, 13U, 24U }, 895*c217d954SCole Faust TensorShape{ 128U, 64U, 12U, 3U }, 896*c217d954SCole Faust TensorShape{ 11U, 11U, 48U, 4U }, 897*c217d954SCole Faust TensorShape{ 27U, 13U, 24U, 4U }, 898*c217d954SCole Faust TensorShape{ 11U, 11U, 48U, 5U } 899*c217d954SCole Faust }) 900*c217d954SCole Faust { 901*c217d954SCole Faust } 902*c217d954SCole Faust }; 903*c217d954SCole Faust 904*c217d954SCole Faust /** Data set containing large grouped im2col tensor shapes. */ 905*c217d954SCole Faust class GroupedIm2ColLargeShapes final : public ShapeDataset 906*c217d954SCole Faust { 907*c217d954SCole Faust public: GroupedIm2ColLargeShapes()908*c217d954SCole Faust GroupedIm2ColLargeShapes() 909*c217d954SCole Faust : ShapeDataset("Shape", 910*c217d954SCole Faust { 911*c217d954SCole Faust TensorShape{ 153U, 231U, 12U }, 912*c217d954SCole Faust TensorShape{ 123U, 191U, 12U, 2U }, 913*c217d954SCole Faust }) 914*c217d954SCole Faust { 915*c217d954SCole Faust } 916*c217d954SCole Faust }; 917*c217d954SCole Faust 918*c217d954SCole Faust /** Data set containing small grouped weights tensor shapes. */ 919*c217d954SCole Faust class GroupedWeightsSmallShapes final : public ShapeDataset 920*c217d954SCole Faust { 921*c217d954SCole Faust public: GroupedWeightsSmallShapes()922*c217d954SCole Faust GroupedWeightsSmallShapes() 923*c217d954SCole Faust : ShapeDataset("Shape", 924*c217d954SCole Faust { 925*c217d954SCole Faust TensorShape{ 3U, 3U, 48U, 120U }, 926*c217d954SCole Faust TensorShape{ 1U, 3U, 24U, 240U }, 927*c217d954SCole Faust TensorShape{ 3U, 1U, 12U, 480U }, 928*c217d954SCole Faust TensorShape{ 5U, 5U, 48U, 120U } 929*c217d954SCole Faust }) 930*c217d954SCole Faust { 931*c217d954SCole Faust } 932*c217d954SCole Faust }; 933*c217d954SCole Faust 934*c217d954SCole Faust /** Data set containing large grouped weights tensor shapes. */ 935*c217d954SCole Faust class GroupedWeightsLargeShapes final : public ShapeDataset 936*c217d954SCole Faust { 937*c217d954SCole Faust public: GroupedWeightsLargeShapes()938*c217d954SCole Faust GroupedWeightsLargeShapes() 939*c217d954SCole Faust : ShapeDataset("Shape", 940*c217d954SCole Faust { 941*c217d954SCole Faust TensorShape{ 9U, 9U, 96U, 240U }, 942*c217d954SCole Faust TensorShape{ 13U, 13U, 96U, 240U } 943*c217d954SCole Faust }) 944*c217d954SCole Faust { 945*c217d954SCole Faust } 946*c217d954SCole Faust }; 947*c217d954SCole Faust 948*c217d954SCole Faust /** Data set containing 2D tensor shapes for DepthConcatenateLayer. */ 949*c217d954SCole Faust class DepthConcatenateLayerShapes final : public ShapeDataset 950*c217d954SCole Faust { 951*c217d954SCole Faust public: DepthConcatenateLayerShapes()952*c217d954SCole Faust DepthConcatenateLayerShapes() 953*c217d954SCole Faust : ShapeDataset("Shape", 954*c217d954SCole Faust { 955*c217d954SCole Faust TensorShape{ 322U, 243U }, 956*c217d954SCole Faust TensorShape{ 463U, 879U }, 957*c217d954SCole Faust TensorShape{ 416U, 651U } 958*c217d954SCole Faust }) 959*c217d954SCole Faust { 960*c217d954SCole Faust } 961*c217d954SCole Faust }; 962*c217d954SCole Faust 963*c217d954SCole Faust /** Data set containing tensor shapes for ConcatenateLayer. */ 964*c217d954SCole Faust class ConcatenateLayerShapes final : public ShapeDataset 965*c217d954SCole Faust { 966*c217d954SCole Faust public: ConcatenateLayerShapes()967*c217d954SCole Faust ConcatenateLayerShapes() 968*c217d954SCole Faust : ShapeDataset("Shape", 969*c217d954SCole Faust { 970*c217d954SCole Faust TensorShape{ 232U, 65U, 3U }, 971*c217d954SCole Faust TensorShape{ 432U, 65U, 3U }, 972*c217d954SCole Faust TensorShape{ 124U, 65U, 3U }, 973*c217d954SCole Faust TensorShape{ 124U, 65U, 3U, 4U } 974*c217d954SCole Faust }) 975*c217d954SCole Faust { 976*c217d954SCole Faust } 977*c217d954SCole Faust }; 978*c217d954SCole Faust 979*c217d954SCole Faust /** Data set containing global pooling tensor shapes. */ 980*c217d954SCole Faust class GlobalPoolingShapes final : public ShapeDataset 981*c217d954SCole Faust { 982*c217d954SCole Faust public: GlobalPoolingShapes()983*c217d954SCole Faust GlobalPoolingShapes() 984*c217d954SCole Faust : ShapeDataset("Shape", 985*c217d954SCole Faust { 986*c217d954SCole Faust // Batch size 1 987*c217d954SCole Faust TensorShape{ 9U, 9U }, 988*c217d954SCole Faust TensorShape{ 13U, 13U, 2U }, 989*c217d954SCole Faust TensorShape{ 27U, 27U, 1U, 3U }, 990*c217d954SCole Faust // Batch size 4 991*c217d954SCole Faust TensorShape{ 31U, 31U, 3U, 4U }, 992*c217d954SCole Faust TensorShape{ 34U, 34U, 2U, 4U } 993*c217d954SCole Faust }) 994*c217d954SCole Faust { 995*c217d954SCole Faust } 996*c217d954SCole Faust }; 997*c217d954SCole Faust /** Data set containing tiny softmax layer shapes. */ 998*c217d954SCole Faust class SoftmaxLayerTinyShapes final : public ShapeDataset 999*c217d954SCole Faust { 1000*c217d954SCole Faust public: SoftmaxLayerTinyShapes()1001*c217d954SCole Faust SoftmaxLayerTinyShapes() 1002*c217d954SCole Faust : ShapeDataset("Shape", 1003*c217d954SCole Faust { 1004*c217d954SCole Faust TensorShape{ 9U, 9U }, 1005*c217d954SCole Faust TensorShape{ 128U, 10U }, 1006*c217d954SCole Faust }) 1007*c217d954SCole Faust { 1008*c217d954SCole Faust } 1009*c217d954SCole Faust }; 1010*c217d954SCole Faust 1011*c217d954SCole Faust /** Data set containing small softmax layer shapes. */ 1012*c217d954SCole Faust class SoftmaxLayerSmallShapes final : public ShapeDataset 1013*c217d954SCole Faust { 1014*c217d954SCole Faust public: SoftmaxLayerSmallShapes()1015*c217d954SCole Faust SoftmaxLayerSmallShapes() 1016*c217d954SCole Faust : ShapeDataset("Shape", 1017*c217d954SCole Faust { 1018*c217d954SCole Faust TensorShape{ 1U, 9U }, 1019*c217d954SCole Faust TensorShape{ 256U, 10U }, 1020*c217d954SCole Faust TensorShape{ 353U, 8U }, 1021*c217d954SCole Faust TensorShape{ 781U, 5U }, 1022*c217d954SCole Faust }) 1023*c217d954SCole Faust { 1024*c217d954SCole Faust } 1025*c217d954SCole Faust }; 1026*c217d954SCole Faust 1027*c217d954SCole Faust /** Data set containing large softmax layer shapes. */ 1028*c217d954SCole Faust class SoftmaxLayerLargeShapes final : public ShapeDataset 1029*c217d954SCole Faust { 1030*c217d954SCole Faust public: SoftmaxLayerLargeShapes()1031*c217d954SCole Faust SoftmaxLayerLargeShapes() 1032*c217d954SCole Faust : ShapeDataset("Shape", 1033*c217d954SCole Faust { 1034*c217d954SCole Faust TensorShape{ 1000U, 10U } 1035*c217d954SCole Faust 1036*c217d954SCole Faust }) 1037*c217d954SCole Faust { 1038*c217d954SCole Faust } 1039*c217d954SCole Faust }; 1040*c217d954SCole Faust 1041*c217d954SCole Faust /** Data set containing large and small softmax layer 4D shapes. */ 1042*c217d954SCole Faust class SoftmaxLayer4DShapes final : public ShapeDataset 1043*c217d954SCole Faust { 1044*c217d954SCole Faust public: SoftmaxLayer4DShapes()1045*c217d954SCole Faust SoftmaxLayer4DShapes() 1046*c217d954SCole Faust : ShapeDataset("Shape", 1047*c217d954SCole Faust { 1048*c217d954SCole Faust TensorShape{ 9U, 9U, 9U, 9U }, 1049*c217d954SCole Faust TensorShape{ 31U, 10U, 1U, 9U }, 1050*c217d954SCole Faust }) 1051*c217d954SCole Faust { 1052*c217d954SCole Faust } 1053*c217d954SCole Faust }; 1054*c217d954SCole Faust 1055*c217d954SCole Faust /** Data set containing 2D tensor shapes relative to an image size. */ 1056*c217d954SCole Faust class SmallImageShapes final : public ShapeDataset 1057*c217d954SCole Faust { 1058*c217d954SCole Faust public: SmallImageShapes()1059*c217d954SCole Faust SmallImageShapes() 1060*c217d954SCole Faust : ShapeDataset("Shape", 1061*c217d954SCole Faust { 1062*c217d954SCole Faust TensorShape{ 640U, 480U }, 1063*c217d954SCole Faust TensorShape{ 800U, 600U }, 1064*c217d954SCole Faust }) 1065*c217d954SCole Faust { 1066*c217d954SCole Faust } 1067*c217d954SCole Faust }; 1068*c217d954SCole Faust 1069*c217d954SCole Faust /** Data set containing 2D tensor shapes relative to an image size. */ 1070*c217d954SCole Faust class LargeImageShapes final : public ShapeDataset 1071*c217d954SCole Faust { 1072*c217d954SCole Faust public: LargeImageShapes()1073*c217d954SCole Faust LargeImageShapes() 1074*c217d954SCole Faust : ShapeDataset("Shape", 1075*c217d954SCole Faust { 1076*c217d954SCole Faust TensorShape{ 1920U, 1080U }, 1077*c217d954SCole Faust TensorShape{ 2560U, 1536U }, 1078*c217d954SCole Faust TensorShape{ 3584U, 2048U } 1079*c217d954SCole Faust }) 1080*c217d954SCole Faust { 1081*c217d954SCole Faust } 1082*c217d954SCole Faust }; 1083*c217d954SCole Faust 1084*c217d954SCole Faust /** Data set containing small YOLO tensor shapes. */ 1085*c217d954SCole Faust class SmallYOLOShapes final : public ShapeDataset 1086*c217d954SCole Faust { 1087*c217d954SCole Faust public: SmallYOLOShapes()1088*c217d954SCole Faust SmallYOLOShapes() 1089*c217d954SCole Faust : ShapeDataset("Shape", 1090*c217d954SCole Faust { 1091*c217d954SCole Faust // Batch size 1 1092*c217d954SCole Faust TensorShape{ 11U, 11U, 270U }, 1093*c217d954SCole Faust TensorShape{ 27U, 13U, 90U }, 1094*c217d954SCole Faust TensorShape{ 13U, 12U, 45U, 2U }, 1095*c217d954SCole Faust }) 1096*c217d954SCole Faust { 1097*c217d954SCole Faust } 1098*c217d954SCole Faust }; 1099*c217d954SCole Faust 1100*c217d954SCole Faust /** Data set containing large YOLO tensor shapes. */ 1101*c217d954SCole Faust class LargeYOLOShapes final : public ShapeDataset 1102*c217d954SCole Faust { 1103*c217d954SCole Faust public: LargeYOLOShapes()1104*c217d954SCole Faust LargeYOLOShapes() 1105*c217d954SCole Faust : ShapeDataset("Shape", 1106*c217d954SCole Faust { 1107*c217d954SCole Faust TensorShape{ 24U, 23U, 270U }, 1108*c217d954SCole Faust TensorShape{ 51U, 63U, 90U, 2U }, 1109*c217d954SCole Faust TensorShape{ 76U, 91U, 45U, 3U } 1110*c217d954SCole Faust }) 1111*c217d954SCole Faust { 1112*c217d954SCole Faust } 1113*c217d954SCole Faust }; 1114*c217d954SCole Faust 1115*c217d954SCole Faust /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel */ 1116*c217d954SCole Faust class SmallGEMMReshape2DShapes final : public ShapeDataset 1117*c217d954SCole Faust { 1118*c217d954SCole Faust public: SmallGEMMReshape2DShapes()1119*c217d954SCole Faust SmallGEMMReshape2DShapes() 1120*c217d954SCole Faust : ShapeDataset("Shape", 1121*c217d954SCole Faust { 1122*c217d954SCole Faust TensorShape{ 63U, 72U }, 1123*c217d954SCole Faust }) 1124*c217d954SCole Faust { 1125*c217d954SCole Faust } 1126*c217d954SCole Faust }; 1127*c217d954SCole Faust 1128*c217d954SCole Faust /** Data set containing small tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */ 1129*c217d954SCole Faust class SmallGEMMReshape3DShapes final : public ShapeDataset 1130*c217d954SCole Faust { 1131*c217d954SCole Faust public: SmallGEMMReshape3DShapes()1132*c217d954SCole Faust SmallGEMMReshape3DShapes() 1133*c217d954SCole Faust : ShapeDataset("Shape", 1134*c217d954SCole Faust { 1135*c217d954SCole Faust TensorShape{ 63U, 9U, 8U }, 1136*c217d954SCole Faust }) 1137*c217d954SCole Faust { 1138*c217d954SCole Faust } 1139*c217d954SCole Faust }; 1140*c217d954SCole Faust 1141*c217d954SCole Faust /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel */ 1142*c217d954SCole Faust class LargeGEMMReshape2DShapes final : public ShapeDataset 1143*c217d954SCole Faust { 1144*c217d954SCole Faust public: LargeGEMMReshape2DShapes()1145*c217d954SCole Faust LargeGEMMReshape2DShapes() 1146*c217d954SCole Faust : ShapeDataset("Shape", 1147*c217d954SCole Faust { 1148*c217d954SCole Faust TensorShape{ 16U, 27U }, 1149*c217d954SCole Faust TensorShape{ 345U, 171U } 1150*c217d954SCole Faust }) 1151*c217d954SCole Faust { 1152*c217d954SCole Faust } 1153*c217d954SCole Faust }; 1154*c217d954SCole Faust 1155*c217d954SCole Faust /** Data set containing large tensor shapes to be used with the GEMM reshaping kernel when the input has to be reinterpreted as 3D */ 1156*c217d954SCole Faust class LargeGEMMReshape3DShapes final : public ShapeDataset 1157*c217d954SCole Faust { 1158*c217d954SCole Faust public: LargeGEMMReshape3DShapes()1159*c217d954SCole Faust LargeGEMMReshape3DShapes() 1160*c217d954SCole Faust : ShapeDataset("Shape", 1161*c217d954SCole Faust { 1162*c217d954SCole Faust TensorShape{ 16U, 3U, 9U }, 1163*c217d954SCole Faust TensorShape{ 345U, 34U, 18U } 1164*c217d954SCole Faust }) 1165*c217d954SCole Faust { 1166*c217d954SCole Faust } 1167*c217d954SCole Faust }; 1168*c217d954SCole Faust 1169*c217d954SCole Faust /** Data set containing small 2D tensor shapes. */ 1170*c217d954SCole Faust class Small2DNonMaxSuppressionShapes final : public ShapeDataset 1171*c217d954SCole Faust { 1172*c217d954SCole Faust public: Small2DNonMaxSuppressionShapes()1173*c217d954SCole Faust Small2DNonMaxSuppressionShapes() 1174*c217d954SCole Faust : ShapeDataset("Shape", 1175*c217d954SCole Faust { 1176*c217d954SCole Faust TensorShape{ 4U, 7U }, 1177*c217d954SCole Faust TensorShape{ 4U, 13U }, 1178*c217d954SCole Faust TensorShape{ 4U, 64U } 1179*c217d954SCole Faust }) 1180*c217d954SCole Faust { 1181*c217d954SCole Faust } 1182*c217d954SCole Faust }; 1183*c217d954SCole Faust 1184*c217d954SCole Faust /** Data set containing large 2D tensor shapes. */ 1185*c217d954SCole Faust class Large2DNonMaxSuppressionShapes final : public ShapeDataset 1186*c217d954SCole Faust { 1187*c217d954SCole Faust public: Large2DNonMaxSuppressionShapes()1188*c217d954SCole Faust Large2DNonMaxSuppressionShapes() 1189*c217d954SCole Faust : ShapeDataset("Shape", 1190*c217d954SCole Faust { 1191*c217d954SCole Faust TensorShape{ 4U, 113U } 1192*c217d954SCole Faust }) 1193*c217d954SCole Faust { 1194*c217d954SCole Faust } 1195*c217d954SCole Faust }; 1196*c217d954SCole Faust 1197*c217d954SCole Faust } // namespace datasets 1198*c217d954SCole Faust } // namespace test 1199*c217d954SCole Faust } // namespace arm_compute 1200*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_SHAPE_DATASETS_H */ 1201