xref: /aosp_15_r20/external/armnn/src/backends/reference/RefLayerSupport.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "RefLayerSupport.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <LayerSupportCommon.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/LayerSupportRules.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker #include <array>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker namespace armnn
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker namespace
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker template<typename Float32Func, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeRef(Optional<std::string &> reasonIfUnsupported,DataType dataType,Float32Func floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)27*89c4ff92SAndroid Build Coastguard Worker bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28*89c4ff92SAndroid Build Coastguard Worker                                DataType dataType,
29*89c4ff92SAndroid Build Coastguard Worker                                Float32Func floatFuncPtr,
30*89c4ff92SAndroid Build Coastguard Worker                                Uint8Func uint8FuncPtr,
31*89c4ff92SAndroid Build Coastguard Worker                                Params&&... params)
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34*89c4ff92SAndroid Build Coastguard Worker                                          dataType,
35*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFunc<Params...>,
36*89c4ff92SAndroid Build Coastguard Worker                                          floatFuncPtr,
37*89c4ff92SAndroid Build Coastguard Worker                                          uint8FuncPtr,
38*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFunc<Params...>,
39*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFunc<Params...>,
40*89c4ff92SAndroid Build Coastguard Worker                                          std::forward<Params>(params)...);
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker namespace
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker 
CreateIncorrectDimensionsErrorMsg(unsigned int expected,unsigned int actual,std::string & layerStr,std::string & tensorName)48*89c4ff92SAndroid Build Coastguard Worker std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49*89c4ff92SAndroid Build Coastguard Worker                                               unsigned int actual,
50*89c4ff92SAndroid Build Coastguard Worker                                               std::string& layerStr,
51*89c4ff92SAndroid Build Coastguard Worker                                               std::string& tensorName)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker     std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54*89c4ff92SAndroid Build Coastguard Worker                            " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     return errorMsg;
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
60*89c4ff92SAndroid Build Coastguard Worker 
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmInputParamsInfo,Optional<std::string &> reasonIfUnsupported) const61*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62*89c4ff92SAndroid Build Coastguard Worker                                        const std::vector<TensorInfo>& infos,
63*89c4ff92SAndroid Build Coastguard Worker                                        const BaseDescriptor& descriptor,
64*89c4ff92SAndroid Build Coastguard Worker                                        const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65*89c4ff92SAndroid Build Coastguard Worker                                        const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker     switch (type)
69*89c4ff92SAndroid Build Coastguard Worker     {
70*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Activation:
71*89c4ff92SAndroid Build Coastguard Worker             return IsActivationSupported(infos[0],
72*89c4ff92SAndroid Build Coastguard Worker                                          infos[1],
73*89c4ff92SAndroid Build Coastguard Worker                                          *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74*89c4ff92SAndroid Build Coastguard Worker                                          reasonIfUnsupported);
75*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Addition:
76*89c4ff92SAndroid Build Coastguard Worker             return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ArgMinMax:
78*89c4ff92SAndroid Build Coastguard Worker             return IsArgMinMaxSupported(infos[0],
79*89c4ff92SAndroid Build Coastguard Worker                                         infos[1],
80*89c4ff92SAndroid Build Coastguard Worker                                         *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81*89c4ff92SAndroid Build Coastguard Worker                                         reasonIfUnsupported);
82*89c4ff92SAndroid Build Coastguard Worker         case LayerType::BatchMatMul:
83*89c4ff92SAndroid Build Coastguard Worker             return IsBatchMatMulSupported(infos[0],
84*89c4ff92SAndroid Build Coastguard Worker                                           infos[1],
85*89c4ff92SAndroid Build Coastguard Worker                                           infos[2],
86*89c4ff92SAndroid Build Coastguard Worker                                           *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported);
88*89c4ff92SAndroid Build Coastguard Worker         case LayerType::BatchNormalization:
89*89c4ff92SAndroid Build Coastguard Worker             return IsBatchNormalizationSupported(infos[0],
90*89c4ff92SAndroid Build Coastguard Worker                                                  infos[1],
91*89c4ff92SAndroid Build Coastguard Worker                                                  infos[2],
92*89c4ff92SAndroid Build Coastguard Worker                                                  infos[3],
93*89c4ff92SAndroid Build Coastguard Worker                                                  infos[4],
94*89c4ff92SAndroid Build Coastguard Worker                                                  infos[5],
95*89c4ff92SAndroid Build Coastguard Worker                                                  *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96*89c4ff92SAndroid Build Coastguard Worker                                                      (&descriptor)),
97*89c4ff92SAndroid Build Coastguard Worker                                                  reasonIfUnsupported);
98*89c4ff92SAndroid Build Coastguard Worker         case LayerType::BatchToSpaceNd:
99*89c4ff92SAndroid Build Coastguard Worker             return IsBatchToSpaceNdSupported(infos[0],
100*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
101*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
103*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Comparison:
104*89c4ff92SAndroid Build Coastguard Worker             return IsComparisonSupported(infos[0],
105*89c4ff92SAndroid Build Coastguard Worker                                          infos[1],
106*89c4ff92SAndroid Build Coastguard Worker                                          infos[2],
107*89c4ff92SAndroid Build Coastguard Worker                                          *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108*89c4ff92SAndroid Build Coastguard Worker                                          reasonIfUnsupported);
109*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Concat:
110*89c4ff92SAndroid Build Coastguard Worker         {
111*89c4ff92SAndroid Build Coastguard Worker             std::vector<const TensorInfo*> inputInfos;
112*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 0; i < (infos.size() - 1); i++)
113*89c4ff92SAndroid Build Coastguard Worker             {
114*89c4ff92SAndroid Build Coastguard Worker                 inputInfos.push_back(&infos[i]);
115*89c4ff92SAndroid Build Coastguard Worker             }
116*89c4ff92SAndroid Build Coastguard Worker             return IsConcatSupported(inputInfos,
117*89c4ff92SAndroid Build Coastguard Worker                                      infos[infos.size() - 1],
118*89c4ff92SAndroid Build Coastguard Worker                                      *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119*89c4ff92SAndroid Build Coastguard Worker                                      reasonIfUnsupported);
120*89c4ff92SAndroid Build Coastguard Worker         }
121*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Constant:
122*89c4ff92SAndroid Build Coastguard Worker             return IsConstantSupported(infos[0], reasonIfUnsupported);
123*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ConvertFp16ToFp32:
124*89c4ff92SAndroid Build Coastguard Worker             return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
125*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ConvertFp32ToFp16:
126*89c4ff92SAndroid Build Coastguard Worker             return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Convolution2d:
128*89c4ff92SAndroid Build Coastguard Worker         {
129*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
130*89c4ff92SAndroid Build Coastguard Worker             {
131*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
133*89c4ff92SAndroid Build Coastguard Worker             }
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
137*89c4ff92SAndroid Build Coastguard Worker             {
138*89c4ff92SAndroid Build Coastguard Worker                 return IsConvolution2dSupported(infos[0],
139*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
140*89c4ff92SAndroid Build Coastguard Worker                                                 desc,
141*89c4ff92SAndroid Build Coastguard Worker                                                 infos[2],
142*89c4ff92SAndroid Build Coastguard Worker                                                 EmptyOptional(),
143*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
144*89c4ff92SAndroid Build Coastguard Worker             }
145*89c4ff92SAndroid Build Coastguard Worker             else
146*89c4ff92SAndroid Build Coastguard Worker             {
147*89c4ff92SAndroid Build Coastguard Worker                 return IsConvolution2dSupported(infos[0],
148*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
149*89c4ff92SAndroid Build Coastguard Worker                                                 desc,
150*89c4ff92SAndroid Build Coastguard Worker                                                 infos[2],
151*89c4ff92SAndroid Build Coastguard Worker                                                 infos[3],
152*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
153*89c4ff92SAndroid Build Coastguard Worker             }
154*89c4ff92SAndroid Build Coastguard Worker         }
155*89c4ff92SAndroid Build Coastguard Worker         case LayerType::DepthToSpace:
156*89c4ff92SAndroid Build Coastguard Worker             return IsDepthToSpaceSupported(infos[0],
157*89c4ff92SAndroid Build Coastguard Worker                                            infos[1],
158*89c4ff92SAndroid Build Coastguard Worker                                            *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported);
160*89c4ff92SAndroid Build Coastguard Worker         case LayerType::DepthwiseConvolution2d:
161*89c4ff92SAndroid Build Coastguard Worker         {
162*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
163*89c4ff92SAndroid Build Coastguard Worker             {
164*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
166*89c4ff92SAndroid Build Coastguard Worker             }
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
170*89c4ff92SAndroid Build Coastguard Worker             {
171*89c4ff92SAndroid Build Coastguard Worker                 return IsDepthwiseConvolutionSupported(infos[0],
172*89c4ff92SAndroid Build Coastguard Worker                                                        infos[1],
173*89c4ff92SAndroid Build Coastguard Worker                                                        desc,
174*89c4ff92SAndroid Build Coastguard Worker                                                        infos[2],
175*89c4ff92SAndroid Build Coastguard Worker                                                        EmptyOptional(),
176*89c4ff92SAndroid Build Coastguard Worker                                                        reasonIfUnsupported);
177*89c4ff92SAndroid Build Coastguard Worker             }
178*89c4ff92SAndroid Build Coastguard Worker             else
179*89c4ff92SAndroid Build Coastguard Worker             {
180*89c4ff92SAndroid Build Coastguard Worker                 return IsDepthwiseConvolutionSupported(infos[0],
181*89c4ff92SAndroid Build Coastguard Worker                                                        infos[1],
182*89c4ff92SAndroid Build Coastguard Worker                                                        desc,
183*89c4ff92SAndroid Build Coastguard Worker                                                        infos[2],
184*89c4ff92SAndroid Build Coastguard Worker                                                        infos[3],
185*89c4ff92SAndroid Build Coastguard Worker                                                        reasonIfUnsupported);
186*89c4ff92SAndroid Build Coastguard Worker             }
187*89c4ff92SAndroid Build Coastguard Worker         }
188*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Dequantize:
189*89c4ff92SAndroid Build Coastguard Worker             return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Division:
191*89c4ff92SAndroid Build Coastguard Worker             return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
192*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ElementwiseBinary:
193*89c4ff92SAndroid Build Coastguard Worker         {
194*89c4ff92SAndroid Build Coastguard Worker             std::array<DataType, 7> supportedTypes =
195*89c4ff92SAndroid Build Coastguard Worker                     {
196*89c4ff92SAndroid Build Coastguard Worker                             DataType::Float32,
197*89c4ff92SAndroid Build Coastguard Worker                             DataType::Float16,
198*89c4ff92SAndroid Build Coastguard Worker                             DataType::QAsymmS8,
199*89c4ff92SAndroid Build Coastguard Worker                             DataType::QAsymmU8,
200*89c4ff92SAndroid Build Coastguard Worker                             DataType::QSymmS16,
201*89c4ff92SAndroid Build Coastguard Worker                             DataType::Signed32
202*89c4ff92SAndroid Build Coastguard Worker                     };
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker             bool supported = true;
205*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206*89c4ff92SAndroid Build Coastguard Worker                                           "Reference elementwise unary: input type not supported");
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209*89c4ff92SAndroid Build Coastguard Worker                                           "Reference elementwise unary: input type not supported");
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212*89c4ff92SAndroid Build Coastguard Worker                                           "Reference elementwise unary: output type not supported");
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215*89c4ff92SAndroid Build Coastguard Worker                                           "Reference elementwise unary: input types not matching");
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218*89c4ff92SAndroid Build Coastguard Worker                                           "Reference elementwise unary: input and output types not matching");
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker             return supported;
221*89c4ff92SAndroid Build Coastguard Worker         }
222*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ElementwiseUnary:
223*89c4ff92SAndroid Build Coastguard Worker             return IsElementwiseUnarySupported(infos[0],
224*89c4ff92SAndroid Build Coastguard Worker                                                infos[1],
225*89c4ff92SAndroid Build Coastguard Worker                                                *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226*89c4ff92SAndroid Build Coastguard Worker                                                reasonIfUnsupported);
227*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Fill:
228*89c4ff92SAndroid Build Coastguard Worker             return IsFillSupported(infos[0],
229*89c4ff92SAndroid Build Coastguard Worker                                    infos[1],
230*89c4ff92SAndroid Build Coastguard Worker                                    *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported);
232*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Floor:
233*89c4ff92SAndroid Build Coastguard Worker             return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234*89c4ff92SAndroid Build Coastguard Worker         case LayerType::FullyConnected:
235*89c4ff92SAndroid Build Coastguard Worker             return IsFullyConnectedSupported(infos[0],
236*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
237*89c4ff92SAndroid Build Coastguard Worker                                              infos[2],
238*89c4ff92SAndroid Build Coastguard Worker                                              infos[3],
239*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
241*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Gather:
242*89c4ff92SAndroid Build Coastguard Worker             return IsGatherSupported(infos[0],
243*89c4ff92SAndroid Build Coastguard Worker                                      infos[1],
244*89c4ff92SAndroid Build Coastguard Worker                                      infos[2],
245*89c4ff92SAndroid Build Coastguard Worker                                      *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246*89c4ff92SAndroid Build Coastguard Worker                                      reasonIfUnsupported);
247*89c4ff92SAndroid Build Coastguard Worker         case LayerType::GatherNd:
248*89c4ff92SAndroid Build Coastguard Worker             return IsGatherNdSupported(infos[0],
249*89c4ff92SAndroid Build Coastguard Worker                                        infos[1],
250*89c4ff92SAndroid Build Coastguard Worker                                        infos[2],
251*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported);
252*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Input:
253*89c4ff92SAndroid Build Coastguard Worker             return IsInputSupported(infos[0], reasonIfUnsupported);
254*89c4ff92SAndroid Build Coastguard Worker         case LayerType::InstanceNormalization:
255*89c4ff92SAndroid Build Coastguard Worker             return IsInstanceNormalizationSupported(infos[0],
256*89c4ff92SAndroid Build Coastguard Worker                                                     infos[1],
257*89c4ff92SAndroid Build Coastguard Worker                                                     *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258*89c4ff92SAndroid Build Coastguard Worker                                                         (&descriptor)),
259*89c4ff92SAndroid Build Coastguard Worker                                                     reasonIfUnsupported);
260*89c4ff92SAndroid Build Coastguard Worker         case LayerType::L2Normalization:
261*89c4ff92SAndroid Build Coastguard Worker             return IsL2NormalizationSupported(infos[0],
262*89c4ff92SAndroid Build Coastguard Worker                                               infos[1],
263*89c4ff92SAndroid Build Coastguard Worker                                               *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264*89c4ff92SAndroid Build Coastguard Worker                                               reasonIfUnsupported);
265*89c4ff92SAndroid Build Coastguard Worker         case LayerType::LogicalBinary:
266*89c4ff92SAndroid Build Coastguard Worker             return IsLogicalBinarySupported(infos[0],
267*89c4ff92SAndroid Build Coastguard Worker                                             infos[1],
268*89c4ff92SAndroid Build Coastguard Worker                                             infos[2],
269*89c4ff92SAndroid Build Coastguard Worker                                             *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported);
271*89c4ff92SAndroid Build Coastguard Worker         case LayerType::LogSoftmax:
272*89c4ff92SAndroid Build Coastguard Worker             return IsLogSoftmaxSupported(infos[0],
273*89c4ff92SAndroid Build Coastguard Worker                                          infos[1],
274*89c4ff92SAndroid Build Coastguard Worker                                          *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275*89c4ff92SAndroid Build Coastguard Worker                                          reasonIfUnsupported);
276*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Lstm:
277*89c4ff92SAndroid Build Coastguard Worker             return IsLstmSupported(infos[0],
278*89c4ff92SAndroid Build Coastguard Worker                                    infos[1],
279*89c4ff92SAndroid Build Coastguard Worker                                    infos[2],
280*89c4ff92SAndroid Build Coastguard Worker                                    infos[3],
281*89c4ff92SAndroid Build Coastguard Worker                                    infos[4],
282*89c4ff92SAndroid Build Coastguard Worker                                    infos[5],
283*89c4ff92SAndroid Build Coastguard Worker                                    infos[6],
284*89c4ff92SAndroid Build Coastguard Worker                                    *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285*89c4ff92SAndroid Build Coastguard Worker                                    lstmParamsInfo.value(),
286*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported);
287*89c4ff92SAndroid Build Coastguard Worker         case LayerType::QLstm:
288*89c4ff92SAndroid Build Coastguard Worker             return IsQLstmSupported(infos[0],
289*89c4ff92SAndroid Build Coastguard Worker                                     infos[1],
290*89c4ff92SAndroid Build Coastguard Worker                                     infos[2],
291*89c4ff92SAndroid Build Coastguard Worker                                     infos[3],
292*89c4ff92SAndroid Build Coastguard Worker                                     infos[4],
293*89c4ff92SAndroid Build Coastguard Worker                                     infos[5],
294*89c4ff92SAndroid Build Coastguard Worker                                     *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295*89c4ff92SAndroid Build Coastguard Worker                                     lstmParamsInfo.value(),
296*89c4ff92SAndroid Build Coastguard Worker                                     reasonIfUnsupported);
297*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Maximum:
298*89c4ff92SAndroid Build Coastguard Worker             return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Mean:
300*89c4ff92SAndroid Build Coastguard Worker             return IsMeanSupported(infos[0],
301*89c4ff92SAndroid Build Coastguard Worker                                    infos[1],
302*89c4ff92SAndroid Build Coastguard Worker                                    *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported);
304*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Minimum:
305*89c4ff92SAndroid Build Coastguard Worker             return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Multiplication:
307*89c4ff92SAndroid Build Coastguard Worker             return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Normalization:
309*89c4ff92SAndroid Build Coastguard Worker             return IsNormalizationSupported(infos[0],
310*89c4ff92SAndroid Build Coastguard Worker                                             infos[1],
311*89c4ff92SAndroid Build Coastguard Worker                                             *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported);
313*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Output:
314*89c4ff92SAndroid Build Coastguard Worker             return IsOutputSupported(infos[0], reasonIfUnsupported);
315*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Pad:
316*89c4ff92SAndroid Build Coastguard Worker             return IsPadSupported(infos[0],
317*89c4ff92SAndroid Build Coastguard Worker                                   infos[1],
318*89c4ff92SAndroid Build Coastguard Worker                                   *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported);
320*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Permute:
321*89c4ff92SAndroid Build Coastguard Worker             return IsPermuteSupported(infos[0],
322*89c4ff92SAndroid Build Coastguard Worker                                       infos[1],
323*89c4ff92SAndroid Build Coastguard Worker                                       *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported);
325*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Pooling2d:
326*89c4ff92SAndroid Build Coastguard Worker             return IsPooling2dSupported(infos[0],
327*89c4ff92SAndroid Build Coastguard Worker                                         infos[1],
328*89c4ff92SAndroid Build Coastguard Worker                                         *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329*89c4ff92SAndroid Build Coastguard Worker                                         reasonIfUnsupported);
330*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Prelu:
331*89c4ff92SAndroid Build Coastguard Worker             return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Quantize:
333*89c4ff92SAndroid Build Coastguard Worker             return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Reshape:
335*89c4ff92SAndroid Build Coastguard Worker             return IsReshapeSupported(infos[0],
336*89c4ff92SAndroid Build Coastguard Worker                                       infos[1],
337*89c4ff92SAndroid Build Coastguard Worker                                       *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported);
339*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Resize:
340*89c4ff92SAndroid Build Coastguard Worker             return IsResizeSupported(infos[0],
341*89c4ff92SAndroid Build Coastguard Worker                                      infos[1],
342*89c4ff92SAndroid Build Coastguard Worker                                      *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343*89c4ff92SAndroid Build Coastguard Worker                                      reasonIfUnsupported);
344*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Reduce:
345*89c4ff92SAndroid Build Coastguard Worker             return IsReduceSupported(infos[0],
346*89c4ff92SAndroid Build Coastguard Worker                                      infos[1],
347*89c4ff92SAndroid Build Coastguard Worker                                      *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
348*89c4ff92SAndroid Build Coastguard Worker                                      reasonIfUnsupported);
349*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Slice:
350*89c4ff92SAndroid Build Coastguard Worker             return IsSliceSupported(infos[0],
351*89c4ff92SAndroid Build Coastguard Worker                                     infos[1],
352*89c4ff92SAndroid Build Coastguard Worker                                     *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
353*89c4ff92SAndroid Build Coastguard Worker                                     reasonIfUnsupported);
354*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Softmax:
355*89c4ff92SAndroid Build Coastguard Worker             return IsSoftmaxSupported(infos[0],
356*89c4ff92SAndroid Build Coastguard Worker                                       infos[1],
357*89c4ff92SAndroid Build Coastguard Worker                                       *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
358*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported);
359*89c4ff92SAndroid Build Coastguard Worker         case LayerType::SpaceToBatchNd:
360*89c4ff92SAndroid Build Coastguard Worker             return IsSpaceToBatchNdSupported(infos[0],
361*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
362*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
363*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
364*89c4ff92SAndroid Build Coastguard Worker         case LayerType::SpaceToDepth:
365*89c4ff92SAndroid Build Coastguard Worker             return IsSpaceToDepthSupported(infos[0],
366*89c4ff92SAndroid Build Coastguard Worker                                            infos[1],
367*89c4ff92SAndroid Build Coastguard Worker                                            *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
368*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported);
369*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Splitter:
370*89c4ff92SAndroid Build Coastguard Worker         {
371*89c4ff92SAndroid Build Coastguard Worker             std::vector<TensorInfo> outputInfos;
372*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 1; i < infos.size(); i++)
373*89c4ff92SAndroid Build Coastguard Worker             {
374*89c4ff92SAndroid Build Coastguard Worker                 outputInfos.push_back(infos[i]);
375*89c4ff92SAndroid Build Coastguard Worker             }
376*89c4ff92SAndroid Build Coastguard Worker             return IsSplitterSupported(infos[0],
377*89c4ff92SAndroid Build Coastguard Worker                                        {outputInfos.begin(), outputInfos.end()},
378*89c4ff92SAndroid Build Coastguard Worker                                        *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
379*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported);
380*89c4ff92SAndroid Build Coastguard Worker         }
381*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Stack:
382*89c4ff92SAndroid Build Coastguard Worker         {
383*89c4ff92SAndroid Build Coastguard Worker             std::vector<const TensorInfo*> inputInfos;
384*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 0; i < infos.size() - 1; i++)
385*89c4ff92SAndroid Build Coastguard Worker             {
386*89c4ff92SAndroid Build Coastguard Worker                 inputInfos.push_back(&infos[i]);
387*89c4ff92SAndroid Build Coastguard Worker             }
388*89c4ff92SAndroid Build Coastguard Worker             return IsStackSupported(inputInfos,
389*89c4ff92SAndroid Build Coastguard Worker                                     infos[infos.size() - 1],
390*89c4ff92SAndroid Build Coastguard Worker                                     *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
391*89c4ff92SAndroid Build Coastguard Worker                                     reasonIfUnsupported);
392*89c4ff92SAndroid Build Coastguard Worker         }
393*89c4ff92SAndroid Build Coastguard Worker         case LayerType::StridedSlice:
394*89c4ff92SAndroid Build Coastguard Worker             return IsStridedSliceSupported(infos[0],
395*89c4ff92SAndroid Build Coastguard Worker                                            infos[1],
396*89c4ff92SAndroid Build Coastguard Worker                                            *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
397*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported);
398*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Subtraction:
399*89c4ff92SAndroid Build Coastguard Worker             return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
400*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Transpose:
401*89c4ff92SAndroid Build Coastguard Worker             return IsTransposeSupported(infos[0],
402*89c4ff92SAndroid Build Coastguard Worker                                         infos[1],
403*89c4ff92SAndroid Build Coastguard Worker                                         *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
404*89c4ff92SAndroid Build Coastguard Worker                                         reasonIfUnsupported);
405*89c4ff92SAndroid Build Coastguard Worker         case LayerType::TransposeConvolution2d:
406*89c4ff92SAndroid Build Coastguard Worker         {
407*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
408*89c4ff92SAndroid Build Coastguard Worker             {
409*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
410*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
411*89c4ff92SAndroid Build Coastguard Worker             }
412*89c4ff92SAndroid Build Coastguard Worker 
413*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
414*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
415*89c4ff92SAndroid Build Coastguard Worker             {
416*89c4ff92SAndroid Build Coastguard Worker                 return IsTransposeConvolution2dSupported(infos[0],
417*89c4ff92SAndroid Build Coastguard Worker                                                          infos[1],
418*89c4ff92SAndroid Build Coastguard Worker                                                          desc,
419*89c4ff92SAndroid Build Coastguard Worker                                                          infos[2],
420*89c4ff92SAndroid Build Coastguard Worker                                                          EmptyOptional(),
421*89c4ff92SAndroid Build Coastguard Worker                                                          reasonIfUnsupported);
422*89c4ff92SAndroid Build Coastguard Worker             }
423*89c4ff92SAndroid Build Coastguard Worker             else
424*89c4ff92SAndroid Build Coastguard Worker             {
425*89c4ff92SAndroid Build Coastguard Worker                 return IsTransposeConvolution2dSupported(infos[0],
426*89c4ff92SAndroid Build Coastguard Worker                                                          infos[1],
427*89c4ff92SAndroid Build Coastguard Worker                                                          desc,
428*89c4ff92SAndroid Build Coastguard Worker                                                          infos[2],
429*89c4ff92SAndroid Build Coastguard Worker                                                          infos[3],
430*89c4ff92SAndroid Build Coastguard Worker                                                          reasonIfUnsupported);
431*89c4ff92SAndroid Build Coastguard Worker             }
432*89c4ff92SAndroid Build Coastguard Worker         }
433*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Cast:
434*89c4ff92SAndroid Build Coastguard Worker             return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
435*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ChannelShuffle:
436*89c4ff92SAndroid Build Coastguard Worker             return IsChannelShuffleSupported(infos[0],
437*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
438*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
439*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
440*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Convolution3d:
441*89c4ff92SAndroid Build Coastguard Worker         {
442*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
443*89c4ff92SAndroid Build Coastguard Worker             {
444*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
445*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
446*89c4ff92SAndroid Build Coastguard Worker             }
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
449*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
450*89c4ff92SAndroid Build Coastguard Worker             {
451*89c4ff92SAndroid Build Coastguard Worker                 return IsConvolution3dSupported(infos[0],
452*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
453*89c4ff92SAndroid Build Coastguard Worker                                                 desc,
454*89c4ff92SAndroid Build Coastguard Worker                                                 infos[2],
455*89c4ff92SAndroid Build Coastguard Worker                                                 EmptyOptional(),
456*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
457*89c4ff92SAndroid Build Coastguard Worker             }
458*89c4ff92SAndroid Build Coastguard Worker             else
459*89c4ff92SAndroid Build Coastguard Worker             {
460*89c4ff92SAndroid Build Coastguard Worker                 return IsConvolution3dSupported(infos[0],
461*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
462*89c4ff92SAndroid Build Coastguard Worker                                                 desc,
463*89c4ff92SAndroid Build Coastguard Worker                                                 infos[2],
464*89c4ff92SAndroid Build Coastguard Worker                                                 infos[3],
465*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
466*89c4ff92SAndroid Build Coastguard Worker             }
467*89c4ff92SAndroid Build Coastguard Worker         }
468*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Debug:
469*89c4ff92SAndroid Build Coastguard Worker             return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
470*89c4ff92SAndroid Build Coastguard Worker         case LayerType::DetectionPostProcess:
471*89c4ff92SAndroid Build Coastguard Worker             return IsDetectionPostProcessSupported(infos[0],
472*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
473*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2],
474*89c4ff92SAndroid Build Coastguard Worker                                                    infos[3],
475*89c4ff92SAndroid Build Coastguard Worker                                                    infos[4],
476*89c4ff92SAndroid Build Coastguard Worker                                                    infos[5],
477*89c4ff92SAndroid Build Coastguard Worker                                                    infos[6],
478*89c4ff92SAndroid Build Coastguard Worker                                                    *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
479*89c4ff92SAndroid Build Coastguard Worker                                                        (&descriptor)),
480*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported);
481*89c4ff92SAndroid Build Coastguard Worker         case LayerType::FakeQuantization:
482*89c4ff92SAndroid Build Coastguard Worker             return IsFakeQuantizationSupported(infos[0],
483*89c4ff92SAndroid Build Coastguard Worker                                                *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
484*89c4ff92SAndroid Build Coastguard Worker                                                reasonIfUnsupported);
485*89c4ff92SAndroid Build Coastguard Worker         case LayerType::MemCopy:
486*89c4ff92SAndroid Build Coastguard Worker             return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
487*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Rank:
488*89c4ff92SAndroid Build Coastguard Worker             return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
489*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Shape:
490*89c4ff92SAndroid Build Coastguard Worker             return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
491*89c4ff92SAndroid Build Coastguard Worker         case LayerType::UnidirectionalSequenceLstm:
492*89c4ff92SAndroid Build Coastguard Worker         {
493*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 6)
494*89c4ff92SAndroid Build Coastguard Worker             {
495*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
496*89c4ff92SAndroid Build Coastguard Worker                                                "should be of format: {input, outputStateIn, cellStateIn, "
497*89c4ff92SAndroid Build Coastguard Worker                                                "hiddenStateOutputVal, cellStateOutputVal, output}");
498*89c4ff92SAndroid Build Coastguard Worker             }
499*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
500*89c4ff92SAndroid Build Coastguard Worker             return IsUnidirectionalSequenceLstmSupported(infos[0],
501*89c4ff92SAndroid Build Coastguard Worker                                                          infos[1],
502*89c4ff92SAndroid Build Coastguard Worker                                                          infos[2],
503*89c4ff92SAndroid Build Coastguard Worker                                                          infos[3],
504*89c4ff92SAndroid Build Coastguard Worker                                                          infos[4],
505*89c4ff92SAndroid Build Coastguard Worker                                                          infos[5],
506*89c4ff92SAndroid Build Coastguard Worker                                                          desc,
507*89c4ff92SAndroid Build Coastguard Worker                                                          lstmParamsInfo.value(),
508*89c4ff92SAndroid Build Coastguard Worker                                                          reasonIfUnsupported);
509*89c4ff92SAndroid Build Coastguard Worker         }
510*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Pooling3d:
511*89c4ff92SAndroid Build Coastguard Worker             return IsPooling3dSupported(infos[0],
512*89c4ff92SAndroid Build Coastguard Worker                                         infos[1],
513*89c4ff92SAndroid Build Coastguard Worker                                         *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
514*89c4ff92SAndroid Build Coastguard Worker                                         reasonIfUnsupported);
515*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Map:
516*89c4ff92SAndroid Build Coastguard Worker             return true;
517*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Unmap:
518*89c4ff92SAndroid Build Coastguard Worker             return true;
519*89c4ff92SAndroid Build Coastguard Worker         case LayerType::MemImport:
520*89c4ff92SAndroid Build Coastguard Worker             return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
521*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Merge:
522*89c4ff92SAndroid Build Coastguard Worker             return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
523*89c4ff92SAndroid Build Coastguard Worker         case LayerType::QuantizedLstm:
524*89c4ff92SAndroid Build Coastguard Worker             return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
525*89c4ff92SAndroid Build Coastguard Worker                                                               infos[1],
526*89c4ff92SAndroid Build Coastguard Worker                                                               infos[2],
527*89c4ff92SAndroid Build Coastguard Worker                                                               infos[3],
528*89c4ff92SAndroid Build Coastguard Worker                                                               infos[4],
529*89c4ff92SAndroid Build Coastguard Worker                                                               quantizedLstmInputParamsInfo.value(),
530*89c4ff92SAndroid Build Coastguard Worker                                                               reasonIfUnsupported);
531*89c4ff92SAndroid Build Coastguard Worker         default:
532*89c4ff92SAndroid Build Coastguard Worker             // layers not supported in neon by default:
533*89c4ff92SAndroid Build Coastguard Worker             // precompiled, standin, switch
534*89c4ff92SAndroid Build Coastguard Worker             return false;
535*89c4ff92SAndroid Build Coastguard Worker     }
536*89c4ff92SAndroid Build Coastguard Worker }
537*89c4ff92SAndroid Build Coastguard Worker 
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const538*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
539*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
540*89c4ff92SAndroid Build Coastguard Worker                                             const ActivationDescriptor& descriptor,
541*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
542*89c4ff92SAndroid Build Coastguard Worker {
543*89c4ff92SAndroid Build Coastguard Worker    bool supported = true;
544*89c4ff92SAndroid Build Coastguard Worker 
545*89c4ff92SAndroid Build Coastguard Worker     // Define supported types.
546*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes = {
547*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
548*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
549*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
550*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
551*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
552*89c4ff92SAndroid Build Coastguard Worker     };
553*89c4ff92SAndroid Build Coastguard Worker 
554*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
555*89c4ff92SAndroid Build Coastguard Worker                                   "Reference activation: input type not supported.");
556*89c4ff92SAndroid Build Coastguard Worker 
557*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
558*89c4ff92SAndroid Build Coastguard Worker                                   "Reference activation: output type not supported.");
559*89c4ff92SAndroid Build Coastguard Worker 
560*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
561*89c4ff92SAndroid Build Coastguard Worker                                   "Reference activation: input and output types mismatched.");
562*89c4ff92SAndroid Build Coastguard Worker 
563*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
564*89c4ff92SAndroid Build Coastguard Worker                                   "Reference activation: input and output shapes are of different rank.");
565*89c4ff92SAndroid Build Coastguard Worker 
566*89c4ff92SAndroid Build Coastguard Worker 
567*89c4ff92SAndroid Build Coastguard Worker     struct ActivationFunctionSupported : public Rule
568*89c4ff92SAndroid Build Coastguard Worker     {
569*89c4ff92SAndroid Build Coastguard Worker         ActivationFunctionSupported(const ActivationDescriptor& desc)
570*89c4ff92SAndroid Build Coastguard Worker         {
571*89c4ff92SAndroid Build Coastguard Worker             switch(desc.m_Function)
572*89c4ff92SAndroid Build Coastguard Worker             {
573*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::Abs:
574*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::BoundedReLu:
575*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::Elu:
576*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::HardSwish:
577*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::LeakyReLu:
578*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::Linear:
579*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::ReLu:
580*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::Sigmoid:
581*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::SoftReLu:
582*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::Sqrt:
583*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::Square:
584*89c4ff92SAndroid Build Coastguard Worker                 case ActivationFunction::TanH:
585*89c4ff92SAndroid Build Coastguard Worker                 {
586*89c4ff92SAndroid Build Coastguard Worker                     m_Res = true;
587*89c4ff92SAndroid Build Coastguard Worker                     break;
588*89c4ff92SAndroid Build Coastguard Worker                 }
589*89c4ff92SAndroid Build Coastguard Worker                 default:
590*89c4ff92SAndroid Build Coastguard Worker                 {
591*89c4ff92SAndroid Build Coastguard Worker                     m_Res = false;
592*89c4ff92SAndroid Build Coastguard Worker                     break;
593*89c4ff92SAndroid Build Coastguard Worker                 }
594*89c4ff92SAndroid Build Coastguard Worker             }
595*89c4ff92SAndroid Build Coastguard Worker         }
596*89c4ff92SAndroid Build Coastguard Worker     };
597*89c4ff92SAndroid Build Coastguard Worker 
598*89c4ff92SAndroid Build Coastguard Worker     // Function is supported
599*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
600*89c4ff92SAndroid Build Coastguard Worker                                   "Reference activation: function not supported.");
601*89c4ff92SAndroid Build Coastguard Worker 
602*89c4ff92SAndroid Build Coastguard Worker     return supported;
603*89c4ff92SAndroid Build Coastguard Worker }
604*89c4ff92SAndroid Build Coastguard Worker 
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const605*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
606*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& input1,
607*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
608*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
609*89c4ff92SAndroid Build Coastguard Worker {
610*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
611*89c4ff92SAndroid Build Coastguard Worker 
612*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes = {
613*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
614*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
615*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
616*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
617*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
618*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
619*89c4ff92SAndroid Build Coastguard Worker     };
620*89c4ff92SAndroid Build Coastguard Worker 
621*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
622*89c4ff92SAndroid Build Coastguard Worker                                   "Reference addition: input 0 is not a supported type.");
623*89c4ff92SAndroid Build Coastguard Worker 
624*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
625*89c4ff92SAndroid Build Coastguard Worker                                   "Reference addition: input 1 is not a supported type.");
626*89c4ff92SAndroid Build Coastguard Worker 
627*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
628*89c4ff92SAndroid Build Coastguard Worker                                   "Reference addition: output is not a supported type.");
629*89c4ff92SAndroid Build Coastguard Worker 
630*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
631*89c4ff92SAndroid Build Coastguard Worker                                   "Reference addition: input 0 and Input 1 types are mismatched");
632*89c4ff92SAndroid Build Coastguard Worker 
633*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
634*89c4ff92SAndroid Build Coastguard Worker                                   "Reference addition: input and output types are mismatched");
635*89c4ff92SAndroid Build Coastguard Worker 
636*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
637*89c4ff92SAndroid Build Coastguard Worker                                   "Reference addition: shapes are not suitable for implicit broadcast.");
638*89c4ff92SAndroid Build Coastguard Worker 
639*89c4ff92SAndroid Build Coastguard Worker     return supported;
640*89c4ff92SAndroid Build Coastguard Worker }
641*89c4ff92SAndroid Build Coastguard Worker 
IsArgMinMaxSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & output,const armnn::ArgMinMaxDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const642*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
643*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::ArgMinMaxDescriptor &descriptor,
644*89c4ff92SAndroid Build Coastguard Worker                                            armnn::Optional<std::string &> reasonIfUnsupported) const
645*89c4ff92SAndroid Build Coastguard Worker {
646*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
647*89c4ff92SAndroid Build Coastguard Worker 
648*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 8> supportedInputTypes =
649*89c4ff92SAndroid Build Coastguard Worker     {
650*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
651*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
652*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
653*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
654*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
655*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32,
656*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed64
657*89c4ff92SAndroid Build Coastguard Worker     };
658*89c4ff92SAndroid Build Coastguard Worker 
659*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,2> supportedOutputTypes = {
660*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32,
661*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed64
662*89c4ff92SAndroid Build Coastguard Worker     };
663*89c4ff92SAndroid Build Coastguard Worker 
664*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
665*89c4ff92SAndroid Build Coastguard Worker 
666*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
667*89c4ff92SAndroid Build Coastguard Worker                                   "Reference ArgMinMax: input is not a supported type.");
668*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
669*89c4ff92SAndroid Build Coastguard Worker                                   "Reference ArgMinMax: output type not supported");
670*89c4ff92SAndroid Build Coastguard Worker 
671*89c4ff92SAndroid Build Coastguard Worker     return supported;
672*89c4ff92SAndroid Build Coastguard Worker }
673*89c4ff92SAndroid Build Coastguard Worker 
IsBatchMatMulSupported(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const674*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
675*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& inputY,
676*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& output,
677*89c4ff92SAndroid Build Coastguard Worker                                              const BatchMatMulDescriptor& descriptor,
678*89c4ff92SAndroid Build Coastguard Worker                                              Optional<std::string &> reasonIfUnsupported) const
679*89c4ff92SAndroid Build Coastguard Worker {
680*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
681*89c4ff92SAndroid Build Coastguard Worker 
682*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes =
683*89c4ff92SAndroid Build Coastguard Worker     {
684*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
685*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
686*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
687*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
688*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
689*89c4ff92SAndroid Build Coastguard Worker     };
690*89c4ff92SAndroid Build Coastguard Worker 
691*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
692*89c4ff92SAndroid Build Coastguard Worker 
693*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
694*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: input X is not a supported type");
695*89c4ff92SAndroid Build Coastguard Worker 
696*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
697*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: input Y is not a supported type");
698*89c4ff92SAndroid Build Coastguard Worker 
699*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
700*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: output is not a supported type");
701*89c4ff92SAndroid Build Coastguard Worker 
702*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
703*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: input X and input Y types are mismatched");
704*89c4ff92SAndroid Build Coastguard Worker 
705*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
706*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: inputs and output types are mismatched");
707*89c4ff92SAndroid Build Coastguard Worker 
708*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
709*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
710*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: input X is not of rank 2 or greater");
711*89c4ff92SAndroid Build Coastguard Worker 
712*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
713*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
714*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
715*89c4ff92SAndroid Build Coastguard Worker 
716*89c4ff92SAndroid Build Coastguard Worker     return supported;
717*89c4ff92SAndroid Build Coastguard Worker }
718*89c4ff92SAndroid Build Coastguard Worker 
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & variance,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const719*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
720*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& output,
721*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& mean,
722*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& variance,
723*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& beta,
724*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& gamma,
725*89c4ff92SAndroid Build Coastguard Worker                                                     const BatchNormalizationDescriptor& descriptor,
726*89c4ff92SAndroid Build Coastguard Worker                                                     Optional<std::string&> reasonIfUnsupported) const
727*89c4ff92SAndroid Build Coastguard Worker {
728*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
729*89c4ff92SAndroid Build Coastguard Worker 
730*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes =
731*89c4ff92SAndroid Build Coastguard Worker     {
732*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
733*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
734*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
735*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
736*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
737*89c4ff92SAndroid Build Coastguard Worker     };
738*89c4ff92SAndroid Build Coastguard Worker 
739*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
740*89c4ff92SAndroid Build Coastguard Worker 
741*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: input is not a supported type.");
743*89c4ff92SAndroid Build Coastguard Worker 
744*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
745*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: output is not a supported type.");
746*89c4ff92SAndroid Build Coastguard Worker 
747*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
748*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: input and output types are mismatched");
749*89c4ff92SAndroid Build Coastguard Worker 
750*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
751*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: mean is not a supported type.");
752*89c4ff92SAndroid Build Coastguard Worker 
753*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
754*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: variance is not a supported type.");
755*89c4ff92SAndroid Build Coastguard Worker 
756*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
757*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: beta is not a supported type.");
758*89c4ff92SAndroid Build Coastguard Worker 
759*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
760*89c4ff92SAndroid Build Coastguard Worker                                   "Reference batch normalization: gamma is not a supported type.");
761*89c4ff92SAndroid Build Coastguard Worker 
762*89c4ff92SAndroid Build Coastguard Worker     return supported;
763*89c4ff92SAndroid Build Coastguard Worker }
764*89c4ff92SAndroid Build Coastguard Worker 
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const765*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
766*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
767*89c4ff92SAndroid Build Coastguard Worker                                                 const BatchToSpaceNdDescriptor& descriptor,
768*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
769*89c4ff92SAndroid Build Coastguard Worker {
770*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
771*89c4ff92SAndroid Build Coastguard Worker 
772*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
773*89c4ff92SAndroid Build Coastguard Worker 
774*89c4ff92SAndroid Build Coastguard Worker     std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
775*89c4ff92SAndroid Build Coastguard Worker     std::string inputTensorStr = "input";
776*89c4ff92SAndroid Build Coastguard Worker     std::string outputTensorStr = "output";
777*89c4ff92SAndroid Build Coastguard Worker 
778*89c4ff92SAndroid Build Coastguard Worker     // Define supported types.
779*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
780*89c4ff92SAndroid Build Coastguard Worker     {
781*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
782*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
783*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
784*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
785*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
786*89c4ff92SAndroid Build Coastguard Worker     };
787*89c4ff92SAndroid Build Coastguard Worker 
788*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789*89c4ff92SAndroid Build Coastguard Worker                                   "Reference BatchToSpaceNd: input type not supported.");
790*89c4ff92SAndroid Build Coastguard Worker 
791*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792*89c4ff92SAndroid Build Coastguard Worker                                   "Reference BatchToSpaceNd: output type not supported.");
793*89c4ff92SAndroid Build Coastguard Worker 
794*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795*89c4ff92SAndroid Build Coastguard Worker                                   "Reference BatchToSpaceNd: input and output types mismatched.");
796*89c4ff92SAndroid Build Coastguard Worker 
797*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
798*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
799*89c4ff92SAndroid Build Coastguard Worker                                   CreateIncorrectDimensionsErrorMsg(4,
800*89c4ff92SAndroid Build Coastguard Worker                                                                     output.GetNumDimensions(),
801*89c4ff92SAndroid Build Coastguard Worker                                                                     batchToSpaceNdLayerStr,
802*89c4ff92SAndroid Build Coastguard Worker                                                                     outputTensorStr).data());
803*89c4ff92SAndroid Build Coastguard Worker 
804*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
805*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
806*89c4ff92SAndroid Build Coastguard Worker                                   CreateIncorrectDimensionsErrorMsg(4,
807*89c4ff92SAndroid Build Coastguard Worker                                                                     input.GetNumDimensions(),
808*89c4ff92SAndroid Build Coastguard Worker                                                                     batchToSpaceNdLayerStr,
809*89c4ff92SAndroid Build Coastguard Worker                                                                     inputTensorStr).data());
810*89c4ff92SAndroid Build Coastguard Worker 
811*89c4ff92SAndroid Build Coastguard Worker     return supported;
812*89c4ff92SAndroid Build Coastguard Worker }
813*89c4ff92SAndroid Build Coastguard Worker 
IsCastSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const814*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
815*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& output,
816*89c4ff92SAndroid Build Coastguard Worker                                       Optional<std::string&> reasonIfUnsupported) const
817*89c4ff92SAndroid Build Coastguard Worker {
818*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 9> supportedInputTypes =
819*89c4ff92SAndroid Build Coastguard Worker             {
820*89c4ff92SAndroid Build Coastguard Worker                     DataType::Float32,
821*89c4ff92SAndroid Build Coastguard Worker                     DataType::Float16,
822*89c4ff92SAndroid Build Coastguard Worker                     DataType::QSymmS8,
823*89c4ff92SAndroid Build Coastguard Worker                     DataType::QAsymmS8,
824*89c4ff92SAndroid Build Coastguard Worker                     DataType::QAsymmU8,
825*89c4ff92SAndroid Build Coastguard Worker                     DataType::QSymmS16,
826*89c4ff92SAndroid Build Coastguard Worker                     DataType::Signed32
827*89c4ff92SAndroid Build Coastguard Worker             };
828*89c4ff92SAndroid Build Coastguard Worker 
829*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
830*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
831*89c4ff92SAndroid Build Coastguard Worker                                   "Reference cast: input is not a supported type");
832*89c4ff92SAndroid Build Coastguard Worker 
833*89c4ff92SAndroid Build Coastguard Worker 
834*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
835*89c4ff92SAndroid Build Coastguard Worker                                   "Reference cast: output is not a supported type");
836*89c4ff92SAndroid Build Coastguard Worker 
837*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838*89c4ff92SAndroid Build Coastguard Worker                                   "Reference cast: input and output shapes have different number of total elements");
839*89c4ff92SAndroid Build Coastguard Worker 
840*89c4ff92SAndroid Build Coastguard Worker     return supported;
841*89c4ff92SAndroid Build Coastguard Worker }
842*89c4ff92SAndroid Build Coastguard Worker 
IsChannelShuffleSupported(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const843*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
844*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
845*89c4ff92SAndroid Build Coastguard Worker                                                 const ChannelShuffleDescriptor& descriptor,
846*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
847*89c4ff92SAndroid Build Coastguard Worker {
848*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
849*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
850*89c4ff92SAndroid Build Coastguard Worker 
851*89c4ff92SAndroid Build Coastguard Worker     // Define supported output and inputs types.
852*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 7> supportedTypes =
853*89c4ff92SAndroid Build Coastguard Worker     {
854*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
855*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
856*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
857*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
858*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
859*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
860*89c4ff92SAndroid Build Coastguard Worker     };
861*89c4ff92SAndroid Build Coastguard Worker 
862*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863*89c4ff92SAndroid Build Coastguard Worker                                   "Reference ChannelShuffle: input is not a supported type.");
864*89c4ff92SAndroid Build Coastguard Worker 
865*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866*89c4ff92SAndroid Build Coastguard Worker                                   "Reference ChannelShuffle: output is not a supported type.");
867*89c4ff92SAndroid Build Coastguard Worker 
868*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
869*89c4ff92SAndroid Build Coastguard Worker                                   "Reference ChannelShuffle: input and output types are mismatched.");
870*89c4ff92SAndroid Build Coastguard Worker 
871*89c4ff92SAndroid Build Coastguard Worker     return supported;
872*89c4ff92SAndroid Build Coastguard Worker }
873*89c4ff92SAndroid Build Coastguard Worker 
874*89c4ff92SAndroid Build Coastguard Worker 
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const875*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
876*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& input1,
877*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
878*89c4ff92SAndroid Build Coastguard Worker                                             const ComparisonDescriptor& descriptor,
879*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
880*89c4ff92SAndroid Build Coastguard Worker {
881*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
882*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 8> supportedInputTypes =
883*89c4ff92SAndroid Build Coastguard Worker     {
884*89c4ff92SAndroid Build Coastguard Worker         DataType::Boolean,
885*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
886*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
887*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
888*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
889*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
890*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
891*89c4ff92SAndroid Build Coastguard Worker     };
892*89c4ff92SAndroid Build Coastguard Worker 
893*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
894*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
895*89c4ff92SAndroid Build Coastguard Worker                                   "Reference comparison: input 0 is not a supported type");
896*89c4ff92SAndroid Build Coastguard Worker 
897*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
898*89c4ff92SAndroid Build Coastguard Worker                                   "Reference comparison: input 0 and Input 1 types are mismatched");
899*89c4ff92SAndroid Build Coastguard Worker 
900*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
901*89c4ff92SAndroid Build Coastguard Worker                                   "Reference comparison: output is not of type Boolean");
902*89c4ff92SAndroid Build Coastguard Worker 
903*89c4ff92SAndroid Build Coastguard Worker     return supported;
904*89c4ff92SAndroid Build Coastguard Worker }
905*89c4ff92SAndroid Build Coastguard Worker 
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const OriginsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const906*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
907*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
908*89c4ff92SAndroid Build Coastguard Worker                                         const OriginsDescriptor& descriptor,
909*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
910*89c4ff92SAndroid Build Coastguard Worker {
911*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
912*89c4ff92SAndroid Build Coastguard Worker 
913*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
914*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
915*89c4ff92SAndroid Build Coastguard Worker     {
916*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
917*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
918*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
919*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
920*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
921*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
922*89c4ff92SAndroid Build Coastguard Worker     };
923*89c4ff92SAndroid Build Coastguard Worker 
924*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
925*89c4ff92SAndroid Build Coastguard Worker                                   "Reference concatenation: output type not supported");
926*89c4ff92SAndroid Build Coastguard Worker     for (const TensorInfo* input : inputs)
927*89c4ff92SAndroid Build Coastguard Worker     {
928*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(input != nullptr);
929*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
930*89c4ff92SAndroid Build Coastguard Worker             "Reference concatenation: input type not supported");
931*89c4ff92SAndroid Build Coastguard Worker 
932*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
933*89c4ff92SAndroid Build Coastguard Worker             "Reference concatenation: input and output types mismatched.");
934*89c4ff92SAndroid Build Coastguard Worker     }
935*89c4ff92SAndroid Build Coastguard Worker 
936*89c4ff92SAndroid Build Coastguard Worker     return supported;
937*89c4ff92SAndroid Build Coastguard Worker }
938*89c4ff92SAndroid Build Coastguard Worker 
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const939*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
940*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
941*89c4ff92SAndroid Build Coastguard Worker {
942*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,8> supportedTypes =
943*89c4ff92SAndroid Build Coastguard Worker     {
944*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
945*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
946*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
947*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
948*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
949*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
950*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
951*89c4ff92SAndroid Build Coastguard Worker     };
952*89c4ff92SAndroid Build Coastguard Worker 
953*89c4ff92SAndroid Build Coastguard Worker     return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
954*89c4ff92SAndroid Build Coastguard Worker                                   "Reference constant: output is not a supported type.");
955*89c4ff92SAndroid Build Coastguard Worker }
956*89c4ff92SAndroid Build Coastguard Worker 
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const957*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
958*89c4ff92SAndroid Build Coastguard Worker                                                    const TensorInfo& output,
959*89c4ff92SAndroid Build Coastguard Worker                                                    Optional<std::string&> reasonIfUnsupported) const
960*89c4ff92SAndroid Build Coastguard Worker {
961*89c4ff92SAndroid Build Coastguard Worker     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
962*89c4ff92SAndroid Build Coastguard Worker                                           input.GetDataType(),
963*89c4ff92SAndroid Build Coastguard Worker                                           &TrueFunc<>,
964*89c4ff92SAndroid Build Coastguard Worker                                           &FalseInputFuncF32<>,
965*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>,
966*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncI32<>,
967*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>) &&
968*89c4ff92SAndroid Build Coastguard Worker             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969*89c4ff92SAndroid Build Coastguard Worker                                           output.GetDataType(),
970*89c4ff92SAndroid Build Coastguard Worker                                           &FalseOutputFuncF16<>,
971*89c4ff92SAndroid Build Coastguard Worker                                           &TrueFunc<>,
972*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>,
973*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncI32<>,
974*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>));
975*89c4ff92SAndroid Build Coastguard Worker }
976*89c4ff92SAndroid Build Coastguard Worker 
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const977*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
978*89c4ff92SAndroid Build Coastguard Worker                                                    const TensorInfo& output,
979*89c4ff92SAndroid Build Coastguard Worker                                                    Optional<std::string&> reasonIfUnsupported) const
980*89c4ff92SAndroid Build Coastguard Worker {
981*89c4ff92SAndroid Build Coastguard Worker     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
982*89c4ff92SAndroid Build Coastguard Worker                                           input.GetDataType(),
983*89c4ff92SAndroid Build Coastguard Worker                                           &FalseInputFuncF16<>,
984*89c4ff92SAndroid Build Coastguard Worker                                           &TrueFunc<>,
985*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>,
986*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncI32<>,
987*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>) &&
988*89c4ff92SAndroid Build Coastguard Worker             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
989*89c4ff92SAndroid Build Coastguard Worker                                           output.GetDataType(),
990*89c4ff92SAndroid Build Coastguard Worker                                           &TrueFunc<>,
991*89c4ff92SAndroid Build Coastguard Worker                                           &FalseOutputFuncF32<>,
992*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>,
993*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncI32<>,
994*89c4ff92SAndroid Build Coastguard Worker                                           &FalseFuncU8<>));
995*89c4ff92SAndroid Build Coastguard Worker }
996*89c4ff92SAndroid Build Coastguard Worker 
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const997*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
998*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
999*89c4ff92SAndroid Build Coastguard Worker                                                const Convolution2dDescriptor& descriptor,
1000*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& weights,
1001*89c4ff92SAndroid Build Coastguard Worker                                                const Optional<TensorInfo>& biases,
1002*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
1003*89c4ff92SAndroid Build Coastguard Worker {
1004*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1005*89c4ff92SAndroid Build Coastguard Worker 
1006*89c4ff92SAndroid Build Coastguard Worker     // Define supported types.
1007*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
1008*89c4ff92SAndroid Build Coastguard Worker     {
1009*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1010*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1011*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1012*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1013*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
1014*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1015*89c4ff92SAndroid Build Coastguard Worker     };
1016*89c4ff92SAndroid Build Coastguard Worker 
1017*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1018*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Convolution2d: input is not a supported type.");
1019*89c4ff92SAndroid Build Coastguard Worker 
1020*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1021*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Convolution2d: output is not a supported type.");
1022*89c4ff92SAndroid Build Coastguard Worker 
1023*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1024*89c4ff92SAndroid Build Coastguard Worker                               "Reference Convolution2d: input and output types mismatched.");
1025*89c4ff92SAndroid Build Coastguard Worker 
1026*89c4ff92SAndroid Build Coastguard Worker 
1027*89c4ff92SAndroid Build Coastguard Worker     const DataType inputType = input.GetDataType();
1028*89c4ff92SAndroid Build Coastguard Worker     if (IsQuantized8BitType(inputType))
1029*89c4ff92SAndroid Build Coastguard Worker     {
1030*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType, 3> supportedWeightTypes =
1031*89c4ff92SAndroid Build Coastguard Worker         {
1032*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmS8,
1033*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmU8,
1034*89c4ff92SAndroid Build Coastguard Worker             DataType::QSymmS8
1035*89c4ff92SAndroid Build Coastguard Worker         };
1036*89c4ff92SAndroid Build Coastguard Worker 
1037*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1038*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution2d: weights type not supported for quantized input.");
1039*89c4ff92SAndroid Build Coastguard Worker     }
1040*89c4ff92SAndroid Build Coastguard Worker     else
1041*89c4ff92SAndroid Build Coastguard Worker     {
1042*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1043*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution2d: weights is not a supported type.");
1044*89c4ff92SAndroid Build Coastguard Worker 
1045*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1046*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution2d: input and weights types mismatched.");
1047*89c4ff92SAndroid Build Coastguard Worker     }
1048*89c4ff92SAndroid Build Coastguard Worker 
1049*89c4ff92SAndroid Build Coastguard Worker     if (biases.has_value())
1050*89c4ff92SAndroid Build Coastguard Worker     {
1051*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType,4> biasesSupportedTypes =
1052*89c4ff92SAndroid Build Coastguard Worker         {
1053*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
1054*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16,
1055*89c4ff92SAndroid Build Coastguard Worker             DataType::Signed32
1056*89c4ff92SAndroid Build Coastguard Worker         };
1057*89c4ff92SAndroid Build Coastguard Worker 
1058*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1059*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution2d: biases is not a supported type.");
1060*89c4ff92SAndroid Build Coastguard Worker     }
1061*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1062*89c4ff92SAndroid Build Coastguard Worker 
1063*89c4ff92SAndroid Build Coastguard Worker     return supported;
1064*89c4ff92SAndroid Build Coastguard Worker }
1065*89c4ff92SAndroid Build Coastguard Worker 
IsConvolution3dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution3dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1066*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1067*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
1068*89c4ff92SAndroid Build Coastguard Worker                                                const Convolution3dDescriptor& descriptor,
1069*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& weights,
1070*89c4ff92SAndroid Build Coastguard Worker                                                const Optional<TensorInfo>& biases,
1071*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
1072*89c4ff92SAndroid Build Coastguard Worker {
1073*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1074*89c4ff92SAndroid Build Coastguard Worker 
1075*89c4ff92SAndroid Build Coastguard Worker     // Define supported types.
1076*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
1077*89c4ff92SAndroid Build Coastguard Worker     {
1078*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1079*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1080*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1081*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1082*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
1083*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1084*89c4ff92SAndroid Build Coastguard Worker     };
1085*89c4ff92SAndroid Build Coastguard Worker 
1086*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1087*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Convolution3d: input is not a supported type.");
1088*89c4ff92SAndroid Build Coastguard Worker 
1089*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1090*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Convolution3d: output is not a supported type.");
1091*89c4ff92SAndroid Build Coastguard Worker 
1092*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1093*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Convolution3d: input and output types mismatched.");
1094*89c4ff92SAndroid Build Coastguard Worker 
1095*89c4ff92SAndroid Build Coastguard Worker     const DataType inputType = input.GetDataType();
1096*89c4ff92SAndroid Build Coastguard Worker     if (IsQuantized8BitType(inputType))
1097*89c4ff92SAndroid Build Coastguard Worker     {
1098*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType, 3> supportedWeightTypes =
1099*89c4ff92SAndroid Build Coastguard Worker         {
1100*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmS8,
1101*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmU8,
1102*89c4ff92SAndroid Build Coastguard Worker             DataType::QSymmS8
1103*89c4ff92SAndroid Build Coastguard Worker         };
1104*89c4ff92SAndroid Build Coastguard Worker 
1105*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1106*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution3d: weights type not supported for quantized input.");
1107*89c4ff92SAndroid Build Coastguard Worker     }
1108*89c4ff92SAndroid Build Coastguard Worker     else
1109*89c4ff92SAndroid Build Coastguard Worker     {
1110*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1111*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution3d: weights is not a supported type.");
1112*89c4ff92SAndroid Build Coastguard Worker 
1113*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1114*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution3d: input and weights types mismatched.");
1115*89c4ff92SAndroid Build Coastguard Worker     }
1116*89c4ff92SAndroid Build Coastguard Worker 
1117*89c4ff92SAndroid Build Coastguard Worker     if (biases.has_value())
1118*89c4ff92SAndroid Build Coastguard Worker     {
1119*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType,4> biasesSupportedTypes =
1120*89c4ff92SAndroid Build Coastguard Worker         {
1121*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
1122*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16,
1123*89c4ff92SAndroid Build Coastguard Worker             DataType::Signed32
1124*89c4ff92SAndroid Build Coastguard Worker         };
1125*89c4ff92SAndroid Build Coastguard Worker 
1126*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1127*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Convolution3d: biases is not a supported type.");
1128*89c4ff92SAndroid Build Coastguard Worker     }
1129*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1130*89c4ff92SAndroid Build Coastguard Worker 
1131*89c4ff92SAndroid Build Coastguard Worker     return supported;
1132*89c4ff92SAndroid Build Coastguard Worker }
1133*89c4ff92SAndroid Build Coastguard Worker 
IsDebugSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1134*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1135*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
1136*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
1137*89c4ff92SAndroid Build Coastguard Worker {
1138*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1139*89c4ff92SAndroid Build Coastguard Worker 
1140*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 8> supportedTypes =
1141*89c4ff92SAndroid Build Coastguard Worker     {
1142*89c4ff92SAndroid Build Coastguard Worker         DataType::BFloat16,
1143*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1144*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1145*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1146*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1147*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
1148*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1149*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1150*89c4ff92SAndroid Build Coastguard Worker     };
1151*89c4ff92SAndroid Build Coastguard Worker 
1152*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1153*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Debug layer: input type not supported");
1154*89c4ff92SAndroid Build Coastguard Worker 
1155*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1156*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Debug layer: output type not supported");
1157*89c4ff92SAndroid Build Coastguard Worker 
1158*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1159*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Debug layer: input and output types are mismatched");
1160*89c4ff92SAndroid Build Coastguard Worker 
1161*89c4ff92SAndroid Build Coastguard Worker     return supported;
1162*89c4ff92SAndroid Build Coastguard Worker }
1163*89c4ff92SAndroid Build Coastguard Worker 
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1164*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1165*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& output,
1166*89c4ff92SAndroid Build Coastguard Worker                                               const DepthToSpaceDescriptor& descriptor,
1167*89c4ff92SAndroid Build Coastguard Worker                                               Optional<std::string&> reasonIfUnsupported) const
1168*89c4ff92SAndroid Build Coastguard Worker {
1169*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1170*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1171*89c4ff92SAndroid Build Coastguard Worker 
1172*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
1173*89c4ff92SAndroid Build Coastguard Worker     {
1174*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1175*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1176*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1177*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1178*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1179*89c4ff92SAndroid Build Coastguard Worker     };
1180*89c4ff92SAndroid Build Coastguard Worker 
1181*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1182*89c4ff92SAndroid Build Coastguard Worker         "Reference DepthToSpace: input type not supported");
1183*89c4ff92SAndroid Build Coastguard Worker 
1184*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1185*89c4ff92SAndroid Build Coastguard Worker         "Reference DepthToSpace: output type not supported");
1186*89c4ff92SAndroid Build Coastguard Worker 
1187*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1188*89c4ff92SAndroid Build Coastguard Worker         "Reference DepthToSpace: input and output types are mismatched");
1189*89c4ff92SAndroid Build Coastguard Worker 
1190*89c4ff92SAndroid Build Coastguard Worker     return supported;
1191*89c4ff92SAndroid Build Coastguard Worker }
1192*89c4ff92SAndroid Build Coastguard Worker 
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1193*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1194*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& output,
1195*89c4ff92SAndroid Build Coastguard Worker                                                       const DepthwiseConvolution2dDescriptor& descriptor,
1196*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& weights,
1197*89c4ff92SAndroid Build Coastguard Worker                                                       const Optional<TensorInfo>& biases,
1198*89c4ff92SAndroid Build Coastguard Worker                                                       Optional<std::string&> reasonIfUnsupported) const
1199*89c4ff92SAndroid Build Coastguard Worker {
1200*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1201*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1202*89c4ff92SAndroid Build Coastguard Worker 
1203*89c4ff92SAndroid Build Coastguard Worker     // Define supported types.
1204*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
1205*89c4ff92SAndroid Build Coastguard Worker     {
1206*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1207*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1208*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1209*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1210*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
1211*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1212*89c4ff92SAndroid Build Coastguard Worker     };
1213*89c4ff92SAndroid Build Coastguard Worker 
1214*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1215*89c4ff92SAndroid Build Coastguard Worker                                   "Reference DepthwiseConvolution2d: input is not a supported type.");
1216*89c4ff92SAndroid Build Coastguard Worker 
1217*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1218*89c4ff92SAndroid Build Coastguard Worker                                   "Reference DepthwiseConvolution2d: output is not a supported type.");
1219*89c4ff92SAndroid Build Coastguard Worker 
1220*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1221*89c4ff92SAndroid Build Coastguard Worker                                   "Reference DepthwiseConvolution2d: input and output types mismatched.");
1222*89c4ff92SAndroid Build Coastguard Worker 
1223*89c4ff92SAndroid Build Coastguard Worker     const DataType inputType = input.GetDataType();
1224*89c4ff92SAndroid Build Coastguard Worker     if (IsQuantized8BitType(inputType))
1225*89c4ff92SAndroid Build Coastguard Worker     {
1226*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType, 3> supportedWeightTypes =
1227*89c4ff92SAndroid Build Coastguard Worker                 {
1228*89c4ff92SAndroid Build Coastguard Worker                         DataType::QAsymmS8,
1229*89c4ff92SAndroid Build Coastguard Worker                         DataType::QAsymmU8,
1230*89c4ff92SAndroid Build Coastguard Worker                         DataType::QSymmS8,
1231*89c4ff92SAndroid Build Coastguard Worker                 };
1232*89c4ff92SAndroid Build Coastguard Worker 
1233*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1234*89c4ff92SAndroid Build Coastguard Worker                                        "Reference DepthwiseConvolution2d: weights type not supported for "
1235*89c4ff92SAndroid Build Coastguard Worker                                        "quantized input.");
1236*89c4ff92SAndroid Build Coastguard Worker     }
1237*89c4ff92SAndroid Build Coastguard Worker     else
1238*89c4ff92SAndroid Build Coastguard Worker     {
1239*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1240*89c4ff92SAndroid Build Coastguard Worker                                       "Reference DepthwiseConvolution2d: weights is not a supported type.");
1241*89c4ff92SAndroid Build Coastguard Worker 
1242*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1243*89c4ff92SAndroid Build Coastguard Worker                                       "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1244*89c4ff92SAndroid Build Coastguard Worker     }
1245*89c4ff92SAndroid Build Coastguard Worker 
1246*89c4ff92SAndroid Build Coastguard Worker     if (biases.has_value())
1247*89c4ff92SAndroid Build Coastguard Worker     {
1248*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType,4> biasesSupportedTypes =
1249*89c4ff92SAndroid Build Coastguard Worker         {
1250*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
1251*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16,
1252*89c4ff92SAndroid Build Coastguard Worker             DataType::Signed32
1253*89c4ff92SAndroid Build Coastguard Worker         };
1254*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1255*89c4ff92SAndroid Build Coastguard Worker                                       "Reference DepthwiseConvolution2d: biases is not a supported type.");
1256*89c4ff92SAndroid Build Coastguard Worker     }
1257*89c4ff92SAndroid Build Coastguard Worker 
1258*89c4ff92SAndroid Build Coastguard Worker     return supported;
1259*89c4ff92SAndroid Build Coastguard Worker 
1260*89c4ff92SAndroid Build Coastguard Worker }
1261*89c4ff92SAndroid Build Coastguard Worker 
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1262*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1263*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
1264*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
1265*89c4ff92SAndroid Build Coastguard Worker {
1266*89c4ff92SAndroid Build Coastguard Worker    bool supported = true;
1267*89c4ff92SAndroid Build Coastguard Worker 
1268*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,5> supportedInputTypes = {
1269*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1270*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1271*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
1272*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1273*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16
1274*89c4ff92SAndroid Build Coastguard Worker     };
1275*89c4ff92SAndroid Build Coastguard Worker 
1276*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1277*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Dequantize layer: input type not supported.");
1278*89c4ff92SAndroid Build Coastguard Worker 
1279*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
1280*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Dequantize layer: per-axis quantized input not supported.");
1281*89c4ff92SAndroid Build Coastguard Worker 
1282*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,3> supportedOutputTypes = {
1283*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1284*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16
1285*89c4ff92SAndroid Build Coastguard Worker     };
1286*89c4ff92SAndroid Build Coastguard Worker 
1287*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1288*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Dequantize layer: output type not supported.");
1289*89c4ff92SAndroid Build Coastguard Worker 
1290*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1291*89c4ff92SAndroid Build Coastguard Worker                                   "Reference for Dequantize layer: input/output shapes have different num total "
1292*89c4ff92SAndroid Build Coastguard Worker                                   "elements.");
1293*89c4ff92SAndroid Build Coastguard Worker 
1294*89c4ff92SAndroid Build Coastguard Worker     return supported;
1295*89c4ff92SAndroid Build Coastguard Worker }
1296*89c4ff92SAndroid Build Coastguard Worker 
IsDetectionPostProcessSupported(const TensorInfo & boxEncodings,const TensorInfo & scores,const TensorInfo & anchors,const TensorInfo & detectionBoxes,const TensorInfo & detectionClasses,const TensorInfo & detectionScores,const TensorInfo & numDetections,const DetectionPostProcessDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1297*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1298*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& scores,
1299*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& anchors,
1300*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& detectionBoxes,
1301*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& detectionClasses,
1302*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& detectionScores,
1303*89c4ff92SAndroid Build Coastguard Worker                                                       const TensorInfo& numDetections,
1304*89c4ff92SAndroid Build Coastguard Worker                                                       const DetectionPostProcessDescriptor& descriptor,
1305*89c4ff92SAndroid Build Coastguard Worker                                                       Optional<std::string&> reasonIfUnsupported) const
1306*89c4ff92SAndroid Build Coastguard Worker {
1307*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
1308*89c4ff92SAndroid Build Coastguard Worker 
1309*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1310*89c4ff92SAndroid Build Coastguard Worker 
1311*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedInputTypes =
1312*89c4ff92SAndroid Build Coastguard Worker     {
1313*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1314*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1315*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1316*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1317*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1318*89c4ff92SAndroid Build Coastguard Worker     };
1319*89c4ff92SAndroid Build Coastguard Worker 
1320*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
1321*89c4ff92SAndroid Build Coastguard Worker                                   "Reference DetectionPostProcess: input 0 is not a supported type.");
1322*89c4ff92SAndroid Build Coastguard Worker 
1323*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
1324*89c4ff92SAndroid Build Coastguard Worker                                   "Reference DetectionPostProcess: input 1 is not a supported type.");
1325*89c4ff92SAndroid Build Coastguard Worker 
1326*89c4ff92SAndroid Build Coastguard Worker     return supported;
1327*89c4ff92SAndroid Build Coastguard Worker }
1328*89c4ff92SAndroid Build Coastguard Worker 
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1329*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1330*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& output,
1331*89c4ff92SAndroid Build Coastguard Worker                                                              const DepthwiseConvolution2dDescriptor& descriptor,
1332*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& weights,
1333*89c4ff92SAndroid Build Coastguard Worker                                                              const Optional<TensorInfo>& biases,
1334*89c4ff92SAndroid Build Coastguard Worker                                                              Optional<std::string&> reasonIfUnsupported) const
1335*89c4ff92SAndroid Build Coastguard Worker {
1336*89c4ff92SAndroid Build Coastguard Worker     return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
1337*89c4ff92SAndroid Build Coastguard Worker }
1338*89c4ff92SAndroid Build Coastguard Worker 
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1339*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
1340*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& input1,
1341*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
1342*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
1343*89c4ff92SAndroid Build Coastguard Worker {
1344*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1345*89c4ff92SAndroid Build Coastguard Worker 
1346*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes = {
1347*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1348*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1349*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1350*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1351*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1352*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1353*89c4ff92SAndroid Build Coastguard Worker     };
1354*89c4ff92SAndroid Build Coastguard Worker 
1355*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1356*89c4ff92SAndroid Build Coastguard Worker                                   "Reference division: input 0 is not a supported type.");
1357*89c4ff92SAndroid Build Coastguard Worker 
1358*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1359*89c4ff92SAndroid Build Coastguard Worker                                   "Reference division: input 1 is not a supported type.");
1360*89c4ff92SAndroid Build Coastguard Worker 
1361*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1362*89c4ff92SAndroid Build Coastguard Worker                                   "Reference division: output is not a supported type.");
1363*89c4ff92SAndroid Build Coastguard Worker 
1364*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1365*89c4ff92SAndroid Build Coastguard Worker                                   "Reference division: input 0 and Input 1 types are mismatched");
1366*89c4ff92SAndroid Build Coastguard Worker 
1367*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1368*89c4ff92SAndroid Build Coastguard Worker                                   "Reference division: input and output types are mismatched");
1369*89c4ff92SAndroid Build Coastguard Worker 
1370*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1371*89c4ff92SAndroid Build Coastguard Worker                                   "Reference division: shapes are not suitable for implicit broadcast.");
1372*89c4ff92SAndroid Build Coastguard Worker 
1373*89c4ff92SAndroid Build Coastguard Worker     return supported;
1374*89c4ff92SAndroid Build Coastguard Worker }
1375*89c4ff92SAndroid Build Coastguard Worker 
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1376*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1377*89c4ff92SAndroid Build Coastguard Worker                                                   const TensorInfo& output,
1378*89c4ff92SAndroid Build Coastguard Worker                                                   const ElementwiseUnaryDescriptor& descriptor,
1379*89c4ff92SAndroid Build Coastguard Worker                                                   Optional<std::string&> reasonIfUnsupported) const
1380*89c4ff92SAndroid Build Coastguard Worker {
1381*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1382*89c4ff92SAndroid Build Coastguard Worker 
1383*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 7> supportedTypes =
1384*89c4ff92SAndroid Build Coastguard Worker     {
1385*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1386*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1387*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1388*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1389*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1390*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1391*89c4ff92SAndroid Build Coastguard Worker     };
1392*89c4ff92SAndroid Build Coastguard Worker 
1393*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 1> logicalSupportedTypes =
1394*89c4ff92SAndroid Build Coastguard Worker     {
1395*89c4ff92SAndroid Build Coastguard Worker         DataType::Boolean
1396*89c4ff92SAndroid Build Coastguard Worker     };
1397*89c4ff92SAndroid Build Coastguard Worker 
1398*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1399*89c4ff92SAndroid Build Coastguard Worker 
1400*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1401*89c4ff92SAndroid Build Coastguard Worker     {
1402*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1403*89c4ff92SAndroid Build Coastguard Worker                                       "Reference elementwise unary: input type not supported");
1404*89c4ff92SAndroid Build Coastguard Worker 
1405*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1406*89c4ff92SAndroid Build Coastguard Worker                                       "Reference elementwise unary: output type not supported");
1407*89c4ff92SAndroid Build Coastguard Worker     }
1408*89c4ff92SAndroid Build Coastguard Worker     else
1409*89c4ff92SAndroid Build Coastguard Worker     {
1410*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1411*89c4ff92SAndroid Build Coastguard Worker                                       "Reference elementwise unary: input type not supported");
1412*89c4ff92SAndroid Build Coastguard Worker 
1413*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1414*89c4ff92SAndroid Build Coastguard Worker                                       "Reference elementwise unary: output type not supported");
1415*89c4ff92SAndroid Build Coastguard Worker     }
1416*89c4ff92SAndroid Build Coastguard Worker 
1417*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1418*89c4ff92SAndroid Build Coastguard Worker                                   "Reference elementwise unary: input and output types not matching");
1419*89c4ff92SAndroid Build Coastguard Worker 
1420*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1421*89c4ff92SAndroid Build Coastguard Worker                                   "Reference elementwise unary: input and output shapes"
1422*89c4ff92SAndroid Build Coastguard Worker                                   "have different number of total elements");
1423*89c4ff92SAndroid Build Coastguard Worker 
1424*89c4ff92SAndroid Build Coastguard Worker     return supported;
1425*89c4ff92SAndroid Build Coastguard Worker }
1426*89c4ff92SAndroid Build Coastguard Worker 
IsFakeQuantizationSupported(const TensorInfo & input,const FakeQuantizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1427*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1428*89c4ff92SAndroid Build Coastguard Worker                                                   const FakeQuantizationDescriptor& descriptor,
1429*89c4ff92SAndroid Build Coastguard Worker                                                   Optional<std::string&> reasonIfUnsupported) const
1430*89c4ff92SAndroid Build Coastguard Worker {
1431*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1432*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1433*89c4ff92SAndroid Build Coastguard Worker 
1434*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,1> supportedTypes =
1435*89c4ff92SAndroid Build Coastguard Worker     {
1436*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32
1437*89c4ff92SAndroid Build Coastguard Worker     };
1438*89c4ff92SAndroid Build Coastguard Worker 
1439*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440*89c4ff92SAndroid Build Coastguard Worker                                   "Reference fake quantization: input type not supported.");
1441*89c4ff92SAndroid Build Coastguard Worker 
1442*89c4ff92SAndroid Build Coastguard Worker     return supported;
1443*89c4ff92SAndroid Build Coastguard Worker }
1444*89c4ff92SAndroid Build Coastguard Worker 
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1445*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1446*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& output,
1447*89c4ff92SAndroid Build Coastguard Worker                                       const FillDescriptor& descriptor,
1448*89c4ff92SAndroid Build Coastguard Worker                                       Optional<std::string&> reasonIfUnsupported) const
1449*89c4ff92SAndroid Build Coastguard Worker {
1450*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1451*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(output);
1452*89c4ff92SAndroid Build Coastguard Worker 
1453*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1454*89c4ff92SAndroid Build Coastguard Worker 
1455*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,3> supportedTypes =
1456*89c4ff92SAndroid Build Coastguard Worker     {
1457*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1458*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1459*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1460*89c4ff92SAndroid Build Coastguard Worker     };
1461*89c4ff92SAndroid Build Coastguard Worker 
1462*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
1463*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fill: input type not supported.");
1464*89c4ff92SAndroid Build Coastguard Worker 
1465*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1466*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fill: output type not supported.");
1467*89c4ff92SAndroid Build Coastguard Worker     return supported;
1468*89c4ff92SAndroid Build Coastguard Worker }
1469*89c4ff92SAndroid Build Coastguard Worker 
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1470*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1471*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
1472*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
1473*89c4ff92SAndroid Build Coastguard Worker {
1474*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(output);
1475*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1476*89c4ff92SAndroid Build Coastguard Worker 
1477*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,3> supportedTypes =
1478*89c4ff92SAndroid Build Coastguard Worker     {
1479*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1480*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16
1481*89c4ff92SAndroid Build Coastguard Worker     };
1482*89c4ff92SAndroid Build Coastguard Worker 
1483*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Floor: input type not supported.");
1485*89c4ff92SAndroid Build Coastguard Worker 
1486*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Floor: output type not supported.");
1488*89c4ff92SAndroid Build Coastguard Worker 
1489*89c4ff92SAndroid Build Coastguard Worker     return supported;
1490*89c4ff92SAndroid Build Coastguard Worker }
1491*89c4ff92SAndroid Build Coastguard Worker 
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1492*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1493*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
1494*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& weights,
1495*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& biases,
1496*89c4ff92SAndroid Build Coastguard Worker                                                 const FullyConnectedDescriptor& descriptor,
1497*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
1498*89c4ff92SAndroid Build Coastguard Worker {
1499*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1500*89c4ff92SAndroid Build Coastguard Worker 
1501*89c4ff92SAndroid Build Coastguard Worker     // Define supported types.
1502*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
1503*89c4ff92SAndroid Build Coastguard Worker     {
1504*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1505*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1506*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1507*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1508*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1509*89c4ff92SAndroid Build Coastguard Worker     };
1510*89c4ff92SAndroid Build Coastguard Worker 
1511*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fully Connected: input type not supported.");
1513*89c4ff92SAndroid Build Coastguard Worker 
1514*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fully Connected: output type not supported.");
1516*89c4ff92SAndroid Build Coastguard Worker 
1517*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1518*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fully Connected: weights type not supported.");
1519*89c4ff92SAndroid Build Coastguard Worker 
1520*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1521*89c4ff92SAndroid Build Coastguard Worker                               "Reference Fully Connected: input and output types mismatched.");
1522*89c4ff92SAndroid Build Coastguard Worker 
1523*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1524*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fully Connected: weights is not a supported type.");
1525*89c4ff92SAndroid Build Coastguard Worker 
1526*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1527*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Fully Connected: input and weights types mismatched.");
1528*89c4ff92SAndroid Build Coastguard Worker 
1529*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_BiasEnabled)
1530*89c4ff92SAndroid Build Coastguard Worker     {
1531*89c4ff92SAndroid Build Coastguard Worker         // Defined supported types for bias
1532*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType, 5>
1533*89c4ff92SAndroid Build Coastguard Worker         supportedBiasTypes =
1534*89c4ff92SAndroid Build Coastguard Worker         {
1535*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
1536*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16,
1537*89c4ff92SAndroid Build Coastguard Worker             DataType::Signed32,
1538*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmS8
1539*89c4ff92SAndroid Build Coastguard Worker         };
1540*89c4ff92SAndroid Build Coastguard Worker 
1541*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1542*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Fully Connected: bias type not supported.");
1543*89c4ff92SAndroid Build Coastguard Worker 
1544*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1545*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Fully Connected: bias and weight types mismatch.");
1546*89c4ff92SAndroid Build Coastguard Worker 
1547*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1548*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Fully Connected: bias type inferred from weights is incompatible.");
1549*89c4ff92SAndroid Build Coastguard Worker 
1550*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1551*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Fully Connected: bias must have 1 dimension.");
1552*89c4ff92SAndroid Build Coastguard Worker 
1553*89c4ff92SAndroid Build Coastguard Worker     }
1554*89c4ff92SAndroid Build Coastguard Worker 
1555*89c4ff92SAndroid Build Coastguard Worker     return supported;
1556*89c4ff92SAndroid Build Coastguard Worker }
1557*89c4ff92SAndroid Build Coastguard Worker 
IsGatherNdSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const1558*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1559*89c4ff92SAndroid Build Coastguard Worker                                           const armnn::TensorInfo& input1,
1560*89c4ff92SAndroid Build Coastguard Worker                                           const armnn::TensorInfo& output,
1561*89c4ff92SAndroid Build Coastguard Worker                                           armnn::Optional<std::string&> reasonIfUnsupported) const
1562*89c4ff92SAndroid Build Coastguard Worker {
1563*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1564*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
1565*89c4ff92SAndroid Build Coastguard Worker     {
1566*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
1567*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16,
1568*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmS8,
1569*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmU8,
1570*89c4ff92SAndroid Build Coastguard Worker             DataType::QSymmS16,
1571*89c4ff92SAndroid Build Coastguard Worker             DataType::Signed32
1572*89c4ff92SAndroid Build Coastguard Worker     };
1573*89c4ff92SAndroid Build Coastguard Worker 
1574*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1575*89c4ff92SAndroid Build Coastguard Worker                                   "Reference GatherNd: input type not supported");
1576*89c4ff92SAndroid Build Coastguard Worker 
1577*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578*89c4ff92SAndroid Build Coastguard Worker                                   "Reference GatherNd: output type not supported");
1579*89c4ff92SAndroid Build Coastguard Worker 
1580*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1581*89c4ff92SAndroid Build Coastguard Worker                                   "Reference GatherNd: indices (input1) type not supported");
1582*89c4ff92SAndroid Build Coastguard Worker 
1583*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584*89c4ff92SAndroid Build Coastguard Worker                                   "Reference GatherNd: input and output types not matching");
1585*89c4ff92SAndroid Build Coastguard Worker 
1586*89c4ff92SAndroid Build Coastguard Worker     return supported;
1587*89c4ff92SAndroid Build Coastguard Worker }
1588*89c4ff92SAndroid Build Coastguard Worker 
IsGatherSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,const GatherDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const1589*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1590*89c4ff92SAndroid Build Coastguard Worker                                         const armnn::TensorInfo& input1,
1591*89c4ff92SAndroid Build Coastguard Worker                                         const armnn::TensorInfo& output,
1592*89c4ff92SAndroid Build Coastguard Worker                                         const GatherDescriptor& descriptor,
1593*89c4ff92SAndroid Build Coastguard Worker                                         armnn::Optional<std::string&> reasonIfUnsupported) const
1594*89c4ff92SAndroid Build Coastguard Worker {
1595*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1596*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
1597*89c4ff92SAndroid Build Coastguard Worker     {
1598*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1599*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1600*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1601*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1602*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1603*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1604*89c4ff92SAndroid Build Coastguard Worker     };
1605*89c4ff92SAndroid Build Coastguard Worker 
1606*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1607*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1608*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Gather: input type not supported");
1609*89c4ff92SAndroid Build Coastguard Worker 
1610*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1611*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Gather: output type not supported");
1612*89c4ff92SAndroid Build Coastguard Worker 
1613*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1614*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Gather: indices (input1) type not supported");
1615*89c4ff92SAndroid Build Coastguard Worker 
1616*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1617*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Gather: input and output types not matching");
1618*89c4ff92SAndroid Build Coastguard Worker 
1619*89c4ff92SAndroid Build Coastguard Worker     return supported;
1620*89c4ff92SAndroid Build Coastguard Worker }
1621*89c4ff92SAndroid Build Coastguard Worker 
IsInputSupported(const TensorInfo &,Optional<std::string &>) const1622*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1623*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> /*reasonIfUnsupported*/) const
1624*89c4ff92SAndroid Build Coastguard Worker {
1625*89c4ff92SAndroid Build Coastguard Worker     return true;
1626*89c4ff92SAndroid Build Coastguard Worker }
1627*89c4ff92SAndroid Build Coastguard Worker 
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1628*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1629*89c4ff92SAndroid Build Coastguard Worker                                                        const TensorInfo& output,
1630*89c4ff92SAndroid Build Coastguard Worker                                                        const InstanceNormalizationDescriptor& descriptor,
1631*89c4ff92SAndroid Build Coastguard Worker                                                        Optional<std::string&> reasonIfUnsupported) const
1632*89c4ff92SAndroid Build Coastguard Worker {
1633*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1634*89c4ff92SAndroid Build Coastguard Worker     // Define supported types
1635*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 3> supportedTypes =
1636*89c4ff92SAndroid Build Coastguard Worker         {
1637*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
1638*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16
1639*89c4ff92SAndroid Build Coastguard Worker         };
1640*89c4ff92SAndroid Build Coastguard Worker 
1641*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1642*89c4ff92SAndroid Build Coastguard Worker 
1643*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1644*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Instance Normalization: input type not supported.");
1645*89c4ff92SAndroid Build Coastguard Worker 
1646*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1647*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Instance Normalization: output type not supported.");
1648*89c4ff92SAndroid Build Coastguard Worker 
1649*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1650*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Instance Normalization: input and output types mismatched.");
1651*89c4ff92SAndroid Build Coastguard Worker 
1652*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1653*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Instance Normalization: input and output shapes have different "
1654*89c4ff92SAndroid Build Coastguard Worker                                   "num total elements.");
1655*89c4ff92SAndroid Build Coastguard Worker 
1656*89c4ff92SAndroid Build Coastguard Worker     return supported;
1657*89c4ff92SAndroid Build Coastguard Worker }
1658*89c4ff92SAndroid Build Coastguard Worker 
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1659*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1660*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& output,
1661*89c4ff92SAndroid Build Coastguard Worker                                                  const L2NormalizationDescriptor& descriptor,
1662*89c4ff92SAndroid Build Coastguard Worker                                                  Optional<std::string&> reasonIfUnsupported) const
1663*89c4ff92SAndroid Build Coastguard Worker {
1664*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1665*89c4ff92SAndroid Build Coastguard Worker     // Define supported types
1666*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes =
1667*89c4ff92SAndroid Build Coastguard Worker     {
1668*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1669*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1670*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1671*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1672*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1673*89c4ff92SAndroid Build Coastguard Worker     };
1674*89c4ff92SAndroid Build Coastguard Worker 
1675*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1676*89c4ff92SAndroid Build Coastguard Worker 
1677*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1678*89c4ff92SAndroid Build Coastguard Worker                                   "Reference L2normalization: input type not supported.");
1679*89c4ff92SAndroid Build Coastguard Worker 
1680*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1681*89c4ff92SAndroid Build Coastguard Worker                                   "Reference L2normalization: output type not supported.");
1682*89c4ff92SAndroid Build Coastguard Worker 
1683*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1684*89c4ff92SAndroid Build Coastguard Worker                                   "Reference L2normalization: input and output types mismatched.");
1685*89c4ff92SAndroid Build Coastguard Worker 
1686*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1687*89c4ff92SAndroid Build Coastguard Worker                                   "Reference L2normalization: input and output shapes have different "
1688*89c4ff92SAndroid Build Coastguard Worker                                   "num total elements.");
1689*89c4ff92SAndroid Build Coastguard Worker 
1690*89c4ff92SAndroid Build Coastguard Worker     return supported;
1691*89c4ff92SAndroid Build Coastguard Worker }
1692*89c4ff92SAndroid Build Coastguard Worker 
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1693*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1694*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& input1,
1695*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
1696*89c4ff92SAndroid Build Coastguard Worker                                                const LogicalBinaryDescriptor& descriptor,
1697*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
1698*89c4ff92SAndroid Build Coastguard Worker {
1699*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1700*89c4ff92SAndroid Build Coastguard Worker 
1701*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 1> supportedTypes =
1702*89c4ff92SAndroid Build Coastguard Worker     {
1703*89c4ff92SAndroid Build Coastguard Worker         DataType::Boolean
1704*89c4ff92SAndroid Build Coastguard Worker     };
1705*89c4ff92SAndroid Build Coastguard Worker 
1706*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1707*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1708*89c4ff92SAndroid Build Coastguard Worker                                   "Reference LogicalBinary: input 0 type not supported");
1709*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1710*89c4ff92SAndroid Build Coastguard Worker                                   "Reference LogicalBinary: input 1 type not supported");
1711*89c4ff92SAndroid Build Coastguard Worker 
1712*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1713*89c4ff92SAndroid Build Coastguard Worker                                   "Reference LogicalBinary: input and output types do not match");
1714*89c4ff92SAndroid Build Coastguard Worker 
1715*89c4ff92SAndroid Build Coastguard Worker     return supported;
1716*89c4ff92SAndroid Build Coastguard Worker }
1717*89c4ff92SAndroid Build Coastguard Worker 
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1718*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1719*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
1720*89c4ff92SAndroid Build Coastguard Worker                                             const LogSoftmaxDescriptor& descriptor,
1721*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
1722*89c4ff92SAndroid Build Coastguard Worker {
1723*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1724*89c4ff92SAndroid Build Coastguard Worker 
1725*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 3> supportedTypes =
1726*89c4ff92SAndroid Build Coastguard Worker     {
1727*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1728*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16
1729*89c4ff92SAndroid Build Coastguard Worker     };
1730*89c4ff92SAndroid Build Coastguard Worker 
1731*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1732*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1733*89c4ff92SAndroid Build Coastguard Worker                                   "Reference LogSoftmax: input type not supported");
1734*89c4ff92SAndroid Build Coastguard Worker 
1735*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1736*89c4ff92SAndroid Build Coastguard Worker                                   "Reference LogSoftmax: output type not supported");
1737*89c4ff92SAndroid Build Coastguard Worker 
1738*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1739*89c4ff92SAndroid Build Coastguard Worker                                   "Reference LogSoftmax: input and output types do not match");
1740*89c4ff92SAndroid Build Coastguard Worker 
1741*89c4ff92SAndroid Build Coastguard Worker     return supported;
1742*89c4ff92SAndroid Build Coastguard Worker }
1743*89c4ff92SAndroid Build Coastguard Worker 
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1744*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1745*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& outputStateIn,
1746*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& cellStateIn,
1747*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& scratchBuffer,
1748*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& outputStateOut,
1749*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& cellStateOut,
1750*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& output,
1751*89c4ff92SAndroid Build Coastguard Worker                                       const LstmDescriptor& descriptor,
1752*89c4ff92SAndroid Build Coastguard Worker                                       const LstmInputParamsInfo& paramsInfo,
1753*89c4ff92SAndroid Build Coastguard Worker                                       Optional<std::string&> reasonIfUnsupported) const
1754*89c4ff92SAndroid Build Coastguard Worker {
1755*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1756*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(paramsInfo);
1757*89c4ff92SAndroid Build Coastguard Worker 
1758*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1759*89c4ff92SAndroid Build Coastguard Worker 
1760*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,3> supportedTypes = {
1761*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1762*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1763*89c4ff92SAndroid Build Coastguard Worker     };
1764*89c4ff92SAndroid Build Coastguard Worker 
1765*89c4ff92SAndroid Build Coastguard Worker     // check inputs and outputs
1766*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1767*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input is not a supported type.");
1768*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1769*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and outputStateIn types are mismatched");
1770*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1771*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and cellStateIn types are mismatched");
1772*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1773*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and scratchBuffer types are mismatched");
1774*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1775*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and outputStateOut types are mismatched");
1776*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1777*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and cellStateOut types are mismatched");
1778*89c4ff92SAndroid Build Coastguard Worker 
1779*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1780*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and output types are mismatched");
1781*89c4ff92SAndroid Build Coastguard Worker     // check layer parameters
1782*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1783*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and InputToForgetWeights types are mismatched");
1784*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1785*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and InputToCellWeights types are mismatched");
1786*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1787*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and InputToOutputWeights types are mismatched");
1788*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1789*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1790*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1791*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1792*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1793*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1794*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1795*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and ForgetGateBias types are mismatched");
1796*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1797*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and CellBias types are mismatched");
1798*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1799*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Lstm: input and OutputGateBias types are mismatched");
1800*89c4ff92SAndroid Build Coastguard Worker     if (!descriptor.m_CifgEnabled)
1801*89c4ff92SAndroid Build Coastguard Worker     {
1802*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1803*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and InputToInputWeights types are mismatched");
1804*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1805*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
1806*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1807*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1808*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and InputGateBias types are mismatched");
1809*89c4ff92SAndroid Build Coastguard Worker         if (descriptor.m_PeepholeEnabled)
1810*89c4ff92SAndroid Build Coastguard Worker         {
1811*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1812*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported,
1813*89c4ff92SAndroid Build Coastguard Worker                                           "Reference Lstm: input and CellToInputWeights types are mismatched");
1814*89c4ff92SAndroid Build Coastguard Worker         }
1815*89c4ff92SAndroid Build Coastguard Worker     }
1816*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_PeepholeEnabled)
1817*89c4ff92SAndroid Build Coastguard Worker     {
1818*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1819*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and CellToForgetWeights types are mismatched");
1820*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1821*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and CellToOutputWeights types are mismatched");
1822*89c4ff92SAndroid Build Coastguard Worker     }
1823*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_ProjectionEnabled)
1824*89c4ff92SAndroid Build Coastguard Worker     {
1825*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1826*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and mProjectionWeights types are mismatched");
1827*89c4ff92SAndroid Build Coastguard Worker         if (paramsInfo.m_ProjectionBias != nullptr)
1828*89c4ff92SAndroid Build Coastguard Worker         {
1829*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1830*89c4ff92SAndroid Build Coastguard Worker                                           "Reference Lstm: input and ProjectionBias types are mismatched");
1831*89c4ff92SAndroid Build Coastguard Worker         }
1832*89c4ff92SAndroid Build Coastguard Worker     }
1833*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_LayerNormEnabled)
1834*89c4ff92SAndroid Build Coastguard Worker     {
1835*89c4ff92SAndroid Build Coastguard Worker         if (!descriptor.m_CifgEnabled)
1836*89c4ff92SAndroid Build Coastguard Worker         {
1837*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1838*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported,
1839*89c4ff92SAndroid Build Coastguard Worker                                           "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1840*89c4ff92SAndroid Build Coastguard Worker         }
1841*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1842*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
1843*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1844*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1845*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
1846*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1847*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1848*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
1849*89c4ff92SAndroid Build Coastguard Worker                                       "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1850*89c4ff92SAndroid Build Coastguard Worker     }
1851*89c4ff92SAndroid Build Coastguard Worker 
1852*89c4ff92SAndroid Build Coastguard Worker     return supported;
1853*89c4ff92SAndroid Build Coastguard Worker }
1854*89c4ff92SAndroid Build Coastguard Worker 
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1855*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1856*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& input1,
1857*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
1858*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
1859*89c4ff92SAndroid Build Coastguard Worker {
1860*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1861*89c4ff92SAndroid Build Coastguard Worker 
1862*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes = {
1863*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1864*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1865*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1866*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1867*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1868*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1869*89c4ff92SAndroid Build Coastguard Worker     };
1870*89c4ff92SAndroid Build Coastguard Worker 
1871*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1872*89c4ff92SAndroid Build Coastguard Worker                                   "Reference maximum: input 0 is not a supported type.");
1873*89c4ff92SAndroid Build Coastguard Worker 
1874*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1875*89c4ff92SAndroid Build Coastguard Worker                                   "Reference maximum: input 1 is not a supported type.");
1876*89c4ff92SAndroid Build Coastguard Worker 
1877*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1878*89c4ff92SAndroid Build Coastguard Worker                                   "Reference maximum: output is not a supported type.");
1879*89c4ff92SAndroid Build Coastguard Worker 
1880*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1881*89c4ff92SAndroid Build Coastguard Worker                                   "Reference maximum: input 0 and Input 1 types are mismatched");
1882*89c4ff92SAndroid Build Coastguard Worker 
1883*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1884*89c4ff92SAndroid Build Coastguard Worker                                   "Reference maximum: input and output types are mismatched");
1885*89c4ff92SAndroid Build Coastguard Worker 
1886*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1887*89c4ff92SAndroid Build Coastguard Worker                                   "Reference maximum: shapes are not suitable for implicit broadcast.");
1888*89c4ff92SAndroid Build Coastguard Worker 
1889*89c4ff92SAndroid Build Coastguard Worker     return supported;
1890*89c4ff92SAndroid Build Coastguard Worker }
1891*89c4ff92SAndroid Build Coastguard Worker 
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1892*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1893*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& output,
1894*89c4ff92SAndroid Build Coastguard Worker                                       const MeanDescriptor& descriptor,
1895*89c4ff92SAndroid Build Coastguard Worker                                       Optional<std::string&> reasonIfUnsupported) const
1896*89c4ff92SAndroid Build Coastguard Worker {
1897*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1898*89c4ff92SAndroid Build Coastguard Worker     std::string meanLayerStr = "Mean";
1899*89c4ff92SAndroid Build Coastguard Worker     std::string outputTensorStr = "output";
1900*89c4ff92SAndroid Build Coastguard Worker 
1901*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
1902*89c4ff92SAndroid Build Coastguard Worker     {
1903*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1904*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1905*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1906*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1907*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
1908*89c4ff92SAndroid Build Coastguard Worker     };
1909*89c4ff92SAndroid Build Coastguard Worker 
1910*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1911*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Mean: input type not supported.");
1912*89c4ff92SAndroid Build Coastguard Worker 
1913*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1914*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Mean: input and output types are mismatched");
1915*89c4ff92SAndroid Build Coastguard Worker 
1916*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_KeepDims)
1917*89c4ff92SAndroid Build Coastguard Worker     {
1918*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1919*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
1920*89c4ff92SAndroid Build Coastguard Worker                                       CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1921*89c4ff92SAndroid Build Coastguard Worker                                                                         output.GetNumDimensions(),
1922*89c4ff92SAndroid Build Coastguard Worker                                                                         meanLayerStr, outputTensorStr).data());
1923*89c4ff92SAndroid Build Coastguard Worker     }
1924*89c4ff92SAndroid Build Coastguard Worker     else if (descriptor.m_Axis.empty())
1925*89c4ff92SAndroid Build Coastguard Worker     {
1926*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1927*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
1928*89c4ff92SAndroid Build Coastguard Worker                                       CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1929*89c4ff92SAndroid Build Coastguard Worker                                                                         meanLayerStr, outputTensorStr).data());
1930*89c4ff92SAndroid Build Coastguard Worker     }
1931*89c4ff92SAndroid Build Coastguard Worker     else
1932*89c4ff92SAndroid Build Coastguard Worker     {
1933*89c4ff92SAndroid Build Coastguard Worker         auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1934*89c4ff92SAndroid Build Coastguard Worker 
1935*89c4ff92SAndroid Build Coastguard Worker         if (outputDim > 0)
1936*89c4ff92SAndroid Build Coastguard Worker         {
1937*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1938*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported,
1939*89c4ff92SAndroid Build Coastguard Worker                                           CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1940*89c4ff92SAndroid Build Coastguard Worker                                                                             meanLayerStr, outputTensorStr).data());
1941*89c4ff92SAndroid Build Coastguard Worker         }
1942*89c4ff92SAndroid Build Coastguard Worker         else
1943*89c4ff92SAndroid Build Coastguard Worker         {
1944*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1945*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported,
1946*89c4ff92SAndroid Build Coastguard Worker                                           CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1947*89c4ff92SAndroid Build Coastguard Worker                                                                             meanLayerStr, outputTensorStr).data());
1948*89c4ff92SAndroid Build Coastguard Worker         }
1949*89c4ff92SAndroid Build Coastguard Worker     }
1950*89c4ff92SAndroid Build Coastguard Worker 
1951*89c4ff92SAndroid Build Coastguard Worker     return supported;
1952*89c4ff92SAndroid Build Coastguard Worker }
1953*89c4ff92SAndroid Build Coastguard Worker 
IsMemCopySupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1954*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1955*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo &output,
1956*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string &> reasonIfUnsupported) const
1957*89c4ff92SAndroid Build Coastguard Worker {
1958*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1959*89c4ff92SAndroid Build Coastguard Worker 
1960*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
1961*89c4ff92SAndroid Build Coastguard Worker     {
1962*89c4ff92SAndroid Build Coastguard Worker         DataType::BFloat16,
1963*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1964*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1965*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1966*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1967*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1968*89c4ff92SAndroid Build Coastguard Worker         DataType::Boolean
1969*89c4ff92SAndroid Build Coastguard Worker     };
1970*89c4ff92SAndroid Build Coastguard Worker 
1971*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1972*89c4ff92SAndroid Build Coastguard Worker                                   "Reference MemCopy: input type not supported");
1973*89c4ff92SAndroid Build Coastguard Worker 
1974*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1975*89c4ff92SAndroid Build Coastguard Worker                                   "Reference MemCopy: output type not supported");
1976*89c4ff92SAndroid Build Coastguard Worker 
1977*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1978*89c4ff92SAndroid Build Coastguard Worker                                   "Reference MemCopy: input and output types are mismatched");
1979*89c4ff92SAndroid Build Coastguard Worker 
1980*89c4ff92SAndroid Build Coastguard Worker     return supported;
1981*89c4ff92SAndroid Build Coastguard Worker }
1982*89c4ff92SAndroid Build Coastguard Worker 
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1983*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1984*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& input1,
1985*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
1986*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
1987*89c4ff92SAndroid Build Coastguard Worker {
1988*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
1989*89c4ff92SAndroid Build Coastguard Worker 
1990*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes = {
1991*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
1992*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
1993*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
1994*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
1995*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
1996*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
1997*89c4ff92SAndroid Build Coastguard Worker     };
1998*89c4ff92SAndroid Build Coastguard Worker 
1999*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2000*89c4ff92SAndroid Build Coastguard Worker                                   "Reference minimum: input 0 is not a supported type.");
2001*89c4ff92SAndroid Build Coastguard Worker 
2002*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2003*89c4ff92SAndroid Build Coastguard Worker                                   "Reference minimum: input 1 is not a supported type.");
2004*89c4ff92SAndroid Build Coastguard Worker 
2005*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2006*89c4ff92SAndroid Build Coastguard Worker                                   "Reference minimum: output is not a supported type.");
2007*89c4ff92SAndroid Build Coastguard Worker 
2008*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2009*89c4ff92SAndroid Build Coastguard Worker                                   "Reference minimum: input 0 and Input 1 types are mismatched");
2010*89c4ff92SAndroid Build Coastguard Worker 
2011*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2012*89c4ff92SAndroid Build Coastguard Worker                                   "Reference minimum: input and output types are mismatched");
2013*89c4ff92SAndroid Build Coastguard Worker 
2014*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2015*89c4ff92SAndroid Build Coastguard Worker                                   "Reference minimum: shapes are not suitable for implicit broadcast.");
2016*89c4ff92SAndroid Build Coastguard Worker 
2017*89c4ff92SAndroid Build Coastguard Worker     return supported;
2018*89c4ff92SAndroid Build Coastguard Worker }
2019*89c4ff92SAndroid Build Coastguard Worker 
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2020*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2021*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& input1,
2022*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
2023*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
2024*89c4ff92SAndroid Build Coastguard Worker {
2025*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2026*89c4ff92SAndroid Build Coastguard Worker 
2027*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes = {
2028*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2029*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2030*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2031*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2032*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
2033*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
2034*89c4ff92SAndroid Build Coastguard Worker     };
2035*89c4ff92SAndroid Build Coastguard Worker 
2036*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2037*89c4ff92SAndroid Build Coastguard Worker                                   "Reference multiplication: input 0 is not a supported type.");
2038*89c4ff92SAndroid Build Coastguard Worker 
2039*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2040*89c4ff92SAndroid Build Coastguard Worker                                   "Reference multiplication: input 1 is not a supported type.");
2041*89c4ff92SAndroid Build Coastguard Worker 
2042*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2043*89c4ff92SAndroid Build Coastguard Worker                                   "Reference multiplication: output is not a supported type.");
2044*89c4ff92SAndroid Build Coastguard Worker 
2045*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2046*89c4ff92SAndroid Build Coastguard Worker                                   "Reference multiplication: input 0 and Input 1 types are mismatched");
2047*89c4ff92SAndroid Build Coastguard Worker 
2048*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2049*89c4ff92SAndroid Build Coastguard Worker                                   "Reference multiplication: input and output types are mismatched");
2050*89c4ff92SAndroid Build Coastguard Worker 
2051*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2052*89c4ff92SAndroid Build Coastguard Worker                                   "Reference multiplication: shapes are not suitable for implicit broadcast.");
2053*89c4ff92SAndroid Build Coastguard Worker 
2054*89c4ff92SAndroid Build Coastguard Worker     return supported;
2055*89c4ff92SAndroid Build Coastguard Worker }
2056*89c4ff92SAndroid Build Coastguard Worker 
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2057*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2058*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
2059*89c4ff92SAndroid Build Coastguard Worker                                                const NormalizationDescriptor& descriptor,
2060*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
2061*89c4ff92SAndroid Build Coastguard Worker {
2062*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2063*89c4ff92SAndroid Build Coastguard Worker 
2064*89c4ff92SAndroid Build Coastguard Worker     // Define supported types
2065*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes =
2066*89c4ff92SAndroid Build Coastguard Worker     {
2067*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2068*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2069*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2070*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2071*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2072*89c4ff92SAndroid Build Coastguard Worker     };
2073*89c4ff92SAndroid Build Coastguard Worker 
2074*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2075*89c4ff92SAndroid Build Coastguard Worker 
2076*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2077*89c4ff92SAndroid Build Coastguard Worker                                   "Reference normalization: input type not supported.");
2078*89c4ff92SAndroid Build Coastguard Worker 
2079*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2080*89c4ff92SAndroid Build Coastguard Worker                                   "Reference normalization: output type not supported.");
2081*89c4ff92SAndroid Build Coastguard Worker 
2082*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2083*89c4ff92SAndroid Build Coastguard Worker                                   "Reference normalization: input and output shapes have different "
2084*89c4ff92SAndroid Build Coastguard Worker                                   "num total elements.");
2085*89c4ff92SAndroid Build Coastguard Worker 
2086*89c4ff92SAndroid Build Coastguard Worker     return supported;
2087*89c4ff92SAndroid Build Coastguard Worker }
2088*89c4ff92SAndroid Build Coastguard Worker 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const2089*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2090*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> /*reasonIfUnsupported*/) const
2091*89c4ff92SAndroid Build Coastguard Worker {
2092*89c4ff92SAndroid Build Coastguard Worker     return true;
2093*89c4ff92SAndroid Build Coastguard Worker }
2094*89c4ff92SAndroid Build Coastguard Worker 
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2095*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2096*89c4ff92SAndroid Build Coastguard Worker                                      const TensorInfo& output,
2097*89c4ff92SAndroid Build Coastguard Worker                                      const PadDescriptor& descriptor,
2098*89c4ff92SAndroid Build Coastguard Worker                                      Optional<std::string&> reasonIfUnsupported) const
2099*89c4ff92SAndroid Build Coastguard Worker {
2100*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2101*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2102*89c4ff92SAndroid Build Coastguard Worker 
2103*89c4ff92SAndroid Build Coastguard Worker     // Define supported output and inputs types.
2104*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2105*89c4ff92SAndroid Build Coastguard Worker     {
2106*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2107*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2108*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2109*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2110*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2111*89c4ff92SAndroid Build Coastguard Worker     };
2112*89c4ff92SAndroid Build Coastguard Worker 
2113*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2114*89c4ff92SAndroid Build Coastguard Worker                                   "Reference pad: input is not a supported type.");
2115*89c4ff92SAndroid Build Coastguard Worker 
2116*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2117*89c4ff92SAndroid Build Coastguard Worker                                   "Reference pad: output is not a supported type.");
2118*89c4ff92SAndroid Build Coastguard Worker 
2119*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2120*89c4ff92SAndroid Build Coastguard Worker                                   "Reference pad: input and output types are mismatched.");
2121*89c4ff92SAndroid Build Coastguard Worker 
2122*89c4ff92SAndroid Build Coastguard Worker     return supported;
2123*89c4ff92SAndroid Build Coastguard Worker }
2124*89c4ff92SAndroid Build Coastguard Worker 
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2125*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2126*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
2127*89c4ff92SAndroid Build Coastguard Worker                                          const PermuteDescriptor& descriptor,
2128*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
2129*89c4ff92SAndroid Build Coastguard Worker {
2130*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2131*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2132*89c4ff92SAndroid Build Coastguard Worker 
2133*89c4ff92SAndroid Build Coastguard Worker     // Define supported output and inputs types.
2134*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes =
2135*89c4ff92SAndroid Build Coastguard Worker     {
2136*89c4ff92SAndroid Build Coastguard Worker         DataType::BFloat16,
2137*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2138*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2139*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2140*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2141*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2142*89c4ff92SAndroid Build Coastguard Worker     };
2143*89c4ff92SAndroid Build Coastguard Worker 
2144*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145*89c4ff92SAndroid Build Coastguard Worker                                   "Reference permute: input is not a supported type.");
2146*89c4ff92SAndroid Build Coastguard Worker 
2147*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148*89c4ff92SAndroid Build Coastguard Worker                                   "Reference permute: output is not a supported type.");
2149*89c4ff92SAndroid Build Coastguard Worker 
2150*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151*89c4ff92SAndroid Build Coastguard Worker                                   "Reference permute: input and output types are mismatched.");
2152*89c4ff92SAndroid Build Coastguard Worker 
2153*89c4ff92SAndroid Build Coastguard Worker     return supported;
2154*89c4ff92SAndroid Build Coastguard Worker }
2155*89c4ff92SAndroid Build Coastguard Worker 
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2156*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2157*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
2158*89c4ff92SAndroid Build Coastguard Worker                                            const Pooling2dDescriptor& descriptor,
2159*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
2160*89c4ff92SAndroid Build Coastguard Worker {
2161*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2162*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2163*89c4ff92SAndroid Build Coastguard Worker 
2164*89c4ff92SAndroid Build Coastguard Worker     // Define supported output and inputs types.
2165*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2166*89c4ff92SAndroid Build Coastguard Worker     {
2167*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2168*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2169*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2170*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2171*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2172*89c4ff92SAndroid Build Coastguard Worker     };
2173*89c4ff92SAndroid Build Coastguard Worker 
2174*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2175*89c4ff92SAndroid Build Coastguard Worker                                   "Reference poolind2d: input is not a supported type.");
2176*89c4ff92SAndroid Build Coastguard Worker 
2177*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2178*89c4ff92SAndroid Build Coastguard Worker                                   "Reference poolind2d: output is not a supported type.");
2179*89c4ff92SAndroid Build Coastguard Worker 
2180*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2181*89c4ff92SAndroid Build Coastguard Worker                                   "Reference poolind2d: input and output types are mismatched.");
2182*89c4ff92SAndroid Build Coastguard Worker 
2183*89c4ff92SAndroid Build Coastguard Worker     return supported;
2184*89c4ff92SAndroid Build Coastguard Worker }
2185*89c4ff92SAndroid Build Coastguard Worker 
IsPooling3dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling3dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2186*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2187*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
2188*89c4ff92SAndroid Build Coastguard Worker                                            const Pooling3dDescriptor& descriptor,
2189*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
2190*89c4ff92SAndroid Build Coastguard Worker {
2191*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2192*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2193*89c4ff92SAndroid Build Coastguard Worker 
2194*89c4ff92SAndroid Build Coastguard Worker     // Define supported output and inputs types.
2195*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2196*89c4ff92SAndroid Build Coastguard Worker     {
2197*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2198*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2199*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2200*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2201*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2202*89c4ff92SAndroid Build Coastguard Worker     };
2203*89c4ff92SAndroid Build Coastguard Worker 
2204*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2205*89c4ff92SAndroid Build Coastguard Worker                                   "Reference poolind3d: input is not a supported type.");
2206*89c4ff92SAndroid Build Coastguard Worker 
2207*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2208*89c4ff92SAndroid Build Coastguard Worker                                   "Reference poolind3d: output is not a supported type.");
2209*89c4ff92SAndroid Build Coastguard Worker 
2210*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2211*89c4ff92SAndroid Build Coastguard Worker                                   "Reference poolind3d: input and output types are mismatched.");
2212*89c4ff92SAndroid Build Coastguard Worker 
2213*89c4ff92SAndroid Build Coastguard Worker     return supported;
2214*89c4ff92SAndroid Build Coastguard Worker }
2215*89c4ff92SAndroid Build Coastguard Worker 
2216*89c4ff92SAndroid Build Coastguard Worker 
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2217*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2218*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& previousOutputIn,
2219*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& previousCellStateIn,
2220*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& outputStateOut,
2221*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& cellStateOut,
2222*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
2223*89c4ff92SAndroid Build Coastguard Worker                                        const QLstmDescriptor& descriptor,
2224*89c4ff92SAndroid Build Coastguard Worker                                        const LstmInputParamsInfo& paramsInfo,
2225*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
2226*89c4ff92SAndroid Build Coastguard Worker {
2227*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(input);
2228*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(previousOutputIn);
2229*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(previousCellStateIn);
2230*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(outputStateOut);
2231*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(cellStateOut);
2232*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(output);
2233*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2234*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(paramsInfo);
2235*89c4ff92SAndroid Build Coastguard Worker 
2236*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(reasonIfUnsupported);
2237*89c4ff92SAndroid Build Coastguard Worker 
2238*89c4ff92SAndroid Build Coastguard Worker     return true;
2239*89c4ff92SAndroid Build Coastguard Worker }
2240*89c4ff92SAndroid Build Coastguard Worker 
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2241*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2242*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
2243*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
2244*89c4ff92SAndroid Build Coastguard Worker {
2245*89c4ff92SAndroid Build Coastguard Worker    bool supported = true;
2246*89c4ff92SAndroid Build Coastguard Worker 
2247*89c4ff92SAndroid Build Coastguard Worker     // Define supported input types.
2248*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedInputTypes = {
2249*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2250*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2251*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2252*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2253*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
2254*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2255*89c4ff92SAndroid Build Coastguard Worker     };
2256*89c4ff92SAndroid Build Coastguard Worker 
2257*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2258*89c4ff92SAndroid Build Coastguard Worker                                   "Reference quantize: input type not supported.");
2259*89c4ff92SAndroid Build Coastguard Worker 
2260*89c4ff92SAndroid Build Coastguard Worker     // Define supported output types.
2261*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,4> supportedOutputTypes = {
2262*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2263*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2264*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
2265*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2266*89c4ff92SAndroid Build Coastguard Worker     };
2267*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2268*89c4ff92SAndroid Build Coastguard Worker                                   "Reference quantize: output type not supported.");
2269*89c4ff92SAndroid Build Coastguard Worker 
2270*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2271*89c4ff92SAndroid Build Coastguard Worker                                   "Reference quantize: input and output shapes have different num total elements.");
2272*89c4ff92SAndroid Build Coastguard Worker 
2273*89c4ff92SAndroid Build Coastguard Worker     return supported;
2274*89c4ff92SAndroid Build Coastguard Worker }
2275*89c4ff92SAndroid Build Coastguard Worker 
IsRankSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2276*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2277*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& output,
2278*89c4ff92SAndroid Build Coastguard Worker                                       Optional<std::string&> reasonIfUnsupported) const
2279*89c4ff92SAndroid Build Coastguard Worker {
2280*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(input);
2281*89c4ff92SAndroid Build Coastguard Worker     // Define supported output types.
2282*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,1> supportedOutputTypes =
2283*89c4ff92SAndroid Build Coastguard Worker     {
2284*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32,
2285*89c4ff92SAndroid Build Coastguard Worker     };
2286*89c4ff92SAndroid Build Coastguard Worker 
2287*89c4ff92SAndroid Build Coastguard Worker     return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2288*89c4ff92SAndroid Build Coastguard Worker            "Reference rank: input type not supported.");
2289*89c4ff92SAndroid Build Coastguard Worker }
2290*89c4ff92SAndroid Build Coastguard Worker 
IsReduceSupported(const TensorInfo & input,const TensorInfo & output,const ReduceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2291*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2292*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
2293*89c4ff92SAndroid Build Coastguard Worker                                         const ReduceDescriptor& descriptor,
2294*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
2295*89c4ff92SAndroid Build Coastguard Worker {
2296*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2297*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2298*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
2299*89c4ff92SAndroid Build Coastguard Worker     {
2300*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2301*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2302*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2303*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2304*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
2305*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
2306*89c4ff92SAndroid Build Coastguard Worker     };
2307*89c4ff92SAndroid Build Coastguard Worker 
2308*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2309*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Reduce: input type not supported");
2310*89c4ff92SAndroid Build Coastguard Worker 
2311*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2312*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Reduce: output type not supported");
2313*89c4ff92SAndroid Build Coastguard Worker 
2314*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2315*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Reduce: input and output types not matching");
2316*89c4ff92SAndroid Build Coastguard Worker 
2317*89c4ff92SAndroid Build Coastguard Worker     return supported;
2318*89c4ff92SAndroid Build Coastguard Worker }
2319*89c4ff92SAndroid Build Coastguard Worker 
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2320*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
2321*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
2322*89c4ff92SAndroid Build Coastguard Worker                                          const ReshapeDescriptor& descriptor,
2323*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
2324*89c4ff92SAndroid Build Coastguard Worker {
2325*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(output);
2326*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2327*89c4ff92SAndroid Build Coastguard Worker     // Define supported output types.
2328*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,8> supportedOutputTypes =
2329*89c4ff92SAndroid Build Coastguard Worker     {
2330*89c4ff92SAndroid Build Coastguard Worker         DataType::BFloat16,
2331*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2332*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2333*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32,
2334*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2335*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2336*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
2337*89c4ff92SAndroid Build Coastguard Worker         DataType::Boolean
2338*89c4ff92SAndroid Build Coastguard Worker     };
2339*89c4ff92SAndroid Build Coastguard Worker 
2340*89c4ff92SAndroid Build Coastguard Worker     return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2341*89c4ff92SAndroid Build Coastguard Worker         "Reference reshape: input type not supported.");
2342*89c4ff92SAndroid Build Coastguard Worker }
2343*89c4ff92SAndroid Build Coastguard Worker 
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2344*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2345*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
2346*89c4ff92SAndroid Build Coastguard Worker                                         const ResizeDescriptor& descriptor,
2347*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
2348*89c4ff92SAndroid Build Coastguard Worker {
2349*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2350*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2351*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2352*89c4ff92SAndroid Build Coastguard Worker     {
2353*89c4ff92SAndroid Build Coastguard Worker         DataType::BFloat16,
2354*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2355*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2356*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2357*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2358*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2359*89c4ff92SAndroid Build Coastguard Worker     };
2360*89c4ff92SAndroid Build Coastguard Worker 
2361*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2362*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Resize: input type not supported");
2363*89c4ff92SAndroid Build Coastguard Worker 
2364*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2365*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Resize: output type not supported");
2366*89c4ff92SAndroid Build Coastguard Worker 
2367*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2368*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Resize: input and output types not matching");
2369*89c4ff92SAndroid Build Coastguard Worker 
2370*89c4ff92SAndroid Build Coastguard Worker     return supported;
2371*89c4ff92SAndroid Build Coastguard Worker }
2372*89c4ff92SAndroid Build Coastguard Worker 
IsShapeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2373*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2374*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
2375*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
2376*89c4ff92SAndroid Build Coastguard Worker {
2377*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(input);
2378*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2379*89c4ff92SAndroid Build Coastguard Worker 
2380*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 1> supportedTypes =
2381*89c4ff92SAndroid Build Coastguard Worker     {
2382*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
2383*89c4ff92SAndroid Build Coastguard Worker     };
2384*89c4ff92SAndroid Build Coastguard Worker 
2385*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2386*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Shape: output type not supported");
2387*89c4ff92SAndroid Build Coastguard Worker 
2388*89c4ff92SAndroid Build Coastguard Worker     return supported;
2389*89c4ff92SAndroid Build Coastguard Worker }
2390*89c4ff92SAndroid Build Coastguard Worker 
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2391*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2392*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
2393*89c4ff92SAndroid Build Coastguard Worker                                        const SliceDescriptor& descriptor,
2394*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
2395*89c4ff92SAndroid Build Coastguard Worker {
2396*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2397*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2398*89c4ff92SAndroid Build Coastguard Worker 
2399*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 5> supportedTypes =
2400*89c4ff92SAndroid Build Coastguard Worker     {
2401*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2402*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2403*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2404*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2405*89c4ff92SAndroid Build Coastguard Worker     };
2406*89c4ff92SAndroid Build Coastguard Worker 
2407*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2408*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Slice: input type not supported");
2409*89c4ff92SAndroid Build Coastguard Worker 
2410*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2411*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Slice: output type not supported");
2412*89c4ff92SAndroid Build Coastguard Worker 
2413*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2414*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Slice: input and output types are mismatched");
2415*89c4ff92SAndroid Build Coastguard Worker 
2416*89c4ff92SAndroid Build Coastguard Worker     return supported;
2417*89c4ff92SAndroid Build Coastguard Worker }
2418*89c4ff92SAndroid Build Coastguard Worker 
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2419*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2420*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
2421*89c4ff92SAndroid Build Coastguard Worker                                          const SoftmaxDescriptor& descriptor,
2422*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
2423*89c4ff92SAndroid Build Coastguard Worker {
2424*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2425*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2426*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
2427*89c4ff92SAndroid Build Coastguard Worker     {
2428*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2429*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2430*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
2431*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2432*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2433*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2434*89c4ff92SAndroid Build Coastguard Worker     };
2435*89c4ff92SAndroid Build Coastguard Worker 
2436*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2437*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Softmax: output type not supported");
2438*89c4ff92SAndroid Build Coastguard Worker 
2439*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2440*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Softmax: input type not supported");
2441*89c4ff92SAndroid Build Coastguard Worker 
2442*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2443*89c4ff92SAndroid Build Coastguard Worker                                   "Reference Softmax: input type not supported");
2444*89c4ff92SAndroid Build Coastguard Worker 
2445*89c4ff92SAndroid Build Coastguard Worker     return supported;
2446*89c4ff92SAndroid Build Coastguard Worker }
2447*89c4ff92SAndroid Build Coastguard Worker 
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2448*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2449*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
2450*89c4ff92SAndroid Build Coastguard Worker                                                 const SpaceToBatchNdDescriptor& descriptor,
2451*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
2452*89c4ff92SAndroid Build Coastguard Worker {
2453*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2454*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2455*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2456*89c4ff92SAndroid Build Coastguard Worker     {
2457*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2458*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2459*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2460*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2461*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2462*89c4ff92SAndroid Build Coastguard Worker     };
2463*89c4ff92SAndroid Build Coastguard Worker 
2464*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2465*89c4ff92SAndroid Build Coastguard Worker                                   "Reference SpaceToBatchNd: input type not supported");
2466*89c4ff92SAndroid Build Coastguard Worker 
2467*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2468*89c4ff92SAndroid Build Coastguard Worker                                   "Reference SpaceToBatchNd: output type not supported");
2469*89c4ff92SAndroid Build Coastguard Worker 
2470*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2471*89c4ff92SAndroid Build Coastguard Worker                                   "Reference SpaceToBatchNd: input and output types are mismatched");
2472*89c4ff92SAndroid Build Coastguard Worker 
2473*89c4ff92SAndroid Build Coastguard Worker     return supported;
2474*89c4ff92SAndroid Build Coastguard Worker }
2475*89c4ff92SAndroid Build Coastguard Worker 
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2476*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
2477*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& output,
2478*89c4ff92SAndroid Build Coastguard Worker                                               const SpaceToDepthDescriptor& descriptor,
2479*89c4ff92SAndroid Build Coastguard Worker                                               Optional<std::string&> reasonIfUnsupported) const
2480*89c4ff92SAndroid Build Coastguard Worker {
2481*89c4ff92SAndroid Build Coastguard Worker 
2482*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2483*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2484*89c4ff92SAndroid Build Coastguard Worker 
2485*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2486*89c4ff92SAndroid Build Coastguard Worker     {
2487*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2488*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2489*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2490*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2491*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2492*89c4ff92SAndroid Build Coastguard Worker     };
2493*89c4ff92SAndroid Build Coastguard Worker 
2494*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2495*89c4ff92SAndroid Build Coastguard Worker         "Reference SpaceToDepth: input type not supported");
2496*89c4ff92SAndroid Build Coastguard Worker 
2497*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2498*89c4ff92SAndroid Build Coastguard Worker         "Reference SpaceToDepth: output type not supported");
2499*89c4ff92SAndroid Build Coastguard Worker 
2500*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2501*89c4ff92SAndroid Build Coastguard Worker         "Reference SpaceToDepth: input and output types are mismatched");
2502*89c4ff92SAndroid Build Coastguard Worker 
2503*89c4ff92SAndroid Build Coastguard Worker     return supported;
2504*89c4ff92SAndroid Build Coastguard Worker }
2505*89c4ff92SAndroid Build Coastguard Worker 
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2506*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2507*89c4ff92SAndroid Build Coastguard Worker                                           const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2508*89c4ff92SAndroid Build Coastguard Worker                                           const ViewsDescriptor& descriptor,
2509*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
2510*89c4ff92SAndroid Build Coastguard Worker {
2511*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2512*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2513*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,6> supportedTypes =
2514*89c4ff92SAndroid Build Coastguard Worker     {
2515*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2516*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2517*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2518*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2519*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2520*89c4ff92SAndroid Build Coastguard Worker     };
2521*89c4ff92SAndroid Build Coastguard Worker 
2522*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2523*89c4ff92SAndroid Build Coastguard Worker                                   "Reference splitter: output type not supported");
2524*89c4ff92SAndroid Build Coastguard Worker     for (const TensorInfo& output : outputs)
2525*89c4ff92SAndroid Build Coastguard Worker     {
2526*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2527*89c4ff92SAndroid Build Coastguard Worker                                       "Reference splitter: input type not supported");
2528*89c4ff92SAndroid Build Coastguard Worker 
2529*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2530*89c4ff92SAndroid Build Coastguard Worker                                       "Reference splitter: input and output types mismatched.");
2531*89c4ff92SAndroid Build Coastguard Worker     }
2532*89c4ff92SAndroid Build Coastguard Worker 
2533*89c4ff92SAndroid Build Coastguard Worker     return supported;
2534*89c4ff92SAndroid Build Coastguard Worker }
2535*89c4ff92SAndroid Build Coastguard Worker 
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2536*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2537*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
2538*89c4ff92SAndroid Build Coastguard Worker                                        const StackDescriptor& descriptor,
2539*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
2540*89c4ff92SAndroid Build Coastguard Worker {
2541*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2542*89c4ff92SAndroid Build Coastguard Worker 
2543*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2544*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
2545*89c4ff92SAndroid Build Coastguard Worker     {
2546*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2547*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2548*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2549*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2550*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
2551*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
2552*89c4ff92SAndroid Build Coastguard Worker     };
2553*89c4ff92SAndroid Build Coastguard Worker 
2554*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2555*89c4ff92SAndroid Build Coastguard Worker                                   "Reference stack: output type not supported");
2556*89c4ff92SAndroid Build Coastguard Worker     for (const TensorInfo* input : inputs)
2557*89c4ff92SAndroid Build Coastguard Worker     {
2558*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(input != nullptr);
2559*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2560*89c4ff92SAndroid Build Coastguard Worker             "Reference stack: input type not supported");
2561*89c4ff92SAndroid Build Coastguard Worker 
2562*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2563*89c4ff92SAndroid Build Coastguard Worker             "Reference stack: input and output types mismatched.");
2564*89c4ff92SAndroid Build Coastguard Worker     }
2565*89c4ff92SAndroid Build Coastguard Worker 
2566*89c4ff92SAndroid Build Coastguard Worker     return supported;
2567*89c4ff92SAndroid Build Coastguard Worker }
2568*89c4ff92SAndroid Build Coastguard Worker 
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2569*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2570*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& output,
2571*89c4ff92SAndroid Build Coastguard Worker                                               const StridedSliceDescriptor& descriptor,
2572*89c4ff92SAndroid Build Coastguard Worker                                               Optional<std::string&> reasonIfUnsupported) const
2573*89c4ff92SAndroid Build Coastguard Worker {
2574*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2575*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2576*89c4ff92SAndroid Build Coastguard Worker 
2577*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,5> supportedTypes =
2578*89c4ff92SAndroid Build Coastguard Worker     {
2579*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2580*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2581*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2582*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2583*89c4ff92SAndroid Build Coastguard Worker     };
2584*89c4ff92SAndroid Build Coastguard Worker 
2585*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2586*89c4ff92SAndroid Build Coastguard Worker                                   "Reference StridedSlice: input type not supported");
2587*89c4ff92SAndroid Build Coastguard Worker 
2588*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2589*89c4ff92SAndroid Build Coastguard Worker                                   "Reference StridedSlice: output type not supported");
2590*89c4ff92SAndroid Build Coastguard Worker 
2591*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2592*89c4ff92SAndroid Build Coastguard Worker                                   "Reference StridedSlice: input and output types are mismatched");
2593*89c4ff92SAndroid Build Coastguard Worker 
2594*89c4ff92SAndroid Build Coastguard Worker     return supported;
2595*89c4ff92SAndroid Build Coastguard Worker }
2596*89c4ff92SAndroid Build Coastguard Worker 
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2597*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2598*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& input1,
2599*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& output,
2600*89c4ff92SAndroid Build Coastguard Worker                                              Optional<std::string&> reasonIfUnsupported) const
2601*89c4ff92SAndroid Build Coastguard Worker {
2602*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2603*89c4ff92SAndroid Build Coastguard Worker 
2604*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes = {
2605*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2606*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2607*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2608*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2609*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16,
2610*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
2611*89c4ff92SAndroid Build Coastguard Worker     };
2612*89c4ff92SAndroid Build Coastguard Worker 
2613*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2614*89c4ff92SAndroid Build Coastguard Worker                                   "Reference subtraction: input 0 is not a supported type.");
2615*89c4ff92SAndroid Build Coastguard Worker 
2616*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2617*89c4ff92SAndroid Build Coastguard Worker                                   "Reference subtraction: input 1 is not a supported type.");
2618*89c4ff92SAndroid Build Coastguard Worker 
2619*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2620*89c4ff92SAndroid Build Coastguard Worker                                   "Reference subtraction: output is not a supported type.");
2621*89c4ff92SAndroid Build Coastguard Worker 
2622*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2623*89c4ff92SAndroid Build Coastguard Worker                                   "Reference subtraction: input 0 and Input 1 types are mismatched");
2624*89c4ff92SAndroid Build Coastguard Worker 
2625*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2626*89c4ff92SAndroid Build Coastguard Worker                                   "Reference subtraction: input and output types are mismatched");
2627*89c4ff92SAndroid Build Coastguard Worker 
2628*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2629*89c4ff92SAndroid Build Coastguard Worker                                   "Reference subtraction: shapes are not suitable for implicit broadcast.");
2630*89c4ff92SAndroid Build Coastguard Worker 
2631*89c4ff92SAndroid Build Coastguard Worker     return supported;
2632*89c4ff92SAndroid Build Coastguard Worker }
2633*89c4ff92SAndroid Build Coastguard Worker 
IsPreluSupported(const TensorInfo & input,const TensorInfo & alpha,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2634*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2635*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& alpha,
2636*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
2637*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
2638*89c4ff92SAndroid Build Coastguard Worker {
2639*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2640*89c4ff92SAndroid Build Coastguard Worker 
2641*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes
2642*89c4ff92SAndroid Build Coastguard Worker     {
2643*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2644*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2645*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2646*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2647*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2648*89c4ff92SAndroid Build Coastguard Worker     };
2649*89c4ff92SAndroid Build Coastguard Worker 
2650*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2651*89c4ff92SAndroid Build Coastguard Worker                                   "PReLU: input is not a supported type.");
2652*89c4ff92SAndroid Build Coastguard Worker 
2653*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2654*89c4ff92SAndroid Build Coastguard Worker                                   "PReLU: alpha is not a supported type.");
2655*89c4ff92SAndroid Build Coastguard Worker 
2656*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2657*89c4ff92SAndroid Build Coastguard Worker                                   "PReLU: output is not a supported type.");
2658*89c4ff92SAndroid Build Coastguard Worker 
2659*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2660*89c4ff92SAndroid Build Coastguard Worker                                   "PReLU: input, alpha and output types are mismatched");
2661*89c4ff92SAndroid Build Coastguard Worker 
2662*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2663*89c4ff92SAndroid Build Coastguard Worker                                   "PReLU: shapes are not suitable for implicit broadcast");
2664*89c4ff92SAndroid Build Coastguard Worker 
2665*89c4ff92SAndroid Build Coastguard Worker     return supported;
2666*89c4ff92SAndroid Build Coastguard Worker }
2667*89c4ff92SAndroid Build Coastguard Worker 
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const2668*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2669*89c4ff92SAndroid Build Coastguard Worker                                                         const TensorInfo& output,
2670*89c4ff92SAndroid Build Coastguard Worker                                                         const TransposeConvolution2dDescriptor& descriptor,
2671*89c4ff92SAndroid Build Coastguard Worker                                                         const TensorInfo& weights,
2672*89c4ff92SAndroid Build Coastguard Worker                                                         const Optional<TensorInfo>& biases,
2673*89c4ff92SAndroid Build Coastguard Worker                                                         Optional<std::string&> reasonIfUnsupported) const
2674*89c4ff92SAndroid Build Coastguard Worker {
2675*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2676*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2677*89c4ff92SAndroid Build Coastguard Worker 
2678*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType,7> supportedTypes =
2679*89c4ff92SAndroid Build Coastguard Worker     {
2680*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2681*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2682*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2683*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2684*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS8,
2685*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2686*89c4ff92SAndroid Build Coastguard Worker     };
2687*89c4ff92SAndroid Build Coastguard Worker 
2688*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2689*89c4ff92SAndroid Build Coastguard Worker                                   "Reference TransposeConvolution2d: input is not a supported type.");
2690*89c4ff92SAndroid Build Coastguard Worker 
2691*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2692*89c4ff92SAndroid Build Coastguard Worker                                   "Reference TransposeConvolution2d: output is not a supported type.");
2693*89c4ff92SAndroid Build Coastguard Worker 
2694*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2695*89c4ff92SAndroid Build Coastguard Worker                                   "Reference TransposeConvolution2d: input and output types mismatched.");
2696*89c4ff92SAndroid Build Coastguard Worker 
2697*89c4ff92SAndroid Build Coastguard Worker 
2698*89c4ff92SAndroid Build Coastguard Worker     const DataType inputType = input.GetDataType();
2699*89c4ff92SAndroid Build Coastguard Worker     if (IsQuantized8BitType(inputType))
2700*89c4ff92SAndroid Build Coastguard Worker     {
2701*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType, 3> supportedWeightTypes =
2702*89c4ff92SAndroid Build Coastguard Worker         {
2703*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmS8,
2704*89c4ff92SAndroid Build Coastguard Worker             DataType::QAsymmU8,
2705*89c4ff92SAndroid Build Coastguard Worker             DataType::QSymmS8
2706*89c4ff92SAndroid Build Coastguard Worker         };
2707*89c4ff92SAndroid Build Coastguard Worker 
2708*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2709*89c4ff92SAndroid Build Coastguard Worker                                       "Reference TransposeConvolution2d: weights type not supported for "
2710*89c4ff92SAndroid Build Coastguard Worker                                       "quantized input.");
2711*89c4ff92SAndroid Build Coastguard Worker     }
2712*89c4ff92SAndroid Build Coastguard Worker     else
2713*89c4ff92SAndroid Build Coastguard Worker     {
2714*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2715*89c4ff92SAndroid Build Coastguard Worker                                     "Reference TransposeConvolution2d: weights is not a supported type.");
2716*89c4ff92SAndroid Build Coastguard Worker 
2717*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2718*89c4ff92SAndroid Build Coastguard Worker                                     "Reference TransposeConvolution2d: input and weights types mismatched.");
2719*89c4ff92SAndroid Build Coastguard Worker     }
2720*89c4ff92SAndroid Build Coastguard Worker 
2721*89c4ff92SAndroid Build Coastguard Worker     if (biases.has_value())
2722*89c4ff92SAndroid Build Coastguard Worker     {
2723*89c4ff92SAndroid Build Coastguard Worker         std::array<DataType,4> biasesSupportedTypes =
2724*89c4ff92SAndroid Build Coastguard Worker         {
2725*89c4ff92SAndroid Build Coastguard Worker             DataType::Float32,
2726*89c4ff92SAndroid Build Coastguard Worker             DataType::Float16,
2727*89c4ff92SAndroid Build Coastguard Worker             DataType::Signed32
2728*89c4ff92SAndroid Build Coastguard Worker         };
2729*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2730*89c4ff92SAndroid Build Coastguard Worker                                       "Reference TransposeConvolution2d: biases is not a supported type.");
2731*89c4ff92SAndroid Build Coastguard Worker     }
2732*89c4ff92SAndroid Build Coastguard Worker 
2733*89c4ff92SAndroid Build Coastguard Worker     return supported;
2734*89c4ff92SAndroid Build Coastguard Worker }
2735*89c4ff92SAndroid Build Coastguard Worker 
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2736*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2737*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
2738*89c4ff92SAndroid Build Coastguard Worker                                            const TransposeDescriptor& descriptor,
2739*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
2740*89c4ff92SAndroid Build Coastguard Worker {
2741*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2742*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2743*89c4ff92SAndroid Build Coastguard Worker 
2744*89c4ff92SAndroid Build Coastguard Worker     // Define supported output and inputs types.
2745*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 6> supportedTypes =
2746*89c4ff92SAndroid Build Coastguard Worker     {
2747*89c4ff92SAndroid Build Coastguard Worker         DataType::BFloat16,
2748*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2749*89c4ff92SAndroid Build Coastguard Worker         DataType::Float16,
2750*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2751*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmU8,
2752*89c4ff92SAndroid Build Coastguard Worker         DataType::QSymmS16
2753*89c4ff92SAndroid Build Coastguard Worker     };
2754*89c4ff92SAndroid Build Coastguard Worker 
2755*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2756*89c4ff92SAndroid Build Coastguard Worker                                   "Reference transpose: input is not a supported type.");
2757*89c4ff92SAndroid Build Coastguard Worker 
2758*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2759*89c4ff92SAndroid Build Coastguard Worker                                   "Reference transpose: output is not a supported type.");
2760*89c4ff92SAndroid Build Coastguard Worker 
2761*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2762*89c4ff92SAndroid Build Coastguard Worker                                   "Reference transpose: input and output types are mismatched.");
2763*89c4ff92SAndroid Build Coastguard Worker 
2764*89c4ff92SAndroid Build Coastguard Worker     return supported;
2765*89c4ff92SAndroid Build Coastguard Worker }
2766*89c4ff92SAndroid Build Coastguard Worker 
IsUnidirectionalSequenceLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2767*89c4ff92SAndroid Build Coastguard Worker bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2768*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& input,
2769*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& outputStateIn,
2770*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& cellStateIn,
2771*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& outputStateOut,
2772*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& cellStateOut,
2773*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& output,
2774*89c4ff92SAndroid Build Coastguard Worker         const UnidirectionalSequenceLstmDescriptor& descriptor,
2775*89c4ff92SAndroid Build Coastguard Worker         const LstmInputParamsInfo& paramsInfo,
2776*89c4ff92SAndroid Build Coastguard Worker         Optional<std::string&> reasonIfUnsupported) const
2777*89c4ff92SAndroid Build Coastguard Worker {
2778*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
2779*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(paramsInfo);
2780*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(outputStateIn);
2781*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(cellStateIn);
2782*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(outputStateOut);
2783*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(cellStateOut);
2784*89c4ff92SAndroid Build Coastguard Worker     bool supported = true;
2785*89c4ff92SAndroid Build Coastguard Worker 
2786*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 2> supportedTypes =
2787*89c4ff92SAndroid Build Coastguard Worker     {
2788*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2789*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8
2790*89c4ff92SAndroid Build Coastguard Worker     };
2791*89c4ff92SAndroid Build Coastguard Worker 
2792*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 2> supportedWeightTypes =
2793*89c4ff92SAndroid Build Coastguard Worker     {
2794*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2795*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8
2796*89c4ff92SAndroid Build Coastguard Worker     };
2797*89c4ff92SAndroid Build Coastguard Worker 
2798*89c4ff92SAndroid Build Coastguard Worker     std::array<DataType, 3> supportedBiasTypes =
2799*89c4ff92SAndroid Build Coastguard Worker     {
2800*89c4ff92SAndroid Build Coastguard Worker         DataType::Float32,
2801*89c4ff92SAndroid Build Coastguard Worker         DataType::QAsymmS8,
2802*89c4ff92SAndroid Build Coastguard Worker         DataType::Signed32
2803*89c4ff92SAndroid Build Coastguard Worker     };
2804*89c4ff92SAndroid Build Coastguard Worker 
2805*89c4ff92SAndroid Build Coastguard Worker     // check inputs and outputs
2806*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2807*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2808*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2809*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: output is not a supported type.");
2810*89c4ff92SAndroid Build Coastguard Worker 
2811*89c4ff92SAndroid Build Coastguard Worker     // check layer parameters
2812*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2813*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
2814*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2815*89c4ff92SAndroid Build Coastguard Worker                                   "is not a supported type.");
2816*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2817*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
2818*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2819*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2820*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
2821*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2822*89c4ff92SAndroid Build Coastguard Worker                                   "is not a supported type.");
2823*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2824*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
2825*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2826*89c4ff92SAndroid Build Coastguard Worker                                   "is not a supported type.");
2827*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2828*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
2829*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2830*89c4ff92SAndroid Build Coastguard Worker                                   "is not a supported type.");
2831*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2832*89c4ff92SAndroid Build Coastguard Worker                                   reasonIfUnsupported,
2833*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2834*89c4ff92SAndroid Build Coastguard Worker                                   "is not a supported type.");
2835*89c4ff92SAndroid Build Coastguard Worker 
2836*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2837*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2838*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2839*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2840*89c4ff92SAndroid Build Coastguard Worker     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2841*89c4ff92SAndroid Build Coastguard Worker                                   "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
2842*89c4ff92SAndroid Build Coastguard Worker     if (!descriptor.m_CifgEnabled)
2843*89c4ff92SAndroid Build Coastguard Worker     {
2844*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2845*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2846*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2847*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2848*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2849*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2850*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2851*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2852*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2853*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
2854*89c4ff92SAndroid Build Coastguard Worker         if (descriptor.m_PeepholeEnabled)
2855*89c4ff92SAndroid Build Coastguard Worker         {
2856*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2857*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported,
2858*89c4ff92SAndroid Build Coastguard Worker                                           "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2859*89c4ff92SAndroid Build Coastguard Worker                                           "is not a supported type.");
2860*89c4ff92SAndroid Build Coastguard Worker         }
2861*89c4ff92SAndroid Build Coastguard Worker     }
2862*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_PeepholeEnabled)
2863*89c4ff92SAndroid Build Coastguard Worker     {
2864*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2865*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2866*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2867*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2868*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2869*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2870*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2871*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2872*89c4ff92SAndroid Build Coastguard Worker     }
2873*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_ProjectionEnabled)
2874*89c4ff92SAndroid Build Coastguard Worker     {
2875*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2876*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2877*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2878*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2879*89c4ff92SAndroid Build Coastguard Worker         if (paramsInfo.m_ProjectionBias != nullptr)
2880*89c4ff92SAndroid Build Coastguard Worker         {
2881*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2882*89c4ff92SAndroid Build Coastguard Worker                                           "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2883*89c4ff92SAndroid Build Coastguard Worker                                           "are mismatched");
2884*89c4ff92SAndroid Build Coastguard Worker         }
2885*89c4ff92SAndroid Build Coastguard Worker     }
2886*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_LayerNormEnabled)
2887*89c4ff92SAndroid Build Coastguard Worker     {
2888*89c4ff92SAndroid Build Coastguard Worker         if (!descriptor.m_CifgEnabled)
2889*89c4ff92SAndroid Build Coastguard Worker         {
2890*89c4ff92SAndroid Build Coastguard Worker             supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2891*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported,
2892*89c4ff92SAndroid Build Coastguard Worker                                           "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2893*89c4ff92SAndroid Build Coastguard Worker                                           "is not a supported type.");
2894*89c4ff92SAndroid Build Coastguard Worker         }
2895*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2896*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2897*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2898*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2899*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2900*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2901*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2902*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2903*89c4ff92SAndroid Build Coastguard Worker         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2904*89c4ff92SAndroid Build Coastguard Worker                                       reasonIfUnsupported,
2905*89c4ff92SAndroid Build Coastguard Worker                                       "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2906*89c4ff92SAndroid Build Coastguard Worker                                       "is not a supported type.");
2907*89c4ff92SAndroid Build Coastguard Worker     }
2908*89c4ff92SAndroid Build Coastguard Worker 
2909*89c4ff92SAndroid Build Coastguard Worker     return supported;
2910*89c4ff92SAndroid Build Coastguard Worker }
2911*89c4ff92SAndroid Build Coastguard Worker 
2912*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
2913