1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "RefLayerSupport.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <LayerSupportCommon.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/LayerSupportRules.hpp>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker #include <array>
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker namespace armnn
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker namespace
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker template<typename Float32Func, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeRef(Optional<std::string &> reasonIfUnsupported,DataType dataType,Float32Func floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)27*89c4ff92SAndroid Build Coastguard Worker bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28*89c4ff92SAndroid Build Coastguard Worker DataType dataType,
29*89c4ff92SAndroid Build Coastguard Worker Float32Func floatFuncPtr,
30*89c4ff92SAndroid Build Coastguard Worker Uint8Func uint8FuncPtr,
31*89c4ff92SAndroid Build Coastguard Worker Params&&... params)
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34*89c4ff92SAndroid Build Coastguard Worker dataType,
35*89c4ff92SAndroid Build Coastguard Worker &FalseFunc<Params...>,
36*89c4ff92SAndroid Build Coastguard Worker floatFuncPtr,
37*89c4ff92SAndroid Build Coastguard Worker uint8FuncPtr,
38*89c4ff92SAndroid Build Coastguard Worker &FalseFunc<Params...>,
39*89c4ff92SAndroid Build Coastguard Worker &FalseFunc<Params...>,
40*89c4ff92SAndroid Build Coastguard Worker std::forward<Params>(params)...);
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker namespace
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker
CreateIncorrectDimensionsErrorMsg(unsigned int expected,unsigned int actual,std::string & layerStr,std::string & tensorName)48*89c4ff92SAndroid Build Coastguard Worker std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49*89c4ff92SAndroid Build Coastguard Worker unsigned int actual,
50*89c4ff92SAndroid Build Coastguard Worker std::string& layerStr,
51*89c4ff92SAndroid Build Coastguard Worker std::string& tensorName)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54*89c4ff92SAndroid Build Coastguard Worker " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker return errorMsg;
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
60*89c4ff92SAndroid Build Coastguard Worker
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmInputParamsInfo,Optional<std::string &> reasonIfUnsupported) const61*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorInfo>& infos,
63*89c4ff92SAndroid Build Coastguard Worker const BaseDescriptor& descriptor,
64*89c4ff92SAndroid Build Coastguard Worker const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65*89c4ff92SAndroid Build Coastguard Worker const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker switch (type)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker case LayerType::Activation:
71*89c4ff92SAndroid Build Coastguard Worker return IsActivationSupported(infos[0],
72*89c4ff92SAndroid Build Coastguard Worker infos[1],
73*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
75*89c4ff92SAndroid Build Coastguard Worker case LayerType::Addition:
76*89c4ff92SAndroid Build Coastguard Worker return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77*89c4ff92SAndroid Build Coastguard Worker case LayerType::ArgMinMax:
78*89c4ff92SAndroid Build Coastguard Worker return IsArgMinMaxSupported(infos[0],
79*89c4ff92SAndroid Build Coastguard Worker infos[1],
80*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
82*89c4ff92SAndroid Build Coastguard Worker case LayerType::BatchMatMul:
83*89c4ff92SAndroid Build Coastguard Worker return IsBatchMatMulSupported(infos[0],
84*89c4ff92SAndroid Build Coastguard Worker infos[1],
85*89c4ff92SAndroid Build Coastguard Worker infos[2],
86*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
88*89c4ff92SAndroid Build Coastguard Worker case LayerType::BatchNormalization:
89*89c4ff92SAndroid Build Coastguard Worker return IsBatchNormalizationSupported(infos[0],
90*89c4ff92SAndroid Build Coastguard Worker infos[1],
91*89c4ff92SAndroid Build Coastguard Worker infos[2],
92*89c4ff92SAndroid Build Coastguard Worker infos[3],
93*89c4ff92SAndroid Build Coastguard Worker infos[4],
94*89c4ff92SAndroid Build Coastguard Worker infos[5],
95*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96*89c4ff92SAndroid Build Coastguard Worker (&descriptor)),
97*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
98*89c4ff92SAndroid Build Coastguard Worker case LayerType::BatchToSpaceNd:
99*89c4ff92SAndroid Build Coastguard Worker return IsBatchToSpaceNdSupported(infos[0],
100*89c4ff92SAndroid Build Coastguard Worker infos[1],
101*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
103*89c4ff92SAndroid Build Coastguard Worker case LayerType::Comparison:
104*89c4ff92SAndroid Build Coastguard Worker return IsComparisonSupported(infos[0],
105*89c4ff92SAndroid Build Coastguard Worker infos[1],
106*89c4ff92SAndroid Build Coastguard Worker infos[2],
107*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
109*89c4ff92SAndroid Build Coastguard Worker case LayerType::Concat:
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker std::vector<const TensorInfo*> inputInfos;
112*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < (infos.size() - 1); i++)
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker inputInfos.push_back(&infos[i]);
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker return IsConcatSupported(inputInfos,
117*89c4ff92SAndroid Build Coastguard Worker infos[infos.size() - 1],
118*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker case LayerType::Constant:
122*89c4ff92SAndroid Build Coastguard Worker return IsConstantSupported(infos[0], reasonIfUnsupported);
123*89c4ff92SAndroid Build Coastguard Worker case LayerType::ConvertFp16ToFp32:
124*89c4ff92SAndroid Build Coastguard Worker return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
125*89c4ff92SAndroid Build Coastguard Worker case LayerType::ConvertFp32ToFp16:
126*89c4ff92SAndroid Build Coastguard Worker return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127*89c4ff92SAndroid Build Coastguard Worker case LayerType::Convolution2d:
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker return IsConvolution2dSupported(infos[0],
139*89c4ff92SAndroid Build Coastguard Worker infos[1],
140*89c4ff92SAndroid Build Coastguard Worker desc,
141*89c4ff92SAndroid Build Coastguard Worker infos[2],
142*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
143*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker else
146*89c4ff92SAndroid Build Coastguard Worker {
147*89c4ff92SAndroid Build Coastguard Worker return IsConvolution2dSupported(infos[0],
148*89c4ff92SAndroid Build Coastguard Worker infos[1],
149*89c4ff92SAndroid Build Coastguard Worker desc,
150*89c4ff92SAndroid Build Coastguard Worker infos[2],
151*89c4ff92SAndroid Build Coastguard Worker infos[3],
152*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker }
155*89c4ff92SAndroid Build Coastguard Worker case LayerType::DepthToSpace:
156*89c4ff92SAndroid Build Coastguard Worker return IsDepthToSpaceSupported(infos[0],
157*89c4ff92SAndroid Build Coastguard Worker infos[1],
158*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
160*89c4ff92SAndroid Build Coastguard Worker case LayerType::DepthwiseConvolution2d:
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker
168*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker return IsDepthwiseConvolutionSupported(infos[0],
172*89c4ff92SAndroid Build Coastguard Worker infos[1],
173*89c4ff92SAndroid Build Coastguard Worker desc,
174*89c4ff92SAndroid Build Coastguard Worker infos[2],
175*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
176*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker else
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker return IsDepthwiseConvolutionSupported(infos[0],
181*89c4ff92SAndroid Build Coastguard Worker infos[1],
182*89c4ff92SAndroid Build Coastguard Worker desc,
183*89c4ff92SAndroid Build Coastguard Worker infos[2],
184*89c4ff92SAndroid Build Coastguard Worker infos[3],
185*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker }
188*89c4ff92SAndroid Build Coastguard Worker case LayerType::Dequantize:
189*89c4ff92SAndroid Build Coastguard Worker return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190*89c4ff92SAndroid Build Coastguard Worker case LayerType::Division:
191*89c4ff92SAndroid Build Coastguard Worker return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
192*89c4ff92SAndroid Build Coastguard Worker case LayerType::ElementwiseBinary:
193*89c4ff92SAndroid Build Coastguard Worker {
194*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 7> supportedTypes =
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
197*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
198*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
199*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
200*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
201*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
202*89c4ff92SAndroid Build Coastguard Worker };
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
205*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input type not supported");
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input type not supported");
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: output type not supported");
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input types not matching");
216*89c4ff92SAndroid Build Coastguard Worker
217*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input and output types not matching");
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker return supported;
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker case LayerType::ElementwiseUnary:
223*89c4ff92SAndroid Build Coastguard Worker return IsElementwiseUnarySupported(infos[0],
224*89c4ff92SAndroid Build Coastguard Worker infos[1],
225*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
227*89c4ff92SAndroid Build Coastguard Worker case LayerType::Fill:
228*89c4ff92SAndroid Build Coastguard Worker return IsFillSupported(infos[0],
229*89c4ff92SAndroid Build Coastguard Worker infos[1],
230*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
232*89c4ff92SAndroid Build Coastguard Worker case LayerType::Floor:
233*89c4ff92SAndroid Build Coastguard Worker return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234*89c4ff92SAndroid Build Coastguard Worker case LayerType::FullyConnected:
235*89c4ff92SAndroid Build Coastguard Worker return IsFullyConnectedSupported(infos[0],
236*89c4ff92SAndroid Build Coastguard Worker infos[1],
237*89c4ff92SAndroid Build Coastguard Worker infos[2],
238*89c4ff92SAndroid Build Coastguard Worker infos[3],
239*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
241*89c4ff92SAndroid Build Coastguard Worker case LayerType::Gather:
242*89c4ff92SAndroid Build Coastguard Worker return IsGatherSupported(infos[0],
243*89c4ff92SAndroid Build Coastguard Worker infos[1],
244*89c4ff92SAndroid Build Coastguard Worker infos[2],
245*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
247*89c4ff92SAndroid Build Coastguard Worker case LayerType::GatherNd:
248*89c4ff92SAndroid Build Coastguard Worker return IsGatherNdSupported(infos[0],
249*89c4ff92SAndroid Build Coastguard Worker infos[1],
250*89c4ff92SAndroid Build Coastguard Worker infos[2],
251*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
252*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
253*89c4ff92SAndroid Build Coastguard Worker return IsInputSupported(infos[0], reasonIfUnsupported);
254*89c4ff92SAndroid Build Coastguard Worker case LayerType::InstanceNormalization:
255*89c4ff92SAndroid Build Coastguard Worker return IsInstanceNormalizationSupported(infos[0],
256*89c4ff92SAndroid Build Coastguard Worker infos[1],
257*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258*89c4ff92SAndroid Build Coastguard Worker (&descriptor)),
259*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
260*89c4ff92SAndroid Build Coastguard Worker case LayerType::L2Normalization:
261*89c4ff92SAndroid Build Coastguard Worker return IsL2NormalizationSupported(infos[0],
262*89c4ff92SAndroid Build Coastguard Worker infos[1],
263*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
265*89c4ff92SAndroid Build Coastguard Worker case LayerType::LogicalBinary:
266*89c4ff92SAndroid Build Coastguard Worker return IsLogicalBinarySupported(infos[0],
267*89c4ff92SAndroid Build Coastguard Worker infos[1],
268*89c4ff92SAndroid Build Coastguard Worker infos[2],
269*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
271*89c4ff92SAndroid Build Coastguard Worker case LayerType::LogSoftmax:
272*89c4ff92SAndroid Build Coastguard Worker return IsLogSoftmaxSupported(infos[0],
273*89c4ff92SAndroid Build Coastguard Worker infos[1],
274*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
276*89c4ff92SAndroid Build Coastguard Worker case LayerType::Lstm:
277*89c4ff92SAndroid Build Coastguard Worker return IsLstmSupported(infos[0],
278*89c4ff92SAndroid Build Coastguard Worker infos[1],
279*89c4ff92SAndroid Build Coastguard Worker infos[2],
280*89c4ff92SAndroid Build Coastguard Worker infos[3],
281*89c4ff92SAndroid Build Coastguard Worker infos[4],
282*89c4ff92SAndroid Build Coastguard Worker infos[5],
283*89c4ff92SAndroid Build Coastguard Worker infos[6],
284*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285*89c4ff92SAndroid Build Coastguard Worker lstmParamsInfo.value(),
286*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
287*89c4ff92SAndroid Build Coastguard Worker case LayerType::QLstm:
288*89c4ff92SAndroid Build Coastguard Worker return IsQLstmSupported(infos[0],
289*89c4ff92SAndroid Build Coastguard Worker infos[1],
290*89c4ff92SAndroid Build Coastguard Worker infos[2],
291*89c4ff92SAndroid Build Coastguard Worker infos[3],
292*89c4ff92SAndroid Build Coastguard Worker infos[4],
293*89c4ff92SAndroid Build Coastguard Worker infos[5],
294*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295*89c4ff92SAndroid Build Coastguard Worker lstmParamsInfo.value(),
296*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
297*89c4ff92SAndroid Build Coastguard Worker case LayerType::Maximum:
298*89c4ff92SAndroid Build Coastguard Worker return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299*89c4ff92SAndroid Build Coastguard Worker case LayerType::Mean:
300*89c4ff92SAndroid Build Coastguard Worker return IsMeanSupported(infos[0],
301*89c4ff92SAndroid Build Coastguard Worker infos[1],
302*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
304*89c4ff92SAndroid Build Coastguard Worker case LayerType::Minimum:
305*89c4ff92SAndroid Build Coastguard Worker return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306*89c4ff92SAndroid Build Coastguard Worker case LayerType::Multiplication:
307*89c4ff92SAndroid Build Coastguard Worker return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308*89c4ff92SAndroid Build Coastguard Worker case LayerType::Normalization:
309*89c4ff92SAndroid Build Coastguard Worker return IsNormalizationSupported(infos[0],
310*89c4ff92SAndroid Build Coastguard Worker infos[1],
311*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
313*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output:
314*89c4ff92SAndroid Build Coastguard Worker return IsOutputSupported(infos[0], reasonIfUnsupported);
315*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pad:
316*89c4ff92SAndroid Build Coastguard Worker return IsPadSupported(infos[0],
317*89c4ff92SAndroid Build Coastguard Worker infos[1],
318*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
320*89c4ff92SAndroid Build Coastguard Worker case LayerType::Permute:
321*89c4ff92SAndroid Build Coastguard Worker return IsPermuteSupported(infos[0],
322*89c4ff92SAndroid Build Coastguard Worker infos[1],
323*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
325*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pooling2d:
326*89c4ff92SAndroid Build Coastguard Worker return IsPooling2dSupported(infos[0],
327*89c4ff92SAndroid Build Coastguard Worker infos[1],
328*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
330*89c4ff92SAndroid Build Coastguard Worker case LayerType::Prelu:
331*89c4ff92SAndroid Build Coastguard Worker return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332*89c4ff92SAndroid Build Coastguard Worker case LayerType::Quantize:
333*89c4ff92SAndroid Build Coastguard Worker return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334*89c4ff92SAndroid Build Coastguard Worker case LayerType::Reshape:
335*89c4ff92SAndroid Build Coastguard Worker return IsReshapeSupported(infos[0],
336*89c4ff92SAndroid Build Coastguard Worker infos[1],
337*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
339*89c4ff92SAndroid Build Coastguard Worker case LayerType::Resize:
340*89c4ff92SAndroid Build Coastguard Worker return IsResizeSupported(infos[0],
341*89c4ff92SAndroid Build Coastguard Worker infos[1],
342*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
344*89c4ff92SAndroid Build Coastguard Worker case LayerType::Reduce:
345*89c4ff92SAndroid Build Coastguard Worker return IsReduceSupported(infos[0],
346*89c4ff92SAndroid Build Coastguard Worker infos[1],
347*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
348*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
349*89c4ff92SAndroid Build Coastguard Worker case LayerType::Slice:
350*89c4ff92SAndroid Build Coastguard Worker return IsSliceSupported(infos[0],
351*89c4ff92SAndroid Build Coastguard Worker infos[1],
352*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
353*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
354*89c4ff92SAndroid Build Coastguard Worker case LayerType::Softmax:
355*89c4ff92SAndroid Build Coastguard Worker return IsSoftmaxSupported(infos[0],
356*89c4ff92SAndroid Build Coastguard Worker infos[1],
357*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
358*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
359*89c4ff92SAndroid Build Coastguard Worker case LayerType::SpaceToBatchNd:
360*89c4ff92SAndroid Build Coastguard Worker return IsSpaceToBatchNdSupported(infos[0],
361*89c4ff92SAndroid Build Coastguard Worker infos[1],
362*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
363*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
364*89c4ff92SAndroid Build Coastguard Worker case LayerType::SpaceToDepth:
365*89c4ff92SAndroid Build Coastguard Worker return IsSpaceToDepthSupported(infos[0],
366*89c4ff92SAndroid Build Coastguard Worker infos[1],
367*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
368*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
369*89c4ff92SAndroid Build Coastguard Worker case LayerType::Splitter:
370*89c4ff92SAndroid Build Coastguard Worker {
371*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorInfo> outputInfos;
372*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 1; i < infos.size(); i++)
373*89c4ff92SAndroid Build Coastguard Worker {
374*89c4ff92SAndroid Build Coastguard Worker outputInfos.push_back(infos[i]);
375*89c4ff92SAndroid Build Coastguard Worker }
376*89c4ff92SAndroid Build Coastguard Worker return IsSplitterSupported(infos[0],
377*89c4ff92SAndroid Build Coastguard Worker {outputInfos.begin(), outputInfos.end()},
378*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
379*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
380*89c4ff92SAndroid Build Coastguard Worker }
381*89c4ff92SAndroid Build Coastguard Worker case LayerType::Stack:
382*89c4ff92SAndroid Build Coastguard Worker {
383*89c4ff92SAndroid Build Coastguard Worker std::vector<const TensorInfo*> inputInfos;
384*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < infos.size() - 1; i++)
385*89c4ff92SAndroid Build Coastguard Worker {
386*89c4ff92SAndroid Build Coastguard Worker inputInfos.push_back(&infos[i]);
387*89c4ff92SAndroid Build Coastguard Worker }
388*89c4ff92SAndroid Build Coastguard Worker return IsStackSupported(inputInfos,
389*89c4ff92SAndroid Build Coastguard Worker infos[infos.size() - 1],
390*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
391*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
392*89c4ff92SAndroid Build Coastguard Worker }
393*89c4ff92SAndroid Build Coastguard Worker case LayerType::StridedSlice:
394*89c4ff92SAndroid Build Coastguard Worker return IsStridedSliceSupported(infos[0],
395*89c4ff92SAndroid Build Coastguard Worker infos[1],
396*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
397*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
398*89c4ff92SAndroid Build Coastguard Worker case LayerType::Subtraction:
399*89c4ff92SAndroid Build Coastguard Worker return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
400*89c4ff92SAndroid Build Coastguard Worker case LayerType::Transpose:
401*89c4ff92SAndroid Build Coastguard Worker return IsTransposeSupported(infos[0],
402*89c4ff92SAndroid Build Coastguard Worker infos[1],
403*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
404*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
405*89c4ff92SAndroid Build Coastguard Worker case LayerType::TransposeConvolution2d:
406*89c4ff92SAndroid Build Coastguard Worker {
407*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
408*89c4ff92SAndroid Build Coastguard Worker {
409*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
410*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
411*89c4ff92SAndroid Build Coastguard Worker }
412*89c4ff92SAndroid Build Coastguard Worker
413*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
414*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
415*89c4ff92SAndroid Build Coastguard Worker {
416*89c4ff92SAndroid Build Coastguard Worker return IsTransposeConvolution2dSupported(infos[0],
417*89c4ff92SAndroid Build Coastguard Worker infos[1],
418*89c4ff92SAndroid Build Coastguard Worker desc,
419*89c4ff92SAndroid Build Coastguard Worker infos[2],
420*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
421*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
422*89c4ff92SAndroid Build Coastguard Worker }
423*89c4ff92SAndroid Build Coastguard Worker else
424*89c4ff92SAndroid Build Coastguard Worker {
425*89c4ff92SAndroid Build Coastguard Worker return IsTransposeConvolution2dSupported(infos[0],
426*89c4ff92SAndroid Build Coastguard Worker infos[1],
427*89c4ff92SAndroid Build Coastguard Worker desc,
428*89c4ff92SAndroid Build Coastguard Worker infos[2],
429*89c4ff92SAndroid Build Coastguard Worker infos[3],
430*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
431*89c4ff92SAndroid Build Coastguard Worker }
432*89c4ff92SAndroid Build Coastguard Worker }
433*89c4ff92SAndroid Build Coastguard Worker case LayerType::Cast:
434*89c4ff92SAndroid Build Coastguard Worker return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
435*89c4ff92SAndroid Build Coastguard Worker case LayerType::ChannelShuffle:
436*89c4ff92SAndroid Build Coastguard Worker return IsChannelShuffleSupported(infos[0],
437*89c4ff92SAndroid Build Coastguard Worker infos[1],
438*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
439*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
440*89c4ff92SAndroid Build Coastguard Worker case LayerType::Convolution3d:
441*89c4ff92SAndroid Build Coastguard Worker {
442*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
443*89c4ff92SAndroid Build Coastguard Worker {
444*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
445*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
446*89c4ff92SAndroid Build Coastguard Worker }
447*89c4ff92SAndroid Build Coastguard Worker
448*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
449*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
450*89c4ff92SAndroid Build Coastguard Worker {
451*89c4ff92SAndroid Build Coastguard Worker return IsConvolution3dSupported(infos[0],
452*89c4ff92SAndroid Build Coastguard Worker infos[1],
453*89c4ff92SAndroid Build Coastguard Worker desc,
454*89c4ff92SAndroid Build Coastguard Worker infos[2],
455*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
456*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
457*89c4ff92SAndroid Build Coastguard Worker }
458*89c4ff92SAndroid Build Coastguard Worker else
459*89c4ff92SAndroid Build Coastguard Worker {
460*89c4ff92SAndroid Build Coastguard Worker return IsConvolution3dSupported(infos[0],
461*89c4ff92SAndroid Build Coastguard Worker infos[1],
462*89c4ff92SAndroid Build Coastguard Worker desc,
463*89c4ff92SAndroid Build Coastguard Worker infos[2],
464*89c4ff92SAndroid Build Coastguard Worker infos[3],
465*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
466*89c4ff92SAndroid Build Coastguard Worker }
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker case LayerType::Debug:
469*89c4ff92SAndroid Build Coastguard Worker return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
470*89c4ff92SAndroid Build Coastguard Worker case LayerType::DetectionPostProcess:
471*89c4ff92SAndroid Build Coastguard Worker return IsDetectionPostProcessSupported(infos[0],
472*89c4ff92SAndroid Build Coastguard Worker infos[1],
473*89c4ff92SAndroid Build Coastguard Worker infos[2],
474*89c4ff92SAndroid Build Coastguard Worker infos[3],
475*89c4ff92SAndroid Build Coastguard Worker infos[4],
476*89c4ff92SAndroid Build Coastguard Worker infos[5],
477*89c4ff92SAndroid Build Coastguard Worker infos[6],
478*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
479*89c4ff92SAndroid Build Coastguard Worker (&descriptor)),
480*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
481*89c4ff92SAndroid Build Coastguard Worker case LayerType::FakeQuantization:
482*89c4ff92SAndroid Build Coastguard Worker return IsFakeQuantizationSupported(infos[0],
483*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
484*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
485*89c4ff92SAndroid Build Coastguard Worker case LayerType::MemCopy:
486*89c4ff92SAndroid Build Coastguard Worker return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
487*89c4ff92SAndroid Build Coastguard Worker case LayerType::Rank:
488*89c4ff92SAndroid Build Coastguard Worker return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
489*89c4ff92SAndroid Build Coastguard Worker case LayerType::Shape:
490*89c4ff92SAndroid Build Coastguard Worker return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
491*89c4ff92SAndroid Build Coastguard Worker case LayerType::UnidirectionalSequenceLstm:
492*89c4ff92SAndroid Build Coastguard Worker {
493*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 6)
494*89c4ff92SAndroid Build Coastguard Worker {
495*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
496*89c4ff92SAndroid Build Coastguard Worker "should be of format: {input, outputStateIn, cellStateIn, "
497*89c4ff92SAndroid Build Coastguard Worker "hiddenStateOutputVal, cellStateOutputVal, output}");
498*89c4ff92SAndroid Build Coastguard Worker }
499*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
500*89c4ff92SAndroid Build Coastguard Worker return IsUnidirectionalSequenceLstmSupported(infos[0],
501*89c4ff92SAndroid Build Coastguard Worker infos[1],
502*89c4ff92SAndroid Build Coastguard Worker infos[2],
503*89c4ff92SAndroid Build Coastguard Worker infos[3],
504*89c4ff92SAndroid Build Coastguard Worker infos[4],
505*89c4ff92SAndroid Build Coastguard Worker infos[5],
506*89c4ff92SAndroid Build Coastguard Worker desc,
507*89c4ff92SAndroid Build Coastguard Worker lstmParamsInfo.value(),
508*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
509*89c4ff92SAndroid Build Coastguard Worker }
510*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pooling3d:
511*89c4ff92SAndroid Build Coastguard Worker return IsPooling3dSupported(infos[0],
512*89c4ff92SAndroid Build Coastguard Worker infos[1],
513*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
514*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
515*89c4ff92SAndroid Build Coastguard Worker case LayerType::Map:
516*89c4ff92SAndroid Build Coastguard Worker return true;
517*89c4ff92SAndroid Build Coastguard Worker case LayerType::Unmap:
518*89c4ff92SAndroid Build Coastguard Worker return true;
519*89c4ff92SAndroid Build Coastguard Worker case LayerType::MemImport:
520*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
521*89c4ff92SAndroid Build Coastguard Worker case LayerType::Merge:
522*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
523*89c4ff92SAndroid Build Coastguard Worker case LayerType::QuantizedLstm:
524*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
525*89c4ff92SAndroid Build Coastguard Worker infos[1],
526*89c4ff92SAndroid Build Coastguard Worker infos[2],
527*89c4ff92SAndroid Build Coastguard Worker infos[3],
528*89c4ff92SAndroid Build Coastguard Worker infos[4],
529*89c4ff92SAndroid Build Coastguard Worker quantizedLstmInputParamsInfo.value(),
530*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
531*89c4ff92SAndroid Build Coastguard Worker default:
532*89c4ff92SAndroid Build Coastguard Worker // layers not supported in neon by default:
533*89c4ff92SAndroid Build Coastguard Worker // precompiled, standin, switch
534*89c4ff92SAndroid Build Coastguard Worker return false;
535*89c4ff92SAndroid Build Coastguard Worker }
536*89c4ff92SAndroid Build Coastguard Worker }
537*89c4ff92SAndroid Build Coastguard Worker
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const538*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
539*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
540*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor& descriptor,
541*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
542*89c4ff92SAndroid Build Coastguard Worker {
543*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
544*89c4ff92SAndroid Build Coastguard Worker
545*89c4ff92SAndroid Build Coastguard Worker // Define supported types.
546*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes = {
547*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
548*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
549*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
550*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
551*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
552*89c4ff92SAndroid Build Coastguard Worker };
553*89c4ff92SAndroid Build Coastguard Worker
554*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
555*89c4ff92SAndroid Build Coastguard Worker "Reference activation: input type not supported.");
556*89c4ff92SAndroid Build Coastguard Worker
557*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
558*89c4ff92SAndroid Build Coastguard Worker "Reference activation: output type not supported.");
559*89c4ff92SAndroid Build Coastguard Worker
560*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
561*89c4ff92SAndroid Build Coastguard Worker "Reference activation: input and output types mismatched.");
562*89c4ff92SAndroid Build Coastguard Worker
563*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
564*89c4ff92SAndroid Build Coastguard Worker "Reference activation: input and output shapes are of different rank.");
565*89c4ff92SAndroid Build Coastguard Worker
566*89c4ff92SAndroid Build Coastguard Worker
567*89c4ff92SAndroid Build Coastguard Worker struct ActivationFunctionSupported : public Rule
568*89c4ff92SAndroid Build Coastguard Worker {
569*89c4ff92SAndroid Build Coastguard Worker ActivationFunctionSupported(const ActivationDescriptor& desc)
570*89c4ff92SAndroid Build Coastguard Worker {
571*89c4ff92SAndroid Build Coastguard Worker switch(desc.m_Function)
572*89c4ff92SAndroid Build Coastguard Worker {
573*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Abs:
574*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::BoundedReLu:
575*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Elu:
576*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::HardSwish:
577*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::LeakyReLu:
578*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Linear:
579*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::ReLu:
580*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Sigmoid:
581*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::SoftReLu:
582*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Sqrt:
583*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Square:
584*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::TanH:
585*89c4ff92SAndroid Build Coastguard Worker {
586*89c4ff92SAndroid Build Coastguard Worker m_Res = true;
587*89c4ff92SAndroid Build Coastguard Worker break;
588*89c4ff92SAndroid Build Coastguard Worker }
589*89c4ff92SAndroid Build Coastguard Worker default:
590*89c4ff92SAndroid Build Coastguard Worker {
591*89c4ff92SAndroid Build Coastguard Worker m_Res = false;
592*89c4ff92SAndroid Build Coastguard Worker break;
593*89c4ff92SAndroid Build Coastguard Worker }
594*89c4ff92SAndroid Build Coastguard Worker }
595*89c4ff92SAndroid Build Coastguard Worker }
596*89c4ff92SAndroid Build Coastguard Worker };
597*89c4ff92SAndroid Build Coastguard Worker
598*89c4ff92SAndroid Build Coastguard Worker // Function is supported
599*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
600*89c4ff92SAndroid Build Coastguard Worker "Reference activation: function not supported.");
601*89c4ff92SAndroid Build Coastguard Worker
602*89c4ff92SAndroid Build Coastguard Worker return supported;
603*89c4ff92SAndroid Build Coastguard Worker }
604*89c4ff92SAndroid Build Coastguard Worker
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const605*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
606*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
607*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
608*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
609*89c4ff92SAndroid Build Coastguard Worker {
610*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
611*89c4ff92SAndroid Build Coastguard Worker
612*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes = {
613*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
614*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
615*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
616*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
617*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
618*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
619*89c4ff92SAndroid Build Coastguard Worker };
620*89c4ff92SAndroid Build Coastguard Worker
621*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
622*89c4ff92SAndroid Build Coastguard Worker "Reference addition: input 0 is not a supported type.");
623*89c4ff92SAndroid Build Coastguard Worker
624*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
625*89c4ff92SAndroid Build Coastguard Worker "Reference addition: input 1 is not a supported type.");
626*89c4ff92SAndroid Build Coastguard Worker
627*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
628*89c4ff92SAndroid Build Coastguard Worker "Reference addition: output is not a supported type.");
629*89c4ff92SAndroid Build Coastguard Worker
630*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
631*89c4ff92SAndroid Build Coastguard Worker "Reference addition: input 0 and Input 1 types are mismatched");
632*89c4ff92SAndroid Build Coastguard Worker
633*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
634*89c4ff92SAndroid Build Coastguard Worker "Reference addition: input and output types are mismatched");
635*89c4ff92SAndroid Build Coastguard Worker
636*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
637*89c4ff92SAndroid Build Coastguard Worker "Reference addition: shapes are not suitable for implicit broadcast.");
638*89c4ff92SAndroid Build Coastguard Worker
639*89c4ff92SAndroid Build Coastguard Worker return supported;
640*89c4ff92SAndroid Build Coastguard Worker }
641*89c4ff92SAndroid Build Coastguard Worker
IsArgMinMaxSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & output,const armnn::ArgMinMaxDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const642*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
643*89c4ff92SAndroid Build Coastguard Worker const armnn::ArgMinMaxDescriptor &descriptor,
644*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string &> reasonIfUnsupported) const
645*89c4ff92SAndroid Build Coastguard Worker {
646*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
647*89c4ff92SAndroid Build Coastguard Worker
648*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 8> supportedInputTypes =
649*89c4ff92SAndroid Build Coastguard Worker {
650*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
651*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
652*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
653*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
654*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
655*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
656*89c4ff92SAndroid Build Coastguard Worker DataType::Signed64
657*89c4ff92SAndroid Build Coastguard Worker };
658*89c4ff92SAndroid Build Coastguard Worker
659*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,2> supportedOutputTypes = {
660*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
661*89c4ff92SAndroid Build Coastguard Worker DataType::Signed64
662*89c4ff92SAndroid Build Coastguard Worker };
663*89c4ff92SAndroid Build Coastguard Worker
664*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
665*89c4ff92SAndroid Build Coastguard Worker
666*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
667*89c4ff92SAndroid Build Coastguard Worker "Reference ArgMinMax: input is not a supported type.");
668*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
669*89c4ff92SAndroid Build Coastguard Worker "Reference ArgMinMax: output type not supported");
670*89c4ff92SAndroid Build Coastguard Worker
671*89c4ff92SAndroid Build Coastguard Worker return supported;
672*89c4ff92SAndroid Build Coastguard Worker }
673*89c4ff92SAndroid Build Coastguard Worker
IsBatchMatMulSupported(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const674*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
675*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputY,
676*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
677*89c4ff92SAndroid Build Coastguard Worker const BatchMatMulDescriptor& descriptor,
678*89c4ff92SAndroid Build Coastguard Worker Optional<std::string &> reasonIfUnsupported) const
679*89c4ff92SAndroid Build Coastguard Worker {
680*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
681*89c4ff92SAndroid Build Coastguard Worker
682*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes =
683*89c4ff92SAndroid Build Coastguard Worker {
684*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
685*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
686*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
687*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
688*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
689*89c4ff92SAndroid Build Coastguard Worker };
690*89c4ff92SAndroid Build Coastguard Worker
691*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
692*89c4ff92SAndroid Build Coastguard Worker
693*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
694*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: input X is not a supported type");
695*89c4ff92SAndroid Build Coastguard Worker
696*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
697*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: input Y is not a supported type");
698*89c4ff92SAndroid Build Coastguard Worker
699*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
700*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: output is not a supported type");
701*89c4ff92SAndroid Build Coastguard Worker
702*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
703*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: input X and input Y types are mismatched");
704*89c4ff92SAndroid Build Coastguard Worker
705*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
706*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: inputs and output types are mismatched");
707*89c4ff92SAndroid Build Coastguard Worker
708*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
709*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
710*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: input X is not of rank 2 or greater");
711*89c4ff92SAndroid Build Coastguard Worker
712*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
713*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
714*89c4ff92SAndroid Build Coastguard Worker "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
715*89c4ff92SAndroid Build Coastguard Worker
716*89c4ff92SAndroid Build Coastguard Worker return supported;
717*89c4ff92SAndroid Build Coastguard Worker }
718*89c4ff92SAndroid Build Coastguard Worker
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & variance,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const719*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
720*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
721*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& mean,
722*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& variance,
723*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& beta,
724*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& gamma,
725*89c4ff92SAndroid Build Coastguard Worker const BatchNormalizationDescriptor& descriptor,
726*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
727*89c4ff92SAndroid Build Coastguard Worker {
728*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
729*89c4ff92SAndroid Build Coastguard Worker
730*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes =
731*89c4ff92SAndroid Build Coastguard Worker {
732*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
733*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
734*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
735*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
736*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
737*89c4ff92SAndroid Build Coastguard Worker };
738*89c4ff92SAndroid Build Coastguard Worker
739*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
740*89c4ff92SAndroid Build Coastguard Worker
741*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: input is not a supported type.");
743*89c4ff92SAndroid Build Coastguard Worker
744*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
745*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: output is not a supported type.");
746*89c4ff92SAndroid Build Coastguard Worker
747*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
748*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: input and output types are mismatched");
749*89c4ff92SAndroid Build Coastguard Worker
750*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
751*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: mean is not a supported type.");
752*89c4ff92SAndroid Build Coastguard Worker
753*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
754*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: variance is not a supported type.");
755*89c4ff92SAndroid Build Coastguard Worker
756*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
757*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: beta is not a supported type.");
758*89c4ff92SAndroid Build Coastguard Worker
759*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
760*89c4ff92SAndroid Build Coastguard Worker "Reference batch normalization: gamma is not a supported type.");
761*89c4ff92SAndroid Build Coastguard Worker
762*89c4ff92SAndroid Build Coastguard Worker return supported;
763*89c4ff92SAndroid Build Coastguard Worker }
764*89c4ff92SAndroid Build Coastguard Worker
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const765*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
766*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
767*89c4ff92SAndroid Build Coastguard Worker const BatchToSpaceNdDescriptor& descriptor,
768*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
769*89c4ff92SAndroid Build Coastguard Worker {
770*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
771*89c4ff92SAndroid Build Coastguard Worker
772*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
773*89c4ff92SAndroid Build Coastguard Worker
774*89c4ff92SAndroid Build Coastguard Worker std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
775*89c4ff92SAndroid Build Coastguard Worker std::string inputTensorStr = "input";
776*89c4ff92SAndroid Build Coastguard Worker std::string outputTensorStr = "output";
777*89c4ff92SAndroid Build Coastguard Worker
778*89c4ff92SAndroid Build Coastguard Worker // Define supported types.
779*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
780*89c4ff92SAndroid Build Coastguard Worker {
781*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
782*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
783*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
784*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
785*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
786*89c4ff92SAndroid Build Coastguard Worker };
787*89c4ff92SAndroid Build Coastguard Worker
788*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789*89c4ff92SAndroid Build Coastguard Worker "Reference BatchToSpaceNd: input type not supported.");
790*89c4ff92SAndroid Build Coastguard Worker
791*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792*89c4ff92SAndroid Build Coastguard Worker "Reference BatchToSpaceNd: output type not supported.");
793*89c4ff92SAndroid Build Coastguard Worker
794*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795*89c4ff92SAndroid Build Coastguard Worker "Reference BatchToSpaceNd: input and output types mismatched.");
796*89c4ff92SAndroid Build Coastguard Worker
797*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
798*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
799*89c4ff92SAndroid Build Coastguard Worker CreateIncorrectDimensionsErrorMsg(4,
800*89c4ff92SAndroid Build Coastguard Worker output.GetNumDimensions(),
801*89c4ff92SAndroid Build Coastguard Worker batchToSpaceNdLayerStr,
802*89c4ff92SAndroid Build Coastguard Worker outputTensorStr).data());
803*89c4ff92SAndroid Build Coastguard Worker
804*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
805*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
806*89c4ff92SAndroid Build Coastguard Worker CreateIncorrectDimensionsErrorMsg(4,
807*89c4ff92SAndroid Build Coastguard Worker input.GetNumDimensions(),
808*89c4ff92SAndroid Build Coastguard Worker batchToSpaceNdLayerStr,
809*89c4ff92SAndroid Build Coastguard Worker inputTensorStr).data());
810*89c4ff92SAndroid Build Coastguard Worker
811*89c4ff92SAndroid Build Coastguard Worker return supported;
812*89c4ff92SAndroid Build Coastguard Worker }
813*89c4ff92SAndroid Build Coastguard Worker
IsCastSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const814*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
815*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
816*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
817*89c4ff92SAndroid Build Coastguard Worker {
818*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 9> supportedInputTypes =
819*89c4ff92SAndroid Build Coastguard Worker {
820*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
821*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
822*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
823*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
824*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
825*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
826*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
827*89c4ff92SAndroid Build Coastguard Worker };
828*89c4ff92SAndroid Build Coastguard Worker
829*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
830*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
831*89c4ff92SAndroid Build Coastguard Worker "Reference cast: input is not a supported type");
832*89c4ff92SAndroid Build Coastguard Worker
833*89c4ff92SAndroid Build Coastguard Worker
834*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
835*89c4ff92SAndroid Build Coastguard Worker "Reference cast: output is not a supported type");
836*89c4ff92SAndroid Build Coastguard Worker
837*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838*89c4ff92SAndroid Build Coastguard Worker "Reference cast: input and output shapes have different number of total elements");
839*89c4ff92SAndroid Build Coastguard Worker
840*89c4ff92SAndroid Build Coastguard Worker return supported;
841*89c4ff92SAndroid Build Coastguard Worker }
842*89c4ff92SAndroid Build Coastguard Worker
IsChannelShuffleSupported(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const843*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
844*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
845*89c4ff92SAndroid Build Coastguard Worker const ChannelShuffleDescriptor& descriptor,
846*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
847*89c4ff92SAndroid Build Coastguard Worker {
848*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
849*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
850*89c4ff92SAndroid Build Coastguard Worker
851*89c4ff92SAndroid Build Coastguard Worker // Define supported output and inputs types.
852*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 7> supportedTypes =
853*89c4ff92SAndroid Build Coastguard Worker {
854*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
855*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
856*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
857*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
858*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
859*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
860*89c4ff92SAndroid Build Coastguard Worker };
861*89c4ff92SAndroid Build Coastguard Worker
862*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863*89c4ff92SAndroid Build Coastguard Worker "Reference ChannelShuffle: input is not a supported type.");
864*89c4ff92SAndroid Build Coastguard Worker
865*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866*89c4ff92SAndroid Build Coastguard Worker "Reference ChannelShuffle: output is not a supported type.");
867*89c4ff92SAndroid Build Coastguard Worker
868*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
869*89c4ff92SAndroid Build Coastguard Worker "Reference ChannelShuffle: input and output types are mismatched.");
870*89c4ff92SAndroid Build Coastguard Worker
871*89c4ff92SAndroid Build Coastguard Worker return supported;
872*89c4ff92SAndroid Build Coastguard Worker }
873*89c4ff92SAndroid Build Coastguard Worker
874*89c4ff92SAndroid Build Coastguard Worker
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const875*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
876*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
877*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
878*89c4ff92SAndroid Build Coastguard Worker const ComparisonDescriptor& descriptor,
879*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
880*89c4ff92SAndroid Build Coastguard Worker {
881*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
882*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 8> supportedInputTypes =
883*89c4ff92SAndroid Build Coastguard Worker {
884*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean,
885*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
886*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
887*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
888*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
889*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
890*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
891*89c4ff92SAndroid Build Coastguard Worker };
892*89c4ff92SAndroid Build Coastguard Worker
893*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
894*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
895*89c4ff92SAndroid Build Coastguard Worker "Reference comparison: input 0 is not a supported type");
896*89c4ff92SAndroid Build Coastguard Worker
897*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
898*89c4ff92SAndroid Build Coastguard Worker "Reference comparison: input 0 and Input 1 types are mismatched");
899*89c4ff92SAndroid Build Coastguard Worker
900*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
901*89c4ff92SAndroid Build Coastguard Worker "Reference comparison: output is not of type Boolean");
902*89c4ff92SAndroid Build Coastguard Worker
903*89c4ff92SAndroid Build Coastguard Worker return supported;
904*89c4ff92SAndroid Build Coastguard Worker }
905*89c4ff92SAndroid Build Coastguard Worker
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const OriginsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const906*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
907*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
908*89c4ff92SAndroid Build Coastguard Worker const OriginsDescriptor& descriptor,
909*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
910*89c4ff92SAndroid Build Coastguard Worker {
911*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
912*89c4ff92SAndroid Build Coastguard Worker
913*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
914*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
915*89c4ff92SAndroid Build Coastguard Worker {
916*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
917*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
918*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
919*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
920*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
921*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
922*89c4ff92SAndroid Build Coastguard Worker };
923*89c4ff92SAndroid Build Coastguard Worker
924*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
925*89c4ff92SAndroid Build Coastguard Worker "Reference concatenation: output type not supported");
926*89c4ff92SAndroid Build Coastguard Worker for (const TensorInfo* input : inputs)
927*89c4ff92SAndroid Build Coastguard Worker {
928*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(input != nullptr);
929*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
930*89c4ff92SAndroid Build Coastguard Worker "Reference concatenation: input type not supported");
931*89c4ff92SAndroid Build Coastguard Worker
932*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
933*89c4ff92SAndroid Build Coastguard Worker "Reference concatenation: input and output types mismatched.");
934*89c4ff92SAndroid Build Coastguard Worker }
935*89c4ff92SAndroid Build Coastguard Worker
936*89c4ff92SAndroid Build Coastguard Worker return supported;
937*89c4ff92SAndroid Build Coastguard Worker }
938*89c4ff92SAndroid Build Coastguard Worker
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const939*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
940*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
941*89c4ff92SAndroid Build Coastguard Worker {
942*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,8> supportedTypes =
943*89c4ff92SAndroid Build Coastguard Worker {
944*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
945*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
946*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
947*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
948*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
949*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
950*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
951*89c4ff92SAndroid Build Coastguard Worker };
952*89c4ff92SAndroid Build Coastguard Worker
953*89c4ff92SAndroid Build Coastguard Worker return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
954*89c4ff92SAndroid Build Coastguard Worker "Reference constant: output is not a supported type.");
955*89c4ff92SAndroid Build Coastguard Worker }
956*89c4ff92SAndroid Build Coastguard Worker
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const957*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
958*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
959*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
960*89c4ff92SAndroid Build Coastguard Worker {
961*89c4ff92SAndroid Build Coastguard Worker return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
962*89c4ff92SAndroid Build Coastguard Worker input.GetDataType(),
963*89c4ff92SAndroid Build Coastguard Worker &TrueFunc<>,
964*89c4ff92SAndroid Build Coastguard Worker &FalseInputFuncF32<>,
965*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>,
966*89c4ff92SAndroid Build Coastguard Worker &FalseFuncI32<>,
967*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>) &&
968*89c4ff92SAndroid Build Coastguard Worker IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969*89c4ff92SAndroid Build Coastguard Worker output.GetDataType(),
970*89c4ff92SAndroid Build Coastguard Worker &FalseOutputFuncF16<>,
971*89c4ff92SAndroid Build Coastguard Worker &TrueFunc<>,
972*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>,
973*89c4ff92SAndroid Build Coastguard Worker &FalseFuncI32<>,
974*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>));
975*89c4ff92SAndroid Build Coastguard Worker }
976*89c4ff92SAndroid Build Coastguard Worker
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const977*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
978*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
979*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
980*89c4ff92SAndroid Build Coastguard Worker {
981*89c4ff92SAndroid Build Coastguard Worker return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
982*89c4ff92SAndroid Build Coastguard Worker input.GetDataType(),
983*89c4ff92SAndroid Build Coastguard Worker &FalseInputFuncF16<>,
984*89c4ff92SAndroid Build Coastguard Worker &TrueFunc<>,
985*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>,
986*89c4ff92SAndroid Build Coastguard Worker &FalseFuncI32<>,
987*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>) &&
988*89c4ff92SAndroid Build Coastguard Worker IsSupportedForDataTypeGeneric(reasonIfUnsupported,
989*89c4ff92SAndroid Build Coastguard Worker output.GetDataType(),
990*89c4ff92SAndroid Build Coastguard Worker &TrueFunc<>,
991*89c4ff92SAndroid Build Coastguard Worker &FalseOutputFuncF32<>,
992*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>,
993*89c4ff92SAndroid Build Coastguard Worker &FalseFuncI32<>,
994*89c4ff92SAndroid Build Coastguard Worker &FalseFuncU8<>));
995*89c4ff92SAndroid Build Coastguard Worker }
996*89c4ff92SAndroid Build Coastguard Worker
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const997*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
998*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
999*89c4ff92SAndroid Build Coastguard Worker const Convolution2dDescriptor& descriptor,
1000*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1001*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
1002*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1003*89c4ff92SAndroid Build Coastguard Worker {
1004*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1005*89c4ff92SAndroid Build Coastguard Worker
1006*89c4ff92SAndroid Build Coastguard Worker // Define supported types.
1007*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
1008*89c4ff92SAndroid Build Coastguard Worker {
1009*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1010*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1011*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1012*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1013*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1014*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1015*89c4ff92SAndroid Build Coastguard Worker };
1016*89c4ff92SAndroid Build Coastguard Worker
1017*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1018*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: input is not a supported type.");
1019*89c4ff92SAndroid Build Coastguard Worker
1020*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1021*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: output is not a supported type.");
1022*89c4ff92SAndroid Build Coastguard Worker
1023*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1024*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: input and output types mismatched.");
1025*89c4ff92SAndroid Build Coastguard Worker
1026*89c4ff92SAndroid Build Coastguard Worker
1027*89c4ff92SAndroid Build Coastguard Worker const DataType inputType = input.GetDataType();
1028*89c4ff92SAndroid Build Coastguard Worker if (IsQuantized8BitType(inputType))
1029*89c4ff92SAndroid Build Coastguard Worker {
1030*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedWeightTypes =
1031*89c4ff92SAndroid Build Coastguard Worker {
1032*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1033*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1034*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
1035*89c4ff92SAndroid Build Coastguard Worker };
1036*89c4ff92SAndroid Build Coastguard Worker
1037*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1038*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: weights type not supported for quantized input.");
1039*89c4ff92SAndroid Build Coastguard Worker }
1040*89c4ff92SAndroid Build Coastguard Worker else
1041*89c4ff92SAndroid Build Coastguard Worker {
1042*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1043*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: weights is not a supported type.");
1044*89c4ff92SAndroid Build Coastguard Worker
1045*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1046*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: input and weights types mismatched.");
1047*89c4ff92SAndroid Build Coastguard Worker }
1048*89c4ff92SAndroid Build Coastguard Worker
1049*89c4ff92SAndroid Build Coastguard Worker if (biases.has_value())
1050*89c4ff92SAndroid Build Coastguard Worker {
1051*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,4> biasesSupportedTypes =
1052*89c4ff92SAndroid Build Coastguard Worker {
1053*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1054*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1055*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1056*89c4ff92SAndroid Build Coastguard Worker };
1057*89c4ff92SAndroid Build Coastguard Worker
1058*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1059*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution2d: biases is not a supported type.");
1060*89c4ff92SAndroid Build Coastguard Worker }
1061*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1062*89c4ff92SAndroid Build Coastguard Worker
1063*89c4ff92SAndroid Build Coastguard Worker return supported;
1064*89c4ff92SAndroid Build Coastguard Worker }
1065*89c4ff92SAndroid Build Coastguard Worker
IsConvolution3dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution3dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1066*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1067*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1068*89c4ff92SAndroid Build Coastguard Worker const Convolution3dDescriptor& descriptor,
1069*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1070*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
1071*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1072*89c4ff92SAndroid Build Coastguard Worker {
1073*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1074*89c4ff92SAndroid Build Coastguard Worker
1075*89c4ff92SAndroid Build Coastguard Worker // Define supported types.
1076*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
1077*89c4ff92SAndroid Build Coastguard Worker {
1078*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1079*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1080*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1081*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1082*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1083*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1084*89c4ff92SAndroid Build Coastguard Worker };
1085*89c4ff92SAndroid Build Coastguard Worker
1086*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1087*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: input is not a supported type.");
1088*89c4ff92SAndroid Build Coastguard Worker
1089*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1090*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: output is not a supported type.");
1091*89c4ff92SAndroid Build Coastguard Worker
1092*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1093*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: input and output types mismatched.");
1094*89c4ff92SAndroid Build Coastguard Worker
1095*89c4ff92SAndroid Build Coastguard Worker const DataType inputType = input.GetDataType();
1096*89c4ff92SAndroid Build Coastguard Worker if (IsQuantized8BitType(inputType))
1097*89c4ff92SAndroid Build Coastguard Worker {
1098*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedWeightTypes =
1099*89c4ff92SAndroid Build Coastguard Worker {
1100*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1101*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1102*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
1103*89c4ff92SAndroid Build Coastguard Worker };
1104*89c4ff92SAndroid Build Coastguard Worker
1105*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1106*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: weights type not supported for quantized input.");
1107*89c4ff92SAndroid Build Coastguard Worker }
1108*89c4ff92SAndroid Build Coastguard Worker else
1109*89c4ff92SAndroid Build Coastguard Worker {
1110*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1111*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: weights is not a supported type.");
1112*89c4ff92SAndroid Build Coastguard Worker
1113*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1114*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: input and weights types mismatched.");
1115*89c4ff92SAndroid Build Coastguard Worker }
1116*89c4ff92SAndroid Build Coastguard Worker
1117*89c4ff92SAndroid Build Coastguard Worker if (biases.has_value())
1118*89c4ff92SAndroid Build Coastguard Worker {
1119*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,4> biasesSupportedTypes =
1120*89c4ff92SAndroid Build Coastguard Worker {
1121*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1122*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1123*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1124*89c4ff92SAndroid Build Coastguard Worker };
1125*89c4ff92SAndroid Build Coastguard Worker
1126*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1127*89c4ff92SAndroid Build Coastguard Worker "Reference Convolution3d: biases is not a supported type.");
1128*89c4ff92SAndroid Build Coastguard Worker }
1129*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1130*89c4ff92SAndroid Build Coastguard Worker
1131*89c4ff92SAndroid Build Coastguard Worker return supported;
1132*89c4ff92SAndroid Build Coastguard Worker }
1133*89c4ff92SAndroid Build Coastguard Worker
IsDebugSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1134*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1135*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1136*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1137*89c4ff92SAndroid Build Coastguard Worker {
1138*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1139*89c4ff92SAndroid Build Coastguard Worker
1140*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 8> supportedTypes =
1141*89c4ff92SAndroid Build Coastguard Worker {
1142*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1143*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1144*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1145*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1146*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1147*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1148*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1149*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1150*89c4ff92SAndroid Build Coastguard Worker };
1151*89c4ff92SAndroid Build Coastguard Worker
1152*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1153*89c4ff92SAndroid Build Coastguard Worker "Reference for Debug layer: input type not supported");
1154*89c4ff92SAndroid Build Coastguard Worker
1155*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1156*89c4ff92SAndroid Build Coastguard Worker "Reference for Debug layer: output type not supported");
1157*89c4ff92SAndroid Build Coastguard Worker
1158*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1159*89c4ff92SAndroid Build Coastguard Worker "Reference for Debug layer: input and output types are mismatched");
1160*89c4ff92SAndroid Build Coastguard Worker
1161*89c4ff92SAndroid Build Coastguard Worker return supported;
1162*89c4ff92SAndroid Build Coastguard Worker }
1163*89c4ff92SAndroid Build Coastguard Worker
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1164*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1165*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1166*89c4ff92SAndroid Build Coastguard Worker const DepthToSpaceDescriptor& descriptor,
1167*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1168*89c4ff92SAndroid Build Coastguard Worker {
1169*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1170*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1171*89c4ff92SAndroid Build Coastguard Worker
1172*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
1173*89c4ff92SAndroid Build Coastguard Worker {
1174*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1175*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1176*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1177*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1178*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1179*89c4ff92SAndroid Build Coastguard Worker };
1180*89c4ff92SAndroid Build Coastguard Worker
1181*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1182*89c4ff92SAndroid Build Coastguard Worker "Reference DepthToSpace: input type not supported");
1183*89c4ff92SAndroid Build Coastguard Worker
1184*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1185*89c4ff92SAndroid Build Coastguard Worker "Reference DepthToSpace: output type not supported");
1186*89c4ff92SAndroid Build Coastguard Worker
1187*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1188*89c4ff92SAndroid Build Coastguard Worker "Reference DepthToSpace: input and output types are mismatched");
1189*89c4ff92SAndroid Build Coastguard Worker
1190*89c4ff92SAndroid Build Coastguard Worker return supported;
1191*89c4ff92SAndroid Build Coastguard Worker }
1192*89c4ff92SAndroid Build Coastguard Worker
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1193*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1194*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1195*89c4ff92SAndroid Build Coastguard Worker const DepthwiseConvolution2dDescriptor& descriptor,
1196*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1197*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
1198*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1199*89c4ff92SAndroid Build Coastguard Worker {
1200*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1201*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1202*89c4ff92SAndroid Build Coastguard Worker
1203*89c4ff92SAndroid Build Coastguard Worker // Define supported types.
1204*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
1205*89c4ff92SAndroid Build Coastguard Worker {
1206*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1207*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1208*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1209*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1210*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1211*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1212*89c4ff92SAndroid Build Coastguard Worker };
1213*89c4ff92SAndroid Build Coastguard Worker
1214*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1215*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: input is not a supported type.");
1216*89c4ff92SAndroid Build Coastguard Worker
1217*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1218*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: output is not a supported type.");
1219*89c4ff92SAndroid Build Coastguard Worker
1220*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1221*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: input and output types mismatched.");
1222*89c4ff92SAndroid Build Coastguard Worker
1223*89c4ff92SAndroid Build Coastguard Worker const DataType inputType = input.GetDataType();
1224*89c4ff92SAndroid Build Coastguard Worker if (IsQuantized8BitType(inputType))
1225*89c4ff92SAndroid Build Coastguard Worker {
1226*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedWeightTypes =
1227*89c4ff92SAndroid Build Coastguard Worker {
1228*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1229*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1230*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1231*89c4ff92SAndroid Build Coastguard Worker };
1232*89c4ff92SAndroid Build Coastguard Worker
1233*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1234*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: weights type not supported for "
1235*89c4ff92SAndroid Build Coastguard Worker "quantized input.");
1236*89c4ff92SAndroid Build Coastguard Worker }
1237*89c4ff92SAndroid Build Coastguard Worker else
1238*89c4ff92SAndroid Build Coastguard Worker {
1239*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1240*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: weights is not a supported type.");
1241*89c4ff92SAndroid Build Coastguard Worker
1242*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1243*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1244*89c4ff92SAndroid Build Coastguard Worker }
1245*89c4ff92SAndroid Build Coastguard Worker
1246*89c4ff92SAndroid Build Coastguard Worker if (biases.has_value())
1247*89c4ff92SAndroid Build Coastguard Worker {
1248*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,4> biasesSupportedTypes =
1249*89c4ff92SAndroid Build Coastguard Worker {
1250*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1251*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1252*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1253*89c4ff92SAndroid Build Coastguard Worker };
1254*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1255*89c4ff92SAndroid Build Coastguard Worker "Reference DepthwiseConvolution2d: biases is not a supported type.");
1256*89c4ff92SAndroid Build Coastguard Worker }
1257*89c4ff92SAndroid Build Coastguard Worker
1258*89c4ff92SAndroid Build Coastguard Worker return supported;
1259*89c4ff92SAndroid Build Coastguard Worker
1260*89c4ff92SAndroid Build Coastguard Worker }
1261*89c4ff92SAndroid Build Coastguard Worker
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1262*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1263*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1264*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1265*89c4ff92SAndroid Build Coastguard Worker {
1266*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1267*89c4ff92SAndroid Build Coastguard Worker
1268*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,5> supportedInputTypes = {
1269*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1270*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1271*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1272*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1273*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
1274*89c4ff92SAndroid Build Coastguard Worker };
1275*89c4ff92SAndroid Build Coastguard Worker
1276*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1277*89c4ff92SAndroid Build Coastguard Worker "Reference for Dequantize layer: input type not supported.");
1278*89c4ff92SAndroid Build Coastguard Worker
1279*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
1280*89c4ff92SAndroid Build Coastguard Worker "Reference for Dequantize layer: per-axis quantized input not supported.");
1281*89c4ff92SAndroid Build Coastguard Worker
1282*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,3> supportedOutputTypes = {
1283*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1284*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
1285*89c4ff92SAndroid Build Coastguard Worker };
1286*89c4ff92SAndroid Build Coastguard Worker
1287*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1288*89c4ff92SAndroid Build Coastguard Worker "Reference for Dequantize layer: output type not supported.");
1289*89c4ff92SAndroid Build Coastguard Worker
1290*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1291*89c4ff92SAndroid Build Coastguard Worker "Reference for Dequantize layer: input/output shapes have different num total "
1292*89c4ff92SAndroid Build Coastguard Worker "elements.");
1293*89c4ff92SAndroid Build Coastguard Worker
1294*89c4ff92SAndroid Build Coastguard Worker return supported;
1295*89c4ff92SAndroid Build Coastguard Worker }
1296*89c4ff92SAndroid Build Coastguard Worker
IsDetectionPostProcessSupported(const TensorInfo & boxEncodings,const TensorInfo & scores,const TensorInfo & anchors,const TensorInfo & detectionBoxes,const TensorInfo & detectionClasses,const TensorInfo & detectionScores,const TensorInfo & numDetections,const DetectionPostProcessDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1297*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1298*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& scores,
1299*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& anchors,
1300*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& detectionBoxes,
1301*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& detectionClasses,
1302*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& detectionScores,
1303*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& numDetections,
1304*89c4ff92SAndroid Build Coastguard Worker const DetectionPostProcessDescriptor& descriptor,
1305*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1306*89c4ff92SAndroid Build Coastguard Worker {
1307*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
1308*89c4ff92SAndroid Build Coastguard Worker
1309*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1310*89c4ff92SAndroid Build Coastguard Worker
1311*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedInputTypes =
1312*89c4ff92SAndroid Build Coastguard Worker {
1313*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1314*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1315*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1316*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1317*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1318*89c4ff92SAndroid Build Coastguard Worker };
1319*89c4ff92SAndroid Build Coastguard Worker
1320*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
1321*89c4ff92SAndroid Build Coastguard Worker "Reference DetectionPostProcess: input 0 is not a supported type.");
1322*89c4ff92SAndroid Build Coastguard Worker
1323*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
1324*89c4ff92SAndroid Build Coastguard Worker "Reference DetectionPostProcess: input 1 is not a supported type.");
1325*89c4ff92SAndroid Build Coastguard Worker
1326*89c4ff92SAndroid Build Coastguard Worker return supported;
1327*89c4ff92SAndroid Build Coastguard Worker }
1328*89c4ff92SAndroid Build Coastguard Worker
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1329*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1330*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1331*89c4ff92SAndroid Build Coastguard Worker const DepthwiseConvolution2dDescriptor& descriptor,
1332*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1333*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
1334*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1335*89c4ff92SAndroid Build Coastguard Worker {
1336*89c4ff92SAndroid Build Coastguard Worker return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
1337*89c4ff92SAndroid Build Coastguard Worker }
1338*89c4ff92SAndroid Build Coastguard Worker
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1339*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
1340*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1341*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1342*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1343*89c4ff92SAndroid Build Coastguard Worker {
1344*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1345*89c4ff92SAndroid Build Coastguard Worker
1346*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes = {
1347*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1348*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1349*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1350*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1351*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1352*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1353*89c4ff92SAndroid Build Coastguard Worker };
1354*89c4ff92SAndroid Build Coastguard Worker
1355*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1356*89c4ff92SAndroid Build Coastguard Worker "Reference division: input 0 is not a supported type.");
1357*89c4ff92SAndroid Build Coastguard Worker
1358*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1359*89c4ff92SAndroid Build Coastguard Worker "Reference division: input 1 is not a supported type.");
1360*89c4ff92SAndroid Build Coastguard Worker
1361*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1362*89c4ff92SAndroid Build Coastguard Worker "Reference division: output is not a supported type.");
1363*89c4ff92SAndroid Build Coastguard Worker
1364*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1365*89c4ff92SAndroid Build Coastguard Worker "Reference division: input 0 and Input 1 types are mismatched");
1366*89c4ff92SAndroid Build Coastguard Worker
1367*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1368*89c4ff92SAndroid Build Coastguard Worker "Reference division: input and output types are mismatched");
1369*89c4ff92SAndroid Build Coastguard Worker
1370*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1371*89c4ff92SAndroid Build Coastguard Worker "Reference division: shapes are not suitable for implicit broadcast.");
1372*89c4ff92SAndroid Build Coastguard Worker
1373*89c4ff92SAndroid Build Coastguard Worker return supported;
1374*89c4ff92SAndroid Build Coastguard Worker }
1375*89c4ff92SAndroid Build Coastguard Worker
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1376*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1377*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1378*89c4ff92SAndroid Build Coastguard Worker const ElementwiseUnaryDescriptor& descriptor,
1379*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1380*89c4ff92SAndroid Build Coastguard Worker {
1381*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1382*89c4ff92SAndroid Build Coastguard Worker
1383*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 7> supportedTypes =
1384*89c4ff92SAndroid Build Coastguard Worker {
1385*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1386*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1387*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1388*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1389*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1390*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1391*89c4ff92SAndroid Build Coastguard Worker };
1392*89c4ff92SAndroid Build Coastguard Worker
1393*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 1> logicalSupportedTypes =
1394*89c4ff92SAndroid Build Coastguard Worker {
1395*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean
1396*89c4ff92SAndroid Build Coastguard Worker };
1397*89c4ff92SAndroid Build Coastguard Worker
1398*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1399*89c4ff92SAndroid Build Coastguard Worker
1400*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1401*89c4ff92SAndroid Build Coastguard Worker {
1402*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1403*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input type not supported");
1404*89c4ff92SAndroid Build Coastguard Worker
1405*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1406*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: output type not supported");
1407*89c4ff92SAndroid Build Coastguard Worker }
1408*89c4ff92SAndroid Build Coastguard Worker else
1409*89c4ff92SAndroid Build Coastguard Worker {
1410*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1411*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input type not supported");
1412*89c4ff92SAndroid Build Coastguard Worker
1413*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1414*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: output type not supported");
1415*89c4ff92SAndroid Build Coastguard Worker }
1416*89c4ff92SAndroid Build Coastguard Worker
1417*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1418*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input and output types not matching");
1419*89c4ff92SAndroid Build Coastguard Worker
1420*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1421*89c4ff92SAndroid Build Coastguard Worker "Reference elementwise unary: input and output shapes"
1422*89c4ff92SAndroid Build Coastguard Worker "have different number of total elements");
1423*89c4ff92SAndroid Build Coastguard Worker
1424*89c4ff92SAndroid Build Coastguard Worker return supported;
1425*89c4ff92SAndroid Build Coastguard Worker }
1426*89c4ff92SAndroid Build Coastguard Worker
IsFakeQuantizationSupported(const TensorInfo & input,const FakeQuantizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1427*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1428*89c4ff92SAndroid Build Coastguard Worker const FakeQuantizationDescriptor& descriptor,
1429*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1430*89c4ff92SAndroid Build Coastguard Worker {
1431*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1432*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1433*89c4ff92SAndroid Build Coastguard Worker
1434*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,1> supportedTypes =
1435*89c4ff92SAndroid Build Coastguard Worker {
1436*89c4ff92SAndroid Build Coastguard Worker DataType::Float32
1437*89c4ff92SAndroid Build Coastguard Worker };
1438*89c4ff92SAndroid Build Coastguard Worker
1439*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440*89c4ff92SAndroid Build Coastguard Worker "Reference fake quantization: input type not supported.");
1441*89c4ff92SAndroid Build Coastguard Worker
1442*89c4ff92SAndroid Build Coastguard Worker return supported;
1443*89c4ff92SAndroid Build Coastguard Worker }
1444*89c4ff92SAndroid Build Coastguard Worker
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1445*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1446*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1447*89c4ff92SAndroid Build Coastguard Worker const FillDescriptor& descriptor,
1448*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1449*89c4ff92SAndroid Build Coastguard Worker {
1450*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1451*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(output);
1452*89c4ff92SAndroid Build Coastguard Worker
1453*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1454*89c4ff92SAndroid Build Coastguard Worker
1455*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,3> supportedTypes =
1456*89c4ff92SAndroid Build Coastguard Worker {
1457*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1458*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1459*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1460*89c4ff92SAndroid Build Coastguard Worker };
1461*89c4ff92SAndroid Build Coastguard Worker
1462*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
1463*89c4ff92SAndroid Build Coastguard Worker "Reference Fill: input type not supported.");
1464*89c4ff92SAndroid Build Coastguard Worker
1465*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1466*89c4ff92SAndroid Build Coastguard Worker "Reference Fill: output type not supported.");
1467*89c4ff92SAndroid Build Coastguard Worker return supported;
1468*89c4ff92SAndroid Build Coastguard Worker }
1469*89c4ff92SAndroid Build Coastguard Worker
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1470*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1471*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1472*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1473*89c4ff92SAndroid Build Coastguard Worker {
1474*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(output);
1475*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1476*89c4ff92SAndroid Build Coastguard Worker
1477*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,3> supportedTypes =
1478*89c4ff92SAndroid Build Coastguard Worker {
1479*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1480*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
1481*89c4ff92SAndroid Build Coastguard Worker };
1482*89c4ff92SAndroid Build Coastguard Worker
1483*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484*89c4ff92SAndroid Build Coastguard Worker "Reference Floor: input type not supported.");
1485*89c4ff92SAndroid Build Coastguard Worker
1486*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487*89c4ff92SAndroid Build Coastguard Worker "Reference Floor: output type not supported.");
1488*89c4ff92SAndroid Build Coastguard Worker
1489*89c4ff92SAndroid Build Coastguard Worker return supported;
1490*89c4ff92SAndroid Build Coastguard Worker }
1491*89c4ff92SAndroid Build Coastguard Worker
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1492*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1493*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1494*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1495*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biases,
1496*89c4ff92SAndroid Build Coastguard Worker const FullyConnectedDescriptor& descriptor,
1497*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1498*89c4ff92SAndroid Build Coastguard Worker {
1499*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1500*89c4ff92SAndroid Build Coastguard Worker
1501*89c4ff92SAndroid Build Coastguard Worker // Define supported types.
1502*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
1503*89c4ff92SAndroid Build Coastguard Worker {
1504*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1505*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1506*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1507*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1508*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1509*89c4ff92SAndroid Build Coastguard Worker };
1510*89c4ff92SAndroid Build Coastguard Worker
1511*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: input type not supported.");
1513*89c4ff92SAndroid Build Coastguard Worker
1514*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: output type not supported.");
1516*89c4ff92SAndroid Build Coastguard Worker
1517*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1518*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: weights type not supported.");
1519*89c4ff92SAndroid Build Coastguard Worker
1520*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1521*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: input and output types mismatched.");
1522*89c4ff92SAndroid Build Coastguard Worker
1523*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1524*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: weights is not a supported type.");
1525*89c4ff92SAndroid Build Coastguard Worker
1526*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1527*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: input and weights types mismatched.");
1528*89c4ff92SAndroid Build Coastguard Worker
1529*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_BiasEnabled)
1530*89c4ff92SAndroid Build Coastguard Worker {
1531*89c4ff92SAndroid Build Coastguard Worker // Defined supported types for bias
1532*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 5>
1533*89c4ff92SAndroid Build Coastguard Worker supportedBiasTypes =
1534*89c4ff92SAndroid Build Coastguard Worker {
1535*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1536*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1537*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
1538*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8
1539*89c4ff92SAndroid Build Coastguard Worker };
1540*89c4ff92SAndroid Build Coastguard Worker
1541*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1542*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: bias type not supported.");
1543*89c4ff92SAndroid Build Coastguard Worker
1544*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1545*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: bias and weight types mismatch.");
1546*89c4ff92SAndroid Build Coastguard Worker
1547*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1548*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: bias type inferred from weights is incompatible.");
1549*89c4ff92SAndroid Build Coastguard Worker
1550*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1551*89c4ff92SAndroid Build Coastguard Worker "Reference Fully Connected: bias must have 1 dimension.");
1552*89c4ff92SAndroid Build Coastguard Worker
1553*89c4ff92SAndroid Build Coastguard Worker }
1554*89c4ff92SAndroid Build Coastguard Worker
1555*89c4ff92SAndroid Build Coastguard Worker return supported;
1556*89c4ff92SAndroid Build Coastguard Worker }
1557*89c4ff92SAndroid Build Coastguard Worker
IsGatherNdSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const1558*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1559*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& input1,
1560*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& output,
1561*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> reasonIfUnsupported) const
1562*89c4ff92SAndroid Build Coastguard Worker {
1563*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1564*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
1565*89c4ff92SAndroid Build Coastguard Worker {
1566*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1567*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1568*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1569*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1570*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1571*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1572*89c4ff92SAndroid Build Coastguard Worker };
1573*89c4ff92SAndroid Build Coastguard Worker
1574*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1575*89c4ff92SAndroid Build Coastguard Worker "Reference GatherNd: input type not supported");
1576*89c4ff92SAndroid Build Coastguard Worker
1577*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578*89c4ff92SAndroid Build Coastguard Worker "Reference GatherNd: output type not supported");
1579*89c4ff92SAndroid Build Coastguard Worker
1580*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1581*89c4ff92SAndroid Build Coastguard Worker "Reference GatherNd: indices (input1) type not supported");
1582*89c4ff92SAndroid Build Coastguard Worker
1583*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584*89c4ff92SAndroid Build Coastguard Worker "Reference GatherNd: input and output types not matching");
1585*89c4ff92SAndroid Build Coastguard Worker
1586*89c4ff92SAndroid Build Coastguard Worker return supported;
1587*89c4ff92SAndroid Build Coastguard Worker }
1588*89c4ff92SAndroid Build Coastguard Worker
IsGatherSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,const GatherDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const1589*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1590*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& input1,
1591*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& output,
1592*89c4ff92SAndroid Build Coastguard Worker const GatherDescriptor& descriptor,
1593*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> reasonIfUnsupported) const
1594*89c4ff92SAndroid Build Coastguard Worker {
1595*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1596*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
1597*89c4ff92SAndroid Build Coastguard Worker {
1598*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1599*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1600*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1601*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1602*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1603*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1604*89c4ff92SAndroid Build Coastguard Worker };
1605*89c4ff92SAndroid Build Coastguard Worker
1606*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1607*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1608*89c4ff92SAndroid Build Coastguard Worker "Reference Gather: input type not supported");
1609*89c4ff92SAndroid Build Coastguard Worker
1610*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1611*89c4ff92SAndroid Build Coastguard Worker "Reference Gather: output type not supported");
1612*89c4ff92SAndroid Build Coastguard Worker
1613*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1614*89c4ff92SAndroid Build Coastguard Worker "Reference Gather: indices (input1) type not supported");
1615*89c4ff92SAndroid Build Coastguard Worker
1616*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1617*89c4ff92SAndroid Build Coastguard Worker "Reference Gather: input and output types not matching");
1618*89c4ff92SAndroid Build Coastguard Worker
1619*89c4ff92SAndroid Build Coastguard Worker return supported;
1620*89c4ff92SAndroid Build Coastguard Worker }
1621*89c4ff92SAndroid Build Coastguard Worker
IsInputSupported(const TensorInfo &,Optional<std::string &>) const1622*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1623*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported*/) const
1624*89c4ff92SAndroid Build Coastguard Worker {
1625*89c4ff92SAndroid Build Coastguard Worker return true;
1626*89c4ff92SAndroid Build Coastguard Worker }
1627*89c4ff92SAndroid Build Coastguard Worker
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1628*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1629*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1630*89c4ff92SAndroid Build Coastguard Worker const InstanceNormalizationDescriptor& descriptor,
1631*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1632*89c4ff92SAndroid Build Coastguard Worker {
1633*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1634*89c4ff92SAndroid Build Coastguard Worker // Define supported types
1635*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedTypes =
1636*89c4ff92SAndroid Build Coastguard Worker {
1637*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1638*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
1639*89c4ff92SAndroid Build Coastguard Worker };
1640*89c4ff92SAndroid Build Coastguard Worker
1641*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1642*89c4ff92SAndroid Build Coastguard Worker
1643*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1644*89c4ff92SAndroid Build Coastguard Worker "Reference Instance Normalization: input type not supported.");
1645*89c4ff92SAndroid Build Coastguard Worker
1646*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1647*89c4ff92SAndroid Build Coastguard Worker "Reference Instance Normalization: output type not supported.");
1648*89c4ff92SAndroid Build Coastguard Worker
1649*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1650*89c4ff92SAndroid Build Coastguard Worker "Reference Instance Normalization: input and output types mismatched.");
1651*89c4ff92SAndroid Build Coastguard Worker
1652*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1653*89c4ff92SAndroid Build Coastguard Worker "Reference Instance Normalization: input and output shapes have different "
1654*89c4ff92SAndroid Build Coastguard Worker "num total elements.");
1655*89c4ff92SAndroid Build Coastguard Worker
1656*89c4ff92SAndroid Build Coastguard Worker return supported;
1657*89c4ff92SAndroid Build Coastguard Worker }
1658*89c4ff92SAndroid Build Coastguard Worker
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1659*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1660*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1661*89c4ff92SAndroid Build Coastguard Worker const L2NormalizationDescriptor& descriptor,
1662*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1663*89c4ff92SAndroid Build Coastguard Worker {
1664*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1665*89c4ff92SAndroid Build Coastguard Worker // Define supported types
1666*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes =
1667*89c4ff92SAndroid Build Coastguard Worker {
1668*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1669*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1670*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1671*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1672*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1673*89c4ff92SAndroid Build Coastguard Worker };
1674*89c4ff92SAndroid Build Coastguard Worker
1675*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1676*89c4ff92SAndroid Build Coastguard Worker
1677*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1678*89c4ff92SAndroid Build Coastguard Worker "Reference L2normalization: input type not supported.");
1679*89c4ff92SAndroid Build Coastguard Worker
1680*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1681*89c4ff92SAndroid Build Coastguard Worker "Reference L2normalization: output type not supported.");
1682*89c4ff92SAndroid Build Coastguard Worker
1683*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1684*89c4ff92SAndroid Build Coastguard Worker "Reference L2normalization: input and output types mismatched.");
1685*89c4ff92SAndroid Build Coastguard Worker
1686*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1687*89c4ff92SAndroid Build Coastguard Worker "Reference L2normalization: input and output shapes have different "
1688*89c4ff92SAndroid Build Coastguard Worker "num total elements.");
1689*89c4ff92SAndroid Build Coastguard Worker
1690*89c4ff92SAndroid Build Coastguard Worker return supported;
1691*89c4ff92SAndroid Build Coastguard Worker }
1692*89c4ff92SAndroid Build Coastguard Worker
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1693*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1694*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1695*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1696*89c4ff92SAndroid Build Coastguard Worker const LogicalBinaryDescriptor& descriptor,
1697*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1698*89c4ff92SAndroid Build Coastguard Worker {
1699*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1700*89c4ff92SAndroid Build Coastguard Worker
1701*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 1> supportedTypes =
1702*89c4ff92SAndroid Build Coastguard Worker {
1703*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean
1704*89c4ff92SAndroid Build Coastguard Worker };
1705*89c4ff92SAndroid Build Coastguard Worker
1706*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1707*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1708*89c4ff92SAndroid Build Coastguard Worker "Reference LogicalBinary: input 0 type not supported");
1709*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1710*89c4ff92SAndroid Build Coastguard Worker "Reference LogicalBinary: input 1 type not supported");
1711*89c4ff92SAndroid Build Coastguard Worker
1712*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1713*89c4ff92SAndroid Build Coastguard Worker "Reference LogicalBinary: input and output types do not match");
1714*89c4ff92SAndroid Build Coastguard Worker
1715*89c4ff92SAndroid Build Coastguard Worker return supported;
1716*89c4ff92SAndroid Build Coastguard Worker }
1717*89c4ff92SAndroid Build Coastguard Worker
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1718*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1719*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1720*89c4ff92SAndroid Build Coastguard Worker const LogSoftmaxDescriptor& descriptor,
1721*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1722*89c4ff92SAndroid Build Coastguard Worker {
1723*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1724*89c4ff92SAndroid Build Coastguard Worker
1725*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedTypes =
1726*89c4ff92SAndroid Build Coastguard Worker {
1727*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1728*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
1729*89c4ff92SAndroid Build Coastguard Worker };
1730*89c4ff92SAndroid Build Coastguard Worker
1731*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1732*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1733*89c4ff92SAndroid Build Coastguard Worker "Reference LogSoftmax: input type not supported");
1734*89c4ff92SAndroid Build Coastguard Worker
1735*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1736*89c4ff92SAndroid Build Coastguard Worker "Reference LogSoftmax: output type not supported");
1737*89c4ff92SAndroid Build Coastguard Worker
1738*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1739*89c4ff92SAndroid Build Coastguard Worker "Reference LogSoftmax: input and output types do not match");
1740*89c4ff92SAndroid Build Coastguard Worker
1741*89c4ff92SAndroid Build Coastguard Worker return supported;
1742*89c4ff92SAndroid Build Coastguard Worker }
1743*89c4ff92SAndroid Build Coastguard Worker
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1744*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1745*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateIn,
1746*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateIn,
1747*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& scratchBuffer,
1748*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateOut,
1749*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
1750*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1751*89c4ff92SAndroid Build Coastguard Worker const LstmDescriptor& descriptor,
1752*89c4ff92SAndroid Build Coastguard Worker const LstmInputParamsInfo& paramsInfo,
1753*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1754*89c4ff92SAndroid Build Coastguard Worker {
1755*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1756*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(paramsInfo);
1757*89c4ff92SAndroid Build Coastguard Worker
1758*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1759*89c4ff92SAndroid Build Coastguard Worker
1760*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,3> supportedTypes = {
1761*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1762*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1763*89c4ff92SAndroid Build Coastguard Worker };
1764*89c4ff92SAndroid Build Coastguard Worker
1765*89c4ff92SAndroid Build Coastguard Worker // check inputs and outputs
1766*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1767*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input is not a supported type.");
1768*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1769*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and outputStateIn types are mismatched");
1770*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1771*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and cellStateIn types are mismatched");
1772*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1773*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and scratchBuffer types are mismatched");
1774*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1775*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and outputStateOut types are mismatched");
1776*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1777*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and cellStateOut types are mismatched");
1778*89c4ff92SAndroid Build Coastguard Worker
1779*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1780*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and output types are mismatched");
1781*89c4ff92SAndroid Build Coastguard Worker // check layer parameters
1782*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1783*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and InputToForgetWeights types are mismatched");
1784*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1785*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and InputToCellWeights types are mismatched");
1786*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1787*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and InputToOutputWeights types are mismatched");
1788*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1789*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1790*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1791*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1792*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1793*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1794*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1795*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and ForgetGateBias types are mismatched");
1796*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1797*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and CellBias types are mismatched");
1798*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1799*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and OutputGateBias types are mismatched");
1800*89c4ff92SAndroid Build Coastguard Worker if (!descriptor.m_CifgEnabled)
1801*89c4ff92SAndroid Build Coastguard Worker {
1802*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1803*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and InputToInputWeights types are mismatched");
1804*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1805*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1806*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1807*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1808*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and InputGateBias types are mismatched");
1809*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_PeepholeEnabled)
1810*89c4ff92SAndroid Build Coastguard Worker {
1811*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1812*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1813*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and CellToInputWeights types are mismatched");
1814*89c4ff92SAndroid Build Coastguard Worker }
1815*89c4ff92SAndroid Build Coastguard Worker }
1816*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_PeepholeEnabled)
1817*89c4ff92SAndroid Build Coastguard Worker {
1818*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1819*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and CellToForgetWeights types are mismatched");
1820*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1821*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and CellToOutputWeights types are mismatched");
1822*89c4ff92SAndroid Build Coastguard Worker }
1823*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_ProjectionEnabled)
1824*89c4ff92SAndroid Build Coastguard Worker {
1825*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1826*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and mProjectionWeights types are mismatched");
1827*89c4ff92SAndroid Build Coastguard Worker if (paramsInfo.m_ProjectionBias != nullptr)
1828*89c4ff92SAndroid Build Coastguard Worker {
1829*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1830*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and ProjectionBias types are mismatched");
1831*89c4ff92SAndroid Build Coastguard Worker }
1832*89c4ff92SAndroid Build Coastguard Worker }
1833*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_LayerNormEnabled)
1834*89c4ff92SAndroid Build Coastguard Worker {
1835*89c4ff92SAndroid Build Coastguard Worker if (!descriptor.m_CifgEnabled)
1836*89c4ff92SAndroid Build Coastguard Worker {
1837*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1838*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1839*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1840*89c4ff92SAndroid Build Coastguard Worker }
1841*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1842*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1843*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1844*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1845*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1846*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1847*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1848*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1849*89c4ff92SAndroid Build Coastguard Worker "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1850*89c4ff92SAndroid Build Coastguard Worker }
1851*89c4ff92SAndroid Build Coastguard Worker
1852*89c4ff92SAndroid Build Coastguard Worker return supported;
1853*89c4ff92SAndroid Build Coastguard Worker }
1854*89c4ff92SAndroid Build Coastguard Worker
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1855*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1856*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1857*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1858*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1859*89c4ff92SAndroid Build Coastguard Worker {
1860*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1861*89c4ff92SAndroid Build Coastguard Worker
1862*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes = {
1863*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1864*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1865*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1866*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1867*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1868*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1869*89c4ff92SAndroid Build Coastguard Worker };
1870*89c4ff92SAndroid Build Coastguard Worker
1871*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1872*89c4ff92SAndroid Build Coastguard Worker "Reference maximum: input 0 is not a supported type.");
1873*89c4ff92SAndroid Build Coastguard Worker
1874*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1875*89c4ff92SAndroid Build Coastguard Worker "Reference maximum: input 1 is not a supported type.");
1876*89c4ff92SAndroid Build Coastguard Worker
1877*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1878*89c4ff92SAndroid Build Coastguard Worker "Reference maximum: output is not a supported type.");
1879*89c4ff92SAndroid Build Coastguard Worker
1880*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1881*89c4ff92SAndroid Build Coastguard Worker "Reference maximum: input 0 and Input 1 types are mismatched");
1882*89c4ff92SAndroid Build Coastguard Worker
1883*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1884*89c4ff92SAndroid Build Coastguard Worker "Reference maximum: input and output types are mismatched");
1885*89c4ff92SAndroid Build Coastguard Worker
1886*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1887*89c4ff92SAndroid Build Coastguard Worker "Reference maximum: shapes are not suitable for implicit broadcast.");
1888*89c4ff92SAndroid Build Coastguard Worker
1889*89c4ff92SAndroid Build Coastguard Worker return supported;
1890*89c4ff92SAndroid Build Coastguard Worker }
1891*89c4ff92SAndroid Build Coastguard Worker
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1892*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1893*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1894*89c4ff92SAndroid Build Coastguard Worker const MeanDescriptor& descriptor,
1895*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1896*89c4ff92SAndroid Build Coastguard Worker {
1897*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1898*89c4ff92SAndroid Build Coastguard Worker std::string meanLayerStr = "Mean";
1899*89c4ff92SAndroid Build Coastguard Worker std::string outputTensorStr = "output";
1900*89c4ff92SAndroid Build Coastguard Worker
1901*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
1902*89c4ff92SAndroid Build Coastguard Worker {
1903*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1904*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1905*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1906*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1907*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1908*89c4ff92SAndroid Build Coastguard Worker };
1909*89c4ff92SAndroid Build Coastguard Worker
1910*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1911*89c4ff92SAndroid Build Coastguard Worker "Reference Mean: input type not supported.");
1912*89c4ff92SAndroid Build Coastguard Worker
1913*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1914*89c4ff92SAndroid Build Coastguard Worker "Reference Mean: input and output types are mismatched");
1915*89c4ff92SAndroid Build Coastguard Worker
1916*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_KeepDims)
1917*89c4ff92SAndroid Build Coastguard Worker {
1918*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1919*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1920*89c4ff92SAndroid Build Coastguard Worker CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1921*89c4ff92SAndroid Build Coastguard Worker output.GetNumDimensions(),
1922*89c4ff92SAndroid Build Coastguard Worker meanLayerStr, outputTensorStr).data());
1923*89c4ff92SAndroid Build Coastguard Worker }
1924*89c4ff92SAndroid Build Coastguard Worker else if (descriptor.m_Axis.empty())
1925*89c4ff92SAndroid Build Coastguard Worker {
1926*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1927*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1928*89c4ff92SAndroid Build Coastguard Worker CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1929*89c4ff92SAndroid Build Coastguard Worker meanLayerStr, outputTensorStr).data());
1930*89c4ff92SAndroid Build Coastguard Worker }
1931*89c4ff92SAndroid Build Coastguard Worker else
1932*89c4ff92SAndroid Build Coastguard Worker {
1933*89c4ff92SAndroid Build Coastguard Worker auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1934*89c4ff92SAndroid Build Coastguard Worker
1935*89c4ff92SAndroid Build Coastguard Worker if (outputDim > 0)
1936*89c4ff92SAndroid Build Coastguard Worker {
1937*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1938*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1939*89c4ff92SAndroid Build Coastguard Worker CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1940*89c4ff92SAndroid Build Coastguard Worker meanLayerStr, outputTensorStr).data());
1941*89c4ff92SAndroid Build Coastguard Worker }
1942*89c4ff92SAndroid Build Coastguard Worker else
1943*89c4ff92SAndroid Build Coastguard Worker {
1944*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1945*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1946*89c4ff92SAndroid Build Coastguard Worker CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1947*89c4ff92SAndroid Build Coastguard Worker meanLayerStr, outputTensorStr).data());
1948*89c4ff92SAndroid Build Coastguard Worker }
1949*89c4ff92SAndroid Build Coastguard Worker }
1950*89c4ff92SAndroid Build Coastguard Worker
1951*89c4ff92SAndroid Build Coastguard Worker return supported;
1952*89c4ff92SAndroid Build Coastguard Worker }
1953*89c4ff92SAndroid Build Coastguard Worker
IsMemCopySupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1954*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1955*89c4ff92SAndroid Build Coastguard Worker const TensorInfo &output,
1956*89c4ff92SAndroid Build Coastguard Worker Optional<std::string &> reasonIfUnsupported) const
1957*89c4ff92SAndroid Build Coastguard Worker {
1958*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1959*89c4ff92SAndroid Build Coastguard Worker
1960*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
1961*89c4ff92SAndroid Build Coastguard Worker {
1962*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1963*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1964*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1965*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1966*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1967*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1968*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean
1969*89c4ff92SAndroid Build Coastguard Worker };
1970*89c4ff92SAndroid Build Coastguard Worker
1971*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1972*89c4ff92SAndroid Build Coastguard Worker "Reference MemCopy: input type not supported");
1973*89c4ff92SAndroid Build Coastguard Worker
1974*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1975*89c4ff92SAndroid Build Coastguard Worker "Reference MemCopy: output type not supported");
1976*89c4ff92SAndroid Build Coastguard Worker
1977*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1978*89c4ff92SAndroid Build Coastguard Worker "Reference MemCopy: input and output types are mismatched");
1979*89c4ff92SAndroid Build Coastguard Worker
1980*89c4ff92SAndroid Build Coastguard Worker return supported;
1981*89c4ff92SAndroid Build Coastguard Worker }
1982*89c4ff92SAndroid Build Coastguard Worker
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1983*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1984*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1985*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1986*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1987*89c4ff92SAndroid Build Coastguard Worker {
1988*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
1989*89c4ff92SAndroid Build Coastguard Worker
1990*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes = {
1991*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1992*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1993*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1994*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1995*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1996*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1997*89c4ff92SAndroid Build Coastguard Worker };
1998*89c4ff92SAndroid Build Coastguard Worker
1999*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2000*89c4ff92SAndroid Build Coastguard Worker "Reference minimum: input 0 is not a supported type.");
2001*89c4ff92SAndroid Build Coastguard Worker
2002*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2003*89c4ff92SAndroid Build Coastguard Worker "Reference minimum: input 1 is not a supported type.");
2004*89c4ff92SAndroid Build Coastguard Worker
2005*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2006*89c4ff92SAndroid Build Coastguard Worker "Reference minimum: output is not a supported type.");
2007*89c4ff92SAndroid Build Coastguard Worker
2008*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2009*89c4ff92SAndroid Build Coastguard Worker "Reference minimum: input 0 and Input 1 types are mismatched");
2010*89c4ff92SAndroid Build Coastguard Worker
2011*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2012*89c4ff92SAndroid Build Coastguard Worker "Reference minimum: input and output types are mismatched");
2013*89c4ff92SAndroid Build Coastguard Worker
2014*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2015*89c4ff92SAndroid Build Coastguard Worker "Reference minimum: shapes are not suitable for implicit broadcast.");
2016*89c4ff92SAndroid Build Coastguard Worker
2017*89c4ff92SAndroid Build Coastguard Worker return supported;
2018*89c4ff92SAndroid Build Coastguard Worker }
2019*89c4ff92SAndroid Build Coastguard Worker
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2020*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2021*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
2022*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2023*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2024*89c4ff92SAndroid Build Coastguard Worker {
2025*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2026*89c4ff92SAndroid Build Coastguard Worker
2027*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes = {
2028*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2029*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2030*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2031*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2032*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2033*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2034*89c4ff92SAndroid Build Coastguard Worker };
2035*89c4ff92SAndroid Build Coastguard Worker
2036*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2037*89c4ff92SAndroid Build Coastguard Worker "Reference multiplication: input 0 is not a supported type.");
2038*89c4ff92SAndroid Build Coastguard Worker
2039*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2040*89c4ff92SAndroid Build Coastguard Worker "Reference multiplication: input 1 is not a supported type.");
2041*89c4ff92SAndroid Build Coastguard Worker
2042*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2043*89c4ff92SAndroid Build Coastguard Worker "Reference multiplication: output is not a supported type.");
2044*89c4ff92SAndroid Build Coastguard Worker
2045*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2046*89c4ff92SAndroid Build Coastguard Worker "Reference multiplication: input 0 and Input 1 types are mismatched");
2047*89c4ff92SAndroid Build Coastguard Worker
2048*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2049*89c4ff92SAndroid Build Coastguard Worker "Reference multiplication: input and output types are mismatched");
2050*89c4ff92SAndroid Build Coastguard Worker
2051*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2052*89c4ff92SAndroid Build Coastguard Worker "Reference multiplication: shapes are not suitable for implicit broadcast.");
2053*89c4ff92SAndroid Build Coastguard Worker
2054*89c4ff92SAndroid Build Coastguard Worker return supported;
2055*89c4ff92SAndroid Build Coastguard Worker }
2056*89c4ff92SAndroid Build Coastguard Worker
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2057*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2058*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2059*89c4ff92SAndroid Build Coastguard Worker const NormalizationDescriptor& descriptor,
2060*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2061*89c4ff92SAndroid Build Coastguard Worker {
2062*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2063*89c4ff92SAndroid Build Coastguard Worker
2064*89c4ff92SAndroid Build Coastguard Worker // Define supported types
2065*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes =
2066*89c4ff92SAndroid Build Coastguard Worker {
2067*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2068*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2069*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2070*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2071*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2072*89c4ff92SAndroid Build Coastguard Worker };
2073*89c4ff92SAndroid Build Coastguard Worker
2074*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2075*89c4ff92SAndroid Build Coastguard Worker
2076*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2077*89c4ff92SAndroid Build Coastguard Worker "Reference normalization: input type not supported.");
2078*89c4ff92SAndroid Build Coastguard Worker
2079*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2080*89c4ff92SAndroid Build Coastguard Worker "Reference normalization: output type not supported.");
2081*89c4ff92SAndroid Build Coastguard Worker
2082*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2083*89c4ff92SAndroid Build Coastguard Worker "Reference normalization: input and output shapes have different "
2084*89c4ff92SAndroid Build Coastguard Worker "num total elements.");
2085*89c4ff92SAndroid Build Coastguard Worker
2086*89c4ff92SAndroid Build Coastguard Worker return supported;
2087*89c4ff92SAndroid Build Coastguard Worker }
2088*89c4ff92SAndroid Build Coastguard Worker
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const2089*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2090*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported*/) const
2091*89c4ff92SAndroid Build Coastguard Worker {
2092*89c4ff92SAndroid Build Coastguard Worker return true;
2093*89c4ff92SAndroid Build Coastguard Worker }
2094*89c4ff92SAndroid Build Coastguard Worker
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2095*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2096*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2097*89c4ff92SAndroid Build Coastguard Worker const PadDescriptor& descriptor,
2098*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2099*89c4ff92SAndroid Build Coastguard Worker {
2100*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2101*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2102*89c4ff92SAndroid Build Coastguard Worker
2103*89c4ff92SAndroid Build Coastguard Worker // Define supported output and inputs types.
2104*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2105*89c4ff92SAndroid Build Coastguard Worker {
2106*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2107*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2108*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2109*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2110*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2111*89c4ff92SAndroid Build Coastguard Worker };
2112*89c4ff92SAndroid Build Coastguard Worker
2113*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2114*89c4ff92SAndroid Build Coastguard Worker "Reference pad: input is not a supported type.");
2115*89c4ff92SAndroid Build Coastguard Worker
2116*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2117*89c4ff92SAndroid Build Coastguard Worker "Reference pad: output is not a supported type.");
2118*89c4ff92SAndroid Build Coastguard Worker
2119*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2120*89c4ff92SAndroid Build Coastguard Worker "Reference pad: input and output types are mismatched.");
2121*89c4ff92SAndroid Build Coastguard Worker
2122*89c4ff92SAndroid Build Coastguard Worker return supported;
2123*89c4ff92SAndroid Build Coastguard Worker }
2124*89c4ff92SAndroid Build Coastguard Worker
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2125*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2126*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2127*89c4ff92SAndroid Build Coastguard Worker const PermuteDescriptor& descriptor,
2128*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2129*89c4ff92SAndroid Build Coastguard Worker {
2130*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2131*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2132*89c4ff92SAndroid Build Coastguard Worker
2133*89c4ff92SAndroid Build Coastguard Worker // Define supported output and inputs types.
2134*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes =
2135*89c4ff92SAndroid Build Coastguard Worker {
2136*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2137*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2138*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2139*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2140*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2141*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2142*89c4ff92SAndroid Build Coastguard Worker };
2143*89c4ff92SAndroid Build Coastguard Worker
2144*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145*89c4ff92SAndroid Build Coastguard Worker "Reference permute: input is not a supported type.");
2146*89c4ff92SAndroid Build Coastguard Worker
2147*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148*89c4ff92SAndroid Build Coastguard Worker "Reference permute: output is not a supported type.");
2149*89c4ff92SAndroid Build Coastguard Worker
2150*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151*89c4ff92SAndroid Build Coastguard Worker "Reference permute: input and output types are mismatched.");
2152*89c4ff92SAndroid Build Coastguard Worker
2153*89c4ff92SAndroid Build Coastguard Worker return supported;
2154*89c4ff92SAndroid Build Coastguard Worker }
2155*89c4ff92SAndroid Build Coastguard Worker
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2156*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2157*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2158*89c4ff92SAndroid Build Coastguard Worker const Pooling2dDescriptor& descriptor,
2159*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2160*89c4ff92SAndroid Build Coastguard Worker {
2161*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2162*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2163*89c4ff92SAndroid Build Coastguard Worker
2164*89c4ff92SAndroid Build Coastguard Worker // Define supported output and inputs types.
2165*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2166*89c4ff92SAndroid Build Coastguard Worker {
2167*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2168*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2169*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2170*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2171*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2172*89c4ff92SAndroid Build Coastguard Worker };
2173*89c4ff92SAndroid Build Coastguard Worker
2174*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2175*89c4ff92SAndroid Build Coastguard Worker "Reference poolind2d: input is not a supported type.");
2176*89c4ff92SAndroid Build Coastguard Worker
2177*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2178*89c4ff92SAndroid Build Coastguard Worker "Reference poolind2d: output is not a supported type.");
2179*89c4ff92SAndroid Build Coastguard Worker
2180*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2181*89c4ff92SAndroid Build Coastguard Worker "Reference poolind2d: input and output types are mismatched.");
2182*89c4ff92SAndroid Build Coastguard Worker
2183*89c4ff92SAndroid Build Coastguard Worker return supported;
2184*89c4ff92SAndroid Build Coastguard Worker }
2185*89c4ff92SAndroid Build Coastguard Worker
IsPooling3dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling3dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2186*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2187*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2188*89c4ff92SAndroid Build Coastguard Worker const Pooling3dDescriptor& descriptor,
2189*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2190*89c4ff92SAndroid Build Coastguard Worker {
2191*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2192*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2193*89c4ff92SAndroid Build Coastguard Worker
2194*89c4ff92SAndroid Build Coastguard Worker // Define supported output and inputs types.
2195*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2196*89c4ff92SAndroid Build Coastguard Worker {
2197*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2198*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2199*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2200*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2201*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2202*89c4ff92SAndroid Build Coastguard Worker };
2203*89c4ff92SAndroid Build Coastguard Worker
2204*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2205*89c4ff92SAndroid Build Coastguard Worker "Reference poolind3d: input is not a supported type.");
2206*89c4ff92SAndroid Build Coastguard Worker
2207*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2208*89c4ff92SAndroid Build Coastguard Worker "Reference poolind3d: output is not a supported type.");
2209*89c4ff92SAndroid Build Coastguard Worker
2210*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2211*89c4ff92SAndroid Build Coastguard Worker "Reference poolind3d: input and output types are mismatched.");
2212*89c4ff92SAndroid Build Coastguard Worker
2213*89c4ff92SAndroid Build Coastguard Worker return supported;
2214*89c4ff92SAndroid Build Coastguard Worker }
2215*89c4ff92SAndroid Build Coastguard Worker
2216*89c4ff92SAndroid Build Coastguard Worker
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2217*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2218*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& previousOutputIn,
2219*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& previousCellStateIn,
2220*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateOut,
2221*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
2222*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2223*89c4ff92SAndroid Build Coastguard Worker const QLstmDescriptor& descriptor,
2224*89c4ff92SAndroid Build Coastguard Worker const LstmInputParamsInfo& paramsInfo,
2225*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2226*89c4ff92SAndroid Build Coastguard Worker {
2227*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(input);
2228*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(previousOutputIn);
2229*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(previousCellStateIn);
2230*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(outputStateOut);
2231*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(cellStateOut);
2232*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(output);
2233*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2234*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(paramsInfo);
2235*89c4ff92SAndroid Build Coastguard Worker
2236*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(reasonIfUnsupported);
2237*89c4ff92SAndroid Build Coastguard Worker
2238*89c4ff92SAndroid Build Coastguard Worker return true;
2239*89c4ff92SAndroid Build Coastguard Worker }
2240*89c4ff92SAndroid Build Coastguard Worker
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2241*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2242*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2243*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2244*89c4ff92SAndroid Build Coastguard Worker {
2245*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2246*89c4ff92SAndroid Build Coastguard Worker
2247*89c4ff92SAndroid Build Coastguard Worker // Define supported input types.
2248*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedInputTypes = {
2249*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2250*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2251*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2252*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2253*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2254*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2255*89c4ff92SAndroid Build Coastguard Worker };
2256*89c4ff92SAndroid Build Coastguard Worker
2257*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2258*89c4ff92SAndroid Build Coastguard Worker "Reference quantize: input type not supported.");
2259*89c4ff92SAndroid Build Coastguard Worker
2260*89c4ff92SAndroid Build Coastguard Worker // Define supported output types.
2261*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,4> supportedOutputTypes = {
2262*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2263*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2264*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2265*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2266*89c4ff92SAndroid Build Coastguard Worker };
2267*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2268*89c4ff92SAndroid Build Coastguard Worker "Reference quantize: output type not supported.");
2269*89c4ff92SAndroid Build Coastguard Worker
2270*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2271*89c4ff92SAndroid Build Coastguard Worker "Reference quantize: input and output shapes have different num total elements.");
2272*89c4ff92SAndroid Build Coastguard Worker
2273*89c4ff92SAndroid Build Coastguard Worker return supported;
2274*89c4ff92SAndroid Build Coastguard Worker }
2275*89c4ff92SAndroid Build Coastguard Worker
IsRankSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2276*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2277*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2278*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2279*89c4ff92SAndroid Build Coastguard Worker {
2280*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(input);
2281*89c4ff92SAndroid Build Coastguard Worker // Define supported output types.
2282*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,1> supportedOutputTypes =
2283*89c4ff92SAndroid Build Coastguard Worker {
2284*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
2285*89c4ff92SAndroid Build Coastguard Worker };
2286*89c4ff92SAndroid Build Coastguard Worker
2287*89c4ff92SAndroid Build Coastguard Worker return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2288*89c4ff92SAndroid Build Coastguard Worker "Reference rank: input type not supported.");
2289*89c4ff92SAndroid Build Coastguard Worker }
2290*89c4ff92SAndroid Build Coastguard Worker
IsReduceSupported(const TensorInfo & input,const TensorInfo & output,const ReduceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2291*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2292*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2293*89c4ff92SAndroid Build Coastguard Worker const ReduceDescriptor& descriptor,
2294*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2295*89c4ff92SAndroid Build Coastguard Worker {
2296*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2297*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2298*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
2299*89c4ff92SAndroid Build Coastguard Worker {
2300*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2301*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2302*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2303*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2304*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2305*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2306*89c4ff92SAndroid Build Coastguard Worker };
2307*89c4ff92SAndroid Build Coastguard Worker
2308*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2309*89c4ff92SAndroid Build Coastguard Worker "Reference Reduce: input type not supported");
2310*89c4ff92SAndroid Build Coastguard Worker
2311*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2312*89c4ff92SAndroid Build Coastguard Worker "Reference Reduce: output type not supported");
2313*89c4ff92SAndroid Build Coastguard Worker
2314*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2315*89c4ff92SAndroid Build Coastguard Worker "Reference Reduce: input and output types not matching");
2316*89c4ff92SAndroid Build Coastguard Worker
2317*89c4ff92SAndroid Build Coastguard Worker return supported;
2318*89c4ff92SAndroid Build Coastguard Worker }
2319*89c4ff92SAndroid Build Coastguard Worker
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2320*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
2321*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2322*89c4ff92SAndroid Build Coastguard Worker const ReshapeDescriptor& descriptor,
2323*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2324*89c4ff92SAndroid Build Coastguard Worker {
2325*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(output);
2326*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2327*89c4ff92SAndroid Build Coastguard Worker // Define supported output types.
2328*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,8> supportedOutputTypes =
2329*89c4ff92SAndroid Build Coastguard Worker {
2330*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2331*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2332*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2333*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
2334*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2335*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2336*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2337*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean
2338*89c4ff92SAndroid Build Coastguard Worker };
2339*89c4ff92SAndroid Build Coastguard Worker
2340*89c4ff92SAndroid Build Coastguard Worker return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2341*89c4ff92SAndroid Build Coastguard Worker "Reference reshape: input type not supported.");
2342*89c4ff92SAndroid Build Coastguard Worker }
2343*89c4ff92SAndroid Build Coastguard Worker
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2344*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2345*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2346*89c4ff92SAndroid Build Coastguard Worker const ResizeDescriptor& descriptor,
2347*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2348*89c4ff92SAndroid Build Coastguard Worker {
2349*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2350*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2351*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2352*89c4ff92SAndroid Build Coastguard Worker {
2353*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2354*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2355*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2356*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2357*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2358*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2359*89c4ff92SAndroid Build Coastguard Worker };
2360*89c4ff92SAndroid Build Coastguard Worker
2361*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2362*89c4ff92SAndroid Build Coastguard Worker "Reference Resize: input type not supported");
2363*89c4ff92SAndroid Build Coastguard Worker
2364*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2365*89c4ff92SAndroid Build Coastguard Worker "Reference Resize: output type not supported");
2366*89c4ff92SAndroid Build Coastguard Worker
2367*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2368*89c4ff92SAndroid Build Coastguard Worker "Reference Resize: input and output types not matching");
2369*89c4ff92SAndroid Build Coastguard Worker
2370*89c4ff92SAndroid Build Coastguard Worker return supported;
2371*89c4ff92SAndroid Build Coastguard Worker }
2372*89c4ff92SAndroid Build Coastguard Worker
IsShapeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2373*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2374*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2375*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2376*89c4ff92SAndroid Build Coastguard Worker {
2377*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(input);
2378*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2379*89c4ff92SAndroid Build Coastguard Worker
2380*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 1> supportedTypes =
2381*89c4ff92SAndroid Build Coastguard Worker {
2382*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2383*89c4ff92SAndroid Build Coastguard Worker };
2384*89c4ff92SAndroid Build Coastguard Worker
2385*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2386*89c4ff92SAndroid Build Coastguard Worker "Reference Shape: output type not supported");
2387*89c4ff92SAndroid Build Coastguard Worker
2388*89c4ff92SAndroid Build Coastguard Worker return supported;
2389*89c4ff92SAndroid Build Coastguard Worker }
2390*89c4ff92SAndroid Build Coastguard Worker
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2391*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2392*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2393*89c4ff92SAndroid Build Coastguard Worker const SliceDescriptor& descriptor,
2394*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2395*89c4ff92SAndroid Build Coastguard Worker {
2396*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2397*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2398*89c4ff92SAndroid Build Coastguard Worker
2399*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 5> supportedTypes =
2400*89c4ff92SAndroid Build Coastguard Worker {
2401*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2402*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2403*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2404*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2405*89c4ff92SAndroid Build Coastguard Worker };
2406*89c4ff92SAndroid Build Coastguard Worker
2407*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2408*89c4ff92SAndroid Build Coastguard Worker "Reference Slice: input type not supported");
2409*89c4ff92SAndroid Build Coastguard Worker
2410*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2411*89c4ff92SAndroid Build Coastguard Worker "Reference Slice: output type not supported");
2412*89c4ff92SAndroid Build Coastguard Worker
2413*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2414*89c4ff92SAndroid Build Coastguard Worker "Reference Slice: input and output types are mismatched");
2415*89c4ff92SAndroid Build Coastguard Worker
2416*89c4ff92SAndroid Build Coastguard Worker return supported;
2417*89c4ff92SAndroid Build Coastguard Worker }
2418*89c4ff92SAndroid Build Coastguard Worker
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2419*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2420*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2421*89c4ff92SAndroid Build Coastguard Worker const SoftmaxDescriptor& descriptor,
2422*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2423*89c4ff92SAndroid Build Coastguard Worker {
2424*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2425*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2426*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
2427*89c4ff92SAndroid Build Coastguard Worker {
2428*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2429*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2430*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2431*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2432*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2433*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2434*89c4ff92SAndroid Build Coastguard Worker };
2435*89c4ff92SAndroid Build Coastguard Worker
2436*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2437*89c4ff92SAndroid Build Coastguard Worker "Reference Softmax: output type not supported");
2438*89c4ff92SAndroid Build Coastguard Worker
2439*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2440*89c4ff92SAndroid Build Coastguard Worker "Reference Softmax: input type not supported");
2441*89c4ff92SAndroid Build Coastguard Worker
2442*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2443*89c4ff92SAndroid Build Coastguard Worker "Reference Softmax: input type not supported");
2444*89c4ff92SAndroid Build Coastguard Worker
2445*89c4ff92SAndroid Build Coastguard Worker return supported;
2446*89c4ff92SAndroid Build Coastguard Worker }
2447*89c4ff92SAndroid Build Coastguard Worker
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2448*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2449*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2450*89c4ff92SAndroid Build Coastguard Worker const SpaceToBatchNdDescriptor& descriptor,
2451*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2452*89c4ff92SAndroid Build Coastguard Worker {
2453*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2454*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2455*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2456*89c4ff92SAndroid Build Coastguard Worker {
2457*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2458*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2459*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2460*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2461*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2462*89c4ff92SAndroid Build Coastguard Worker };
2463*89c4ff92SAndroid Build Coastguard Worker
2464*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2465*89c4ff92SAndroid Build Coastguard Worker "Reference SpaceToBatchNd: input type not supported");
2466*89c4ff92SAndroid Build Coastguard Worker
2467*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2468*89c4ff92SAndroid Build Coastguard Worker "Reference SpaceToBatchNd: output type not supported");
2469*89c4ff92SAndroid Build Coastguard Worker
2470*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2471*89c4ff92SAndroid Build Coastguard Worker "Reference SpaceToBatchNd: input and output types are mismatched");
2472*89c4ff92SAndroid Build Coastguard Worker
2473*89c4ff92SAndroid Build Coastguard Worker return supported;
2474*89c4ff92SAndroid Build Coastguard Worker }
2475*89c4ff92SAndroid Build Coastguard Worker
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2476*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
2477*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2478*89c4ff92SAndroid Build Coastguard Worker const SpaceToDepthDescriptor& descriptor,
2479*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2480*89c4ff92SAndroid Build Coastguard Worker {
2481*89c4ff92SAndroid Build Coastguard Worker
2482*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2483*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2484*89c4ff92SAndroid Build Coastguard Worker
2485*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2486*89c4ff92SAndroid Build Coastguard Worker {
2487*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2488*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2489*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2490*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2491*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2492*89c4ff92SAndroid Build Coastguard Worker };
2493*89c4ff92SAndroid Build Coastguard Worker
2494*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2495*89c4ff92SAndroid Build Coastguard Worker "Reference SpaceToDepth: input type not supported");
2496*89c4ff92SAndroid Build Coastguard Worker
2497*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2498*89c4ff92SAndroid Build Coastguard Worker "Reference SpaceToDepth: output type not supported");
2499*89c4ff92SAndroid Build Coastguard Worker
2500*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2501*89c4ff92SAndroid Build Coastguard Worker "Reference SpaceToDepth: input and output types are mismatched");
2502*89c4ff92SAndroid Build Coastguard Worker
2503*89c4ff92SAndroid Build Coastguard Worker return supported;
2504*89c4ff92SAndroid Build Coastguard Worker }
2505*89c4ff92SAndroid Build Coastguard Worker
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2506*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2507*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2508*89c4ff92SAndroid Build Coastguard Worker const ViewsDescriptor& descriptor,
2509*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2510*89c4ff92SAndroid Build Coastguard Worker {
2511*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2512*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2513*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,6> supportedTypes =
2514*89c4ff92SAndroid Build Coastguard Worker {
2515*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2516*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2517*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2518*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2519*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2520*89c4ff92SAndroid Build Coastguard Worker };
2521*89c4ff92SAndroid Build Coastguard Worker
2522*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2523*89c4ff92SAndroid Build Coastguard Worker "Reference splitter: output type not supported");
2524*89c4ff92SAndroid Build Coastguard Worker for (const TensorInfo& output : outputs)
2525*89c4ff92SAndroid Build Coastguard Worker {
2526*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2527*89c4ff92SAndroid Build Coastguard Worker "Reference splitter: input type not supported");
2528*89c4ff92SAndroid Build Coastguard Worker
2529*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2530*89c4ff92SAndroid Build Coastguard Worker "Reference splitter: input and output types mismatched.");
2531*89c4ff92SAndroid Build Coastguard Worker }
2532*89c4ff92SAndroid Build Coastguard Worker
2533*89c4ff92SAndroid Build Coastguard Worker return supported;
2534*89c4ff92SAndroid Build Coastguard Worker }
2535*89c4ff92SAndroid Build Coastguard Worker
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2536*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2537*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2538*89c4ff92SAndroid Build Coastguard Worker const StackDescriptor& descriptor,
2539*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2540*89c4ff92SAndroid Build Coastguard Worker {
2541*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2542*89c4ff92SAndroid Build Coastguard Worker
2543*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2544*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
2545*89c4ff92SAndroid Build Coastguard Worker {
2546*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2547*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2548*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2549*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2550*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2551*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2552*89c4ff92SAndroid Build Coastguard Worker };
2553*89c4ff92SAndroid Build Coastguard Worker
2554*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2555*89c4ff92SAndroid Build Coastguard Worker "Reference stack: output type not supported");
2556*89c4ff92SAndroid Build Coastguard Worker for (const TensorInfo* input : inputs)
2557*89c4ff92SAndroid Build Coastguard Worker {
2558*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(input != nullptr);
2559*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2560*89c4ff92SAndroid Build Coastguard Worker "Reference stack: input type not supported");
2561*89c4ff92SAndroid Build Coastguard Worker
2562*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2563*89c4ff92SAndroid Build Coastguard Worker "Reference stack: input and output types mismatched.");
2564*89c4ff92SAndroid Build Coastguard Worker }
2565*89c4ff92SAndroid Build Coastguard Worker
2566*89c4ff92SAndroid Build Coastguard Worker return supported;
2567*89c4ff92SAndroid Build Coastguard Worker }
2568*89c4ff92SAndroid Build Coastguard Worker
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2569*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2570*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2571*89c4ff92SAndroid Build Coastguard Worker const StridedSliceDescriptor& descriptor,
2572*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2573*89c4ff92SAndroid Build Coastguard Worker {
2574*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2575*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2576*89c4ff92SAndroid Build Coastguard Worker
2577*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,5> supportedTypes =
2578*89c4ff92SAndroid Build Coastguard Worker {
2579*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2580*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2581*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2582*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2583*89c4ff92SAndroid Build Coastguard Worker };
2584*89c4ff92SAndroid Build Coastguard Worker
2585*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2586*89c4ff92SAndroid Build Coastguard Worker "Reference StridedSlice: input type not supported");
2587*89c4ff92SAndroid Build Coastguard Worker
2588*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2589*89c4ff92SAndroid Build Coastguard Worker "Reference StridedSlice: output type not supported");
2590*89c4ff92SAndroid Build Coastguard Worker
2591*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2592*89c4ff92SAndroid Build Coastguard Worker "Reference StridedSlice: input and output types are mismatched");
2593*89c4ff92SAndroid Build Coastguard Worker
2594*89c4ff92SAndroid Build Coastguard Worker return supported;
2595*89c4ff92SAndroid Build Coastguard Worker }
2596*89c4ff92SAndroid Build Coastguard Worker
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2597*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2598*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
2599*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2600*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2601*89c4ff92SAndroid Build Coastguard Worker {
2602*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2603*89c4ff92SAndroid Build Coastguard Worker
2604*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes = {
2605*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2606*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2607*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2608*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2609*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2610*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2611*89c4ff92SAndroid Build Coastguard Worker };
2612*89c4ff92SAndroid Build Coastguard Worker
2613*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2614*89c4ff92SAndroid Build Coastguard Worker "Reference subtraction: input 0 is not a supported type.");
2615*89c4ff92SAndroid Build Coastguard Worker
2616*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2617*89c4ff92SAndroid Build Coastguard Worker "Reference subtraction: input 1 is not a supported type.");
2618*89c4ff92SAndroid Build Coastguard Worker
2619*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2620*89c4ff92SAndroid Build Coastguard Worker "Reference subtraction: output is not a supported type.");
2621*89c4ff92SAndroid Build Coastguard Worker
2622*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2623*89c4ff92SAndroid Build Coastguard Worker "Reference subtraction: input 0 and Input 1 types are mismatched");
2624*89c4ff92SAndroid Build Coastguard Worker
2625*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2626*89c4ff92SAndroid Build Coastguard Worker "Reference subtraction: input and output types are mismatched");
2627*89c4ff92SAndroid Build Coastguard Worker
2628*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2629*89c4ff92SAndroid Build Coastguard Worker "Reference subtraction: shapes are not suitable for implicit broadcast.");
2630*89c4ff92SAndroid Build Coastguard Worker
2631*89c4ff92SAndroid Build Coastguard Worker return supported;
2632*89c4ff92SAndroid Build Coastguard Worker }
2633*89c4ff92SAndroid Build Coastguard Worker
IsPreluSupported(const TensorInfo & input,const TensorInfo & alpha,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2634*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2635*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& alpha,
2636*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2637*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2638*89c4ff92SAndroid Build Coastguard Worker {
2639*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2640*89c4ff92SAndroid Build Coastguard Worker
2641*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes
2642*89c4ff92SAndroid Build Coastguard Worker {
2643*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2644*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2645*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2646*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2647*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2648*89c4ff92SAndroid Build Coastguard Worker };
2649*89c4ff92SAndroid Build Coastguard Worker
2650*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2651*89c4ff92SAndroid Build Coastguard Worker "PReLU: input is not a supported type.");
2652*89c4ff92SAndroid Build Coastguard Worker
2653*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2654*89c4ff92SAndroid Build Coastguard Worker "PReLU: alpha is not a supported type.");
2655*89c4ff92SAndroid Build Coastguard Worker
2656*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2657*89c4ff92SAndroid Build Coastguard Worker "PReLU: output is not a supported type.");
2658*89c4ff92SAndroid Build Coastguard Worker
2659*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2660*89c4ff92SAndroid Build Coastguard Worker "PReLU: input, alpha and output types are mismatched");
2661*89c4ff92SAndroid Build Coastguard Worker
2662*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2663*89c4ff92SAndroid Build Coastguard Worker "PReLU: shapes are not suitable for implicit broadcast");
2664*89c4ff92SAndroid Build Coastguard Worker
2665*89c4ff92SAndroid Build Coastguard Worker return supported;
2666*89c4ff92SAndroid Build Coastguard Worker }
2667*89c4ff92SAndroid Build Coastguard Worker
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const2668*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2669*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2670*89c4ff92SAndroid Build Coastguard Worker const TransposeConvolution2dDescriptor& descriptor,
2671*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
2672*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
2673*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2674*89c4ff92SAndroid Build Coastguard Worker {
2675*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2676*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2677*89c4ff92SAndroid Build Coastguard Worker
2678*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,7> supportedTypes =
2679*89c4ff92SAndroid Build Coastguard Worker {
2680*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2681*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2682*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2683*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2684*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2685*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2686*89c4ff92SAndroid Build Coastguard Worker };
2687*89c4ff92SAndroid Build Coastguard Worker
2688*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2689*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: input is not a supported type.");
2690*89c4ff92SAndroid Build Coastguard Worker
2691*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2692*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: output is not a supported type.");
2693*89c4ff92SAndroid Build Coastguard Worker
2694*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2695*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: input and output types mismatched.");
2696*89c4ff92SAndroid Build Coastguard Worker
2697*89c4ff92SAndroid Build Coastguard Worker
2698*89c4ff92SAndroid Build Coastguard Worker const DataType inputType = input.GetDataType();
2699*89c4ff92SAndroid Build Coastguard Worker if (IsQuantized8BitType(inputType))
2700*89c4ff92SAndroid Build Coastguard Worker {
2701*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedWeightTypes =
2702*89c4ff92SAndroid Build Coastguard Worker {
2703*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2704*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2705*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
2706*89c4ff92SAndroid Build Coastguard Worker };
2707*89c4ff92SAndroid Build Coastguard Worker
2708*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2709*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: weights type not supported for "
2710*89c4ff92SAndroid Build Coastguard Worker "quantized input.");
2711*89c4ff92SAndroid Build Coastguard Worker }
2712*89c4ff92SAndroid Build Coastguard Worker else
2713*89c4ff92SAndroid Build Coastguard Worker {
2714*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2715*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: weights is not a supported type.");
2716*89c4ff92SAndroid Build Coastguard Worker
2717*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2718*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: input and weights types mismatched.");
2719*89c4ff92SAndroid Build Coastguard Worker }
2720*89c4ff92SAndroid Build Coastguard Worker
2721*89c4ff92SAndroid Build Coastguard Worker if (biases.has_value())
2722*89c4ff92SAndroid Build Coastguard Worker {
2723*89c4ff92SAndroid Build Coastguard Worker std::array<DataType,4> biasesSupportedTypes =
2724*89c4ff92SAndroid Build Coastguard Worker {
2725*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2726*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2727*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2728*89c4ff92SAndroid Build Coastguard Worker };
2729*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2730*89c4ff92SAndroid Build Coastguard Worker "Reference TransposeConvolution2d: biases is not a supported type.");
2731*89c4ff92SAndroid Build Coastguard Worker }
2732*89c4ff92SAndroid Build Coastguard Worker
2733*89c4ff92SAndroid Build Coastguard Worker return supported;
2734*89c4ff92SAndroid Build Coastguard Worker }
2735*89c4ff92SAndroid Build Coastguard Worker
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2736*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2737*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2738*89c4ff92SAndroid Build Coastguard Worker const TransposeDescriptor& descriptor,
2739*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2740*89c4ff92SAndroid Build Coastguard Worker {
2741*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2742*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2743*89c4ff92SAndroid Build Coastguard Worker
2744*89c4ff92SAndroid Build Coastguard Worker // Define supported output and inputs types.
2745*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 6> supportedTypes =
2746*89c4ff92SAndroid Build Coastguard Worker {
2747*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2748*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2749*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2750*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2751*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2752*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2753*89c4ff92SAndroid Build Coastguard Worker };
2754*89c4ff92SAndroid Build Coastguard Worker
2755*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2756*89c4ff92SAndroid Build Coastguard Worker "Reference transpose: input is not a supported type.");
2757*89c4ff92SAndroid Build Coastguard Worker
2758*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2759*89c4ff92SAndroid Build Coastguard Worker "Reference transpose: output is not a supported type.");
2760*89c4ff92SAndroid Build Coastguard Worker
2761*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2762*89c4ff92SAndroid Build Coastguard Worker "Reference transpose: input and output types are mismatched.");
2763*89c4ff92SAndroid Build Coastguard Worker
2764*89c4ff92SAndroid Build Coastguard Worker return supported;
2765*89c4ff92SAndroid Build Coastguard Worker }
2766*89c4ff92SAndroid Build Coastguard Worker
IsUnidirectionalSequenceLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2767*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2768*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input,
2769*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateIn,
2770*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateIn,
2771*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateOut,
2772*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
2773*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
2774*89c4ff92SAndroid Build Coastguard Worker const UnidirectionalSequenceLstmDescriptor& descriptor,
2775*89c4ff92SAndroid Build Coastguard Worker const LstmInputParamsInfo& paramsInfo,
2776*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
2777*89c4ff92SAndroid Build Coastguard Worker {
2778*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
2779*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(paramsInfo);
2780*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(outputStateIn);
2781*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(cellStateIn);
2782*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(outputStateOut);
2783*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(cellStateOut);
2784*89c4ff92SAndroid Build Coastguard Worker bool supported = true;
2785*89c4ff92SAndroid Build Coastguard Worker
2786*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 2> supportedTypes =
2787*89c4ff92SAndroid Build Coastguard Worker {
2788*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2789*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8
2790*89c4ff92SAndroid Build Coastguard Worker };
2791*89c4ff92SAndroid Build Coastguard Worker
2792*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 2> supportedWeightTypes =
2793*89c4ff92SAndroid Build Coastguard Worker {
2794*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2795*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8
2796*89c4ff92SAndroid Build Coastguard Worker };
2797*89c4ff92SAndroid Build Coastguard Worker
2798*89c4ff92SAndroid Build Coastguard Worker std::array<DataType, 3> supportedBiasTypes =
2799*89c4ff92SAndroid Build Coastguard Worker {
2800*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2801*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2802*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2803*89c4ff92SAndroid Build Coastguard Worker };
2804*89c4ff92SAndroid Build Coastguard Worker
2805*89c4ff92SAndroid Build Coastguard Worker // check inputs and outputs
2806*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2807*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2808*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2809*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: output is not a supported type.");
2810*89c4ff92SAndroid Build Coastguard Worker
2811*89c4ff92SAndroid Build Coastguard Worker // check layer parameters
2812*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2813*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2814*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2815*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2816*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2817*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2818*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2819*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2820*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2821*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2822*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2823*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2824*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2825*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2826*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2827*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2828*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2829*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2830*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2831*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2832*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2833*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2834*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2835*89c4ff92SAndroid Build Coastguard Worker
2836*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2837*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2838*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2839*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2840*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2841*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
2842*89c4ff92SAndroid Build Coastguard Worker if (!descriptor.m_CifgEnabled)
2843*89c4ff92SAndroid Build Coastguard Worker {
2844*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2845*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2846*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2847*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2848*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2849*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2850*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2851*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2852*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2853*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
2854*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_PeepholeEnabled)
2855*89c4ff92SAndroid Build Coastguard Worker {
2856*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2857*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2858*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2859*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2860*89c4ff92SAndroid Build Coastguard Worker }
2861*89c4ff92SAndroid Build Coastguard Worker }
2862*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_PeepholeEnabled)
2863*89c4ff92SAndroid Build Coastguard Worker {
2864*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2865*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2866*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2867*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2868*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2869*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2870*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2871*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2872*89c4ff92SAndroid Build Coastguard Worker }
2873*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_ProjectionEnabled)
2874*89c4ff92SAndroid Build Coastguard Worker {
2875*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2876*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2877*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2878*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2879*89c4ff92SAndroid Build Coastguard Worker if (paramsInfo.m_ProjectionBias != nullptr)
2880*89c4ff92SAndroid Build Coastguard Worker {
2881*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2882*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2883*89c4ff92SAndroid Build Coastguard Worker "are mismatched");
2884*89c4ff92SAndroid Build Coastguard Worker }
2885*89c4ff92SAndroid Build Coastguard Worker }
2886*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_LayerNormEnabled)
2887*89c4ff92SAndroid Build Coastguard Worker {
2888*89c4ff92SAndroid Build Coastguard Worker if (!descriptor.m_CifgEnabled)
2889*89c4ff92SAndroid Build Coastguard Worker {
2890*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2891*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2892*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2893*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2894*89c4ff92SAndroid Build Coastguard Worker }
2895*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2896*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2897*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2898*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2899*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2900*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2901*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2902*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2903*89c4ff92SAndroid Build Coastguard Worker supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2904*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
2905*89c4ff92SAndroid Build Coastguard Worker "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2906*89c4ff92SAndroid Build Coastguard Worker "is not a supported type.");
2907*89c4ff92SAndroid Build Coastguard Worker }
2908*89c4ff92SAndroid Build Coastguard Worker
2909*89c4ff92SAndroid Build Coastguard Worker return supported;
2910*89c4ff92SAndroid Build Coastguard Worker }
2911*89c4ff92SAndroid Build Coastguard Worker
2912*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
2913