1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadInfo.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
16*89c4ff92SAndroid Build Coastguard Worker #include <iomanip>
17*89c4ff92SAndroid Build Coastguard Worker #include <string>
18*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker using namespace armnnUtils;
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker namespace armnn
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
GetBiasDataType(DataType inputDataType)28*89c4ff92SAndroid Build Coastguard Worker DataType GetBiasDataType(DataType inputDataType)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker switch (inputDataType)
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker case DataType::Float16:
33*89c4ff92SAndroid Build Coastguard Worker return DataType::Float16;
34*89c4ff92SAndroid Build Coastguard Worker case DataType::BFloat16:
35*89c4ff92SAndroid Build Coastguard Worker case DataType::Float32:
36*89c4ff92SAndroid Build Coastguard Worker return DataType::Float32;
37*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmS8:
38*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmU8:
39*89c4ff92SAndroid Build Coastguard Worker case DataType::QSymmS8:
40*89c4ff92SAndroid Build Coastguard Worker case DataType::QSymmS16:
41*89c4ff92SAndroid Build Coastguard Worker return DataType::Signed32;
42*89c4ff92SAndroid Build Coastguard Worker default:
43*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Invalid input data type");
44*89c4ff92SAndroid Build Coastguard Worker return DataType::Float32;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker namespace
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
52*89c4ff92SAndroid Build Coastguard Worker //android ndk does not support std::to_string function.
53*89c4ff92SAndroid Build Coastguard Worker template <typename T>
to_string(T value)54*89c4ff92SAndroid Build Coastguard Worker std::string to_string(T value)
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker std::ostringstream os;
57*89c4ff92SAndroid Build Coastguard Worker os << value;
58*89c4ff92SAndroid Build Coastguard Worker return os.str();
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidatePointer(const void * ptr,std::string const & descName,std::string const & paramName)62*89c4ff92SAndroid Build Coastguard Worker void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker if (!ptr)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67*89c4ff92SAndroid Build Coastguard Worker paramName + " parameter must be set.");
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorShapesMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)72*89c4ff92SAndroid Build Coastguard Worker void ValidateTensorShapesMatch(const TensorInfo& first,
73*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& second,
74*89c4ff92SAndroid Build Coastguard Worker std::string const& descName,
75*89c4ff92SAndroid Build Coastguard Worker std::string const& firstName,
76*89c4ff92SAndroid Build Coastguard Worker std::string const& secondName)
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker if (first.GetShape() != second.GetShape())
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": "
81*89c4ff92SAndroid Build Coastguard Worker + firstName + " & " + secondName + " must have identical shapes");
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateNumInputs(const WorkloadInfo & workloadInfo,std::string const & descName,const unsigned int expectedSize)86*89c4ff92SAndroid Build Coastguard Worker void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
89*89c4ff92SAndroid Build Coastguard Worker {
90*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName +
91*89c4ff92SAndroid Build Coastguard Worker ": Requires exactly " + to_string(expectedSize) + "input(s). " +
92*89c4ff92SAndroid Build Coastguard Worker to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93*89c4ff92SAndroid Build Coastguard Worker }
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateNumOutputs(const WorkloadInfo & workloadInfo,std::string const & descName,const unsigned int expectedSize)97*89c4ff92SAndroid Build Coastguard Worker void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName +
102*89c4ff92SAndroid Build Coastguard Worker ": Requires exactly " + to_string(expectedSize) + " output(s). " +
103*89c4ff92SAndroid Build Coastguard Worker to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorNumElements(const TensorInfo & tensor,std::string const & descName,unsigned int numElements,std::string const & tensorName)110*89c4ff92SAndroid Build Coastguard Worker void ValidateTensorNumElements(const TensorInfo& tensor,
111*89c4ff92SAndroid Build Coastguard Worker std::string const& descName,
112*89c4ff92SAndroid Build Coastguard Worker unsigned int numElements,
113*89c4ff92SAndroid Build Coastguard Worker std::string const& tensorName)
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker if (tensor.GetNumElements() != numElements)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
118*89c4ff92SAndroid Build Coastguard Worker to_string(tensor.GetNumElements()) + " elements for " +
119*89c4ff92SAndroid Build Coastguard Worker tensorName + " tensor.");
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorDataType(const TensorInfo & tensor,DataType dataType,const std::string & descName,std::string const & tensorName)124*89c4ff92SAndroid Build Coastguard Worker void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
125*89c4ff92SAndroid Build Coastguard Worker const std::string& descName, std::string const& tensorName)
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker if (tensor.GetDataType() != dataType)
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
130*89c4ff92SAndroid Build Coastguard Worker GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker
ValidPerAxisQuantizedDataType(const TensorInfo & tensor,const std::string & descName,const std::string & tensorName)134*89c4ff92SAndroid Build Coastguard Worker void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker if (tensor.GetDataType() != DataType::QSymmS8)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName +
139*89c4ff92SAndroid Build Coastguard Worker ": Expected data type which supports per-axis quantization scheme but got " +
140*89c4ff92SAndroid Build Coastguard Worker GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorQuantizationSpace(const TensorInfo & first,const TensorInfo & second,const std::string & descName,std::string const & firstName,std::string const & secondName)145*89c4ff92SAndroid Build Coastguard Worker void ValidateTensorQuantizationSpace(const TensorInfo& first,
146*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& second,
147*89c4ff92SAndroid Build Coastguard Worker const std::string& descName,
148*89c4ff92SAndroid Build Coastguard Worker std::string const& firstName,
149*89c4ff92SAndroid Build Coastguard Worker std::string const& secondName)
150*89c4ff92SAndroid Build Coastguard Worker {
151*89c4ff92SAndroid Build Coastguard Worker if (!first.IsQuantized() ||
152*89c4ff92SAndroid Build Coastguard Worker !second.IsQuantized())
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker // Not a quantized type, ignore the validation
155*89c4ff92SAndroid Build Coastguard Worker return;
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker DataType firstDataType = first.GetDataType();
159*89c4ff92SAndroid Build Coastguard Worker DataType secondDataType = second.GetDataType();
160*89c4ff92SAndroid Build Coastguard Worker
161*89c4ff92SAndroid Build Coastguard Worker if (firstDataType != secondDataType)
162*89c4ff92SAndroid Build Coastguard Worker {
163*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
164*89c4ff92SAndroid Build Coastguard Worker " must be of the same quantized type, " +
165*89c4ff92SAndroid Build Coastguard Worker firstName + " is " + GetDataTypeName(firstDataType) + ", " +
166*89c4ff92SAndroid Build Coastguard Worker secondName + " is " + GetDataTypeName(secondDataType));
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker if (!first.IsTypeSpaceMatch(second))
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172*89c4ff92SAndroid Build Coastguard Worker " must have the same quantization space, " +
173*89c4ff92SAndroid Build Coastguard Worker firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
174*89c4ff92SAndroid Build Coastguard Worker " and scale " + to_string(first.GetQuantizationScale()) + ", " +
175*89c4ff92SAndroid Build Coastguard Worker secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
176*89c4ff92SAndroid Build Coastguard Worker " and scale " + to_string(second.GetQuantizationScale()));
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateBiasTensorQuantization(const TensorInfo & biasTensor,const TensorInfo & inputTensorInfo,const TensorInfo & weightsTensorInfo,const std::string & descName)181*89c4ff92SAndroid Build Coastguard Worker void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
182*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo,
183*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightsTensorInfo,
184*89c4ff92SAndroid Build Coastguard Worker const std::string& descName)
185*89c4ff92SAndroid Build Coastguard Worker {
186*89c4ff92SAndroid Build Coastguard Worker // Helper lambda function to validate a single bias quantization scale value
187*89c4ff92SAndroid Build Coastguard Worker auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
188*89c4ff92SAndroid Build Coastguard Worker {
189*89c4ff92SAndroid Build Coastguard Worker constexpr float tolerance = 0.0001f;
190*89c4ff92SAndroid Build Coastguard Worker if (std::abs(biasScale - expectedScale) > tolerance)
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker // Print the float values with extra precision to see very small differences
193*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
194*89c4ff92SAndroid Build Coastguard Worker " for bias quantization scale (product of input and weight scales), but got " <<
195*89c4ff92SAndroid Build Coastguard Worker biasScale << ". Using scale provided.";
196*89c4ff92SAndroid Build Coastguard Worker }
197*89c4ff92SAndroid Build Coastguard Worker };
198*89c4ff92SAndroid Build Coastguard Worker
199*89c4ff92SAndroid Build Coastguard Worker if (biasTensor.GetQuantizationOffset() != 0)
200*89c4ff92SAndroid Build Coastguard Worker {
201*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
202*89c4ff92SAndroid Build Coastguard Worker to_string(biasTensor.GetQuantizationOffset()));
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker
205*89c4ff92SAndroid Build Coastguard Worker if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
206*89c4ff92SAndroid Build Coastguard Worker {
207*89c4ff92SAndroid Build Coastguard Worker // Validate per-axis quantization scales
208*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
209*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker if (weightScales.size() != biasScales.size())
212*89c4ff92SAndroid Build Coastguard Worker {
213*89c4ff92SAndroid Build Coastguard Worker std::stringstream msg;
214*89c4ff92SAndroid Build Coastguard Worker msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
215*89c4ff92SAndroid Build Coastguard Worker << "but got different values. This is currently unsupported: weights=" << weightScales.size()
216*89c4ff92SAndroid Build Coastguard Worker << ", biases=" << biasScales.size();
217*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0ul; i < biasScales.size(); ++i)
221*89c4ff92SAndroid Build Coastguard Worker {
222*89c4ff92SAndroid Build Coastguard Worker const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
223*89c4ff92SAndroid Build Coastguard Worker VerifyBiasQuantizationScale(biasScales[i], expectedScale);
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker else
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker // Validate per-tensor quantization scale
229*89c4ff92SAndroid Build Coastguard Worker const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
230*89c4ff92SAndroid Build Coastguard Worker VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
231*89c4ff92SAndroid Build Coastguard Worker }
232*89c4ff92SAndroid Build Coastguard Worker }
233*89c4ff92SAndroid Build Coastguard Worker
234*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensors(const std::vector<ITensorHandle * > & vec,unsigned int numExpected,const std::string & descName,const std::string & varName)235*89c4ff92SAndroid Build Coastguard Worker void ValidateTensors(const std::vector<ITensorHandle*>& vec,
236*89c4ff92SAndroid Build Coastguard Worker unsigned int numExpected,
237*89c4ff92SAndroid Build Coastguard Worker const std::string& descName,
238*89c4ff92SAndroid Build Coastguard Worker const std::string& varName)
239*89c4ff92SAndroid Build Coastguard Worker {
240*89c4ff92SAndroid Build Coastguard Worker if (vec.empty() && numExpected > 0)
241*89c4ff92SAndroid Build Coastguard Worker {
242*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
243*89c4ff92SAndroid Build Coastguard Worker }
244*89c4ff92SAndroid Build Coastguard Worker
245*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numExpected; ++i)
246*89c4ff92SAndroid Build Coastguard Worker {
247*89c4ff92SAndroid Build Coastguard Worker if (!vec[i])
248*89c4ff92SAndroid Build Coastguard Worker {
249*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
250*89c4ff92SAndroid Build Coastguard Worker }
251*89c4ff92SAndroid Build Coastguard Worker }
252*89c4ff92SAndroid Build Coastguard Worker }
253*89c4ff92SAndroid Build Coastguard Worker
254*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateBroadcastTensorShapesMatch(const TensorInfo & first,const TensorInfo & second,const TensorInfo & output,std::string const & descName,std::string const & firstName,std::string const & secondName)255*89c4ff92SAndroid Build Coastguard Worker void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
256*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& second,
257*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
258*89c4ff92SAndroid Build Coastguard Worker std::string const& descName,
259*89c4ff92SAndroid Build Coastguard Worker std::string const& firstName,
260*89c4ff92SAndroid Build Coastguard Worker std::string const& secondName)
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
263*89c4ff92SAndroid Build Coastguard Worker // broadcasted.
264*89c4ff92SAndroid Build Coastguard Worker if (first.GetNumDimensions() != second.GetNumDimensions())
265*89c4ff92SAndroid Build Coastguard Worker {
266*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Tensors "
267*89c4ff92SAndroid Build Coastguard Worker + firstName + " & " + secondName
268*89c4ff92SAndroid Build Coastguard Worker + " must have the same number of dimensions in order to be broadcasted");
269*89c4ff92SAndroid Build Coastguard Worker }
270*89c4ff92SAndroid Build Coastguard Worker uint32_t numDims = first.GetNumDimensions();
271*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> outputDims(numDims, 0u);
272*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < numDims; i++)
273*89c4ff92SAndroid Build Coastguard Worker {
274*89c4ff92SAndroid Build Coastguard Worker const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
275*89c4ff92SAndroid Build Coastguard Worker const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
276*89c4ff92SAndroid Build Coastguard Worker if (dimsNotEqual && dimsNotOne)
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
281*89c4ff92SAndroid Build Coastguard Worker }
282*89c4ff92SAndroid Build Coastguard Worker TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
283*89c4ff92SAndroid Build Coastguard Worker if (broadcastShape != output.GetShape())
284*89c4ff92SAndroid Build Coastguard Worker {
285*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
286*89c4ff92SAndroid Build Coastguard Worker + firstName + " & " + secondName
287*89c4ff92SAndroid Build Coastguard Worker + " does not match the output shape");
288*89c4ff92SAndroid Build Coastguard Worker }
289*89c4ff92SAndroid Build Coastguard Worker }
290*89c4ff92SAndroid Build Coastguard Worker
291*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateDataTypes(const TensorInfo & info,const std::vector<armnn::DataType> & supportedTypes,std::string const & descName)292*89c4ff92SAndroid Build Coastguard Worker void ValidateDataTypes(const TensorInfo& info,
293*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::DataType>& supportedTypes,
294*89c4ff92SAndroid Build Coastguard Worker std::string const& descName)
295*89c4ff92SAndroid Build Coastguard Worker {
296*89c4ff92SAndroid Build Coastguard Worker auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
297*89c4ff92SAndroid Build Coastguard Worker if (iterator == supportedTypes.end())
298*89c4ff92SAndroid Build Coastguard Worker {
299*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
300*89c4ff92SAndroid Build Coastguard Worker }
301*89c4ff92SAndroid Build Coastguard Worker }
302*89c4ff92SAndroid Build Coastguard Worker
303*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorDataTypesMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)304*89c4ff92SAndroid Build Coastguard Worker void ValidateTensorDataTypesMatch(const TensorInfo& first,
305*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& second,
306*89c4ff92SAndroid Build Coastguard Worker std::string const& descName,
307*89c4ff92SAndroid Build Coastguard Worker std::string const& firstName,
308*89c4ff92SAndroid Build Coastguard Worker std::string const& secondName)
309*89c4ff92SAndroid Build Coastguard Worker {
310*89c4ff92SAndroid Build Coastguard Worker if (first.GetDataType() != second.GetDataType())
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
313*89c4ff92SAndroid Build Coastguard Worker " must have identical data types.");
314*89c4ff92SAndroid Build Coastguard Worker }
315*89c4ff92SAndroid Build Coastguard Worker }
316*89c4ff92SAndroid Build Coastguard Worker
317*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorNumElementsMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)318*89c4ff92SAndroid Build Coastguard Worker void ValidateTensorNumElementsMatch(const TensorInfo& first,
319*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& second,
320*89c4ff92SAndroid Build Coastguard Worker std::string const& descName,
321*89c4ff92SAndroid Build Coastguard Worker std::string const& firstName,
322*89c4ff92SAndroid Build Coastguard Worker std::string const& secondName)
323*89c4ff92SAndroid Build Coastguard Worker {
324*89c4ff92SAndroid Build Coastguard Worker if (first.GetNumElements() != second.GetNumElements())
325*89c4ff92SAndroid Build Coastguard Worker {
326*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
327*89c4ff92SAndroid Build Coastguard Worker " must have the same number of elements.");
328*89c4ff92SAndroid Build Coastguard Worker }
329*89c4ff92SAndroid Build Coastguard Worker }
330*89c4ff92SAndroid Build Coastguard Worker
ValidateWeightDataType(const TensorInfo & inputInfo,const TensorInfo & weightInfo,const std::string & descName)331*89c4ff92SAndroid Build Coastguard Worker void ValidateWeightDataType(const TensorInfo& inputInfo,
332*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightInfo,
333*89c4ff92SAndroid Build Coastguard Worker const std::string& descName)
334*89c4ff92SAndroid Build Coastguard Worker {
335*89c4ff92SAndroid Build Coastguard Worker const DataType inputType = inputInfo.GetDataType();
336*89c4ff92SAndroid Build Coastguard Worker if (IsQuantized8BitType(inputType))
337*89c4ff92SAndroid Build Coastguard Worker {
338*89c4ff92SAndroid Build Coastguard Worker const std::vector<DataType> validTypes =
339*89c4ff92SAndroid Build Coastguard Worker {
340*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
341*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
342*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
343*89c4ff92SAndroid Build Coastguard Worker };
344*89c4ff92SAndroid Build Coastguard Worker
345*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(weightInfo, validTypes, descName);
346*89c4ff92SAndroid Build Coastguard Worker }
347*89c4ff92SAndroid Build Coastguard Worker else
348*89c4ff92SAndroid Build Coastguard Worker {
349*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
350*89c4ff92SAndroid Build Coastguard Worker }
351*89c4ff92SAndroid Build Coastguard Worker }
352*89c4ff92SAndroid Build Coastguard Worker
ValidatePerAxisQuantizationDimension(const TensorInfo & tensorInfo,const std::string & descName,const std::string & tensorName)353*89c4ff92SAndroid Build Coastguard Worker void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
354*89c4ff92SAndroid Build Coastguard Worker const std::string& descName,
355*89c4ff92SAndroid Build Coastguard Worker const std::string& tensorName)
356*89c4ff92SAndroid Build Coastguard Worker {
357*89c4ff92SAndroid Build Coastguard Worker const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
358*89c4ff92SAndroid Build Coastguard Worker if (!quantizationDim.has_value())
359*89c4ff92SAndroid Build Coastguard Worker {
360*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
361*89c4ff92SAndroid Build Coastguard Worker "not set on tensor {1}.", descName, tensorName));
362*89c4ff92SAndroid Build Coastguard Worker }
363*89c4ff92SAndroid Build Coastguard Worker }
364*89c4ff92SAndroid Build Coastguard Worker
ValidatePerAxisQuantizationOffset(const TensorInfo & tensorInfo,const std::string & descName,const std::string & tensorName)365*89c4ff92SAndroid Build Coastguard Worker void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
366*89c4ff92SAndroid Build Coastguard Worker const std::string& descName,
367*89c4ff92SAndroid Build Coastguard Worker const std::string& tensorName)
368*89c4ff92SAndroid Build Coastguard Worker {
369*89c4ff92SAndroid Build Coastguard Worker int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
370*89c4ff92SAndroid Build Coastguard Worker if (quantizationOffset != 0)
371*89c4ff92SAndroid Build Coastguard Worker {
372*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
373*89c4ff92SAndroid Build Coastguard Worker "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
374*89c4ff92SAndroid Build Coastguard Worker descName, tensorName, quantizationOffset));
375*89c4ff92SAndroid Build Coastguard Worker }
376*89c4ff92SAndroid Build Coastguard Worker }
377*89c4ff92SAndroid Build Coastguard Worker
ValidatePerAxisQuantization(const TensorInfo & inputInfo,const TensorInfo & outputInfo,const TensorInfo & weightInfo,const Optional<TensorInfo> & optionalBiasInfo,const std::string & descName)378*89c4ff92SAndroid Build Coastguard Worker void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
379*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo,
380*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightInfo,
381*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& optionalBiasInfo,
382*89c4ff92SAndroid Build Coastguard Worker const std::string& descName)
383*89c4ff92SAndroid Build Coastguard Worker {
384*89c4ff92SAndroid Build Coastguard Worker if (weightInfo.HasPerAxisQuantization())
385*89c4ff92SAndroid Build Coastguard Worker {
386*89c4ff92SAndroid Build Coastguard Worker const DataType inputDataType = inputInfo.GetDataType();
387*89c4ff92SAndroid Build Coastguard Worker const DataType outputDataType = outputInfo.GetDataType();
388*89c4ff92SAndroid Build Coastguard Worker
389*89c4ff92SAndroid Build Coastguard Worker const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
390*89c4ff92SAndroid Build Coastguard Worker
391*89c4ff92SAndroid Build Coastguard Worker if (!canHavePerAxisQuantization)
392*89c4ff92SAndroid Build Coastguard Worker {
393*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
394*89c4ff92SAndroid Build Coastguard Worker "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
395*89c4ff92SAndroid Build Coastguard Worker "per-axis quantization.", descName, "weight"));
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker
398*89c4ff92SAndroid Build Coastguard Worker
399*89c4ff92SAndroid Build Coastguard Worker ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
400*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
401*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
402*89c4ff92SAndroid Build Coastguard Worker
403*89c4ff92SAndroid Build Coastguard Worker if (optionalBiasInfo.has_value())
404*89c4ff92SAndroid Build Coastguard Worker {
405*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biasInfo = optionalBiasInfo.value();
406*89c4ff92SAndroid Build Coastguard Worker if (!biasInfo.HasPerAxisQuantization())
407*89c4ff92SAndroid Build Coastguard Worker {
408*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
409*89c4ff92SAndroid Build Coastguard Worker "{}: Per-axis quantization parameters not set on bias tensor, "
410*89c4ff92SAndroid Build Coastguard Worker "despite being set on weight tensor.", descName));
411*89c4ff92SAndroid Build Coastguard Worker }
412*89c4ff92SAndroid Build Coastguard Worker
413*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
414*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
415*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
416*89c4ff92SAndroid Build Coastguard Worker }
417*89c4ff92SAndroid Build Coastguard Worker }
418*89c4ff92SAndroid Build Coastguard Worker }
419*89c4ff92SAndroid Build Coastguard Worker
420*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
421*89c4ff92SAndroid Build Coastguard Worker
422*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorNumDimensions(const TensorInfo & tensor,std::string const & descName,unsigned int numDimensions,std::string const & tensorName) const423*89c4ff92SAndroid Build Coastguard Worker void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
424*89c4ff92SAndroid Build Coastguard Worker std::string const& descName,
425*89c4ff92SAndroid Build Coastguard Worker unsigned int numDimensions,
426*89c4ff92SAndroid Build Coastguard Worker std::string const& tensorName) const
427*89c4ff92SAndroid Build Coastguard Worker {
428*89c4ff92SAndroid Build Coastguard Worker // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
429*89c4ff92SAndroid Build Coastguard Worker // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
430*89c4ff92SAndroid Build Coastguard Worker // numDimensions.
431*89c4ff92SAndroid Build Coastguard Worker if (m_AllowExpandedDims)
432*89c4ff92SAndroid Build Coastguard Worker {
433*89c4ff92SAndroid Build Coastguard Worker unsigned int squeezedDims = 0;
434*89c4ff92SAndroid Build Coastguard Worker
435*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
436*89c4ff92SAndroid Build Coastguard Worker {
437*89c4ff92SAndroid Build Coastguard Worker if (tensor.GetShape()[i] != 1)
438*89c4ff92SAndroid Build Coastguard Worker {
439*89c4ff92SAndroid Build Coastguard Worker ++squeezedDims;
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker }
442*89c4ff92SAndroid Build Coastguard Worker if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
443*89c4ff92SAndroid Build Coastguard Worker {
444*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
445*89c4ff92SAndroid Build Coastguard Worker to_string(tensor.GetNumDimensions()) + " dimensions for " +
446*89c4ff92SAndroid Build Coastguard Worker tensorName + " tensor.");
447*89c4ff92SAndroid Build Coastguard Worker }
448*89c4ff92SAndroid Build Coastguard Worker }
449*89c4ff92SAndroid Build Coastguard Worker else
450*89c4ff92SAndroid Build Coastguard Worker {
451*89c4ff92SAndroid Build Coastguard Worker if (tensor.GetNumDimensions() != numDimensions)
452*89c4ff92SAndroid Build Coastguard Worker {
453*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
454*89c4ff92SAndroid Build Coastguard Worker to_string(tensor.GetNumDimensions()) + " dimensions for " +
455*89c4ff92SAndroid Build Coastguard Worker tensorName + " tensor.");
456*89c4ff92SAndroid Build Coastguard Worker }
457*89c4ff92SAndroid Build Coastguard Worker }
458*89c4ff92SAndroid Build Coastguard Worker }
459*89c4ff92SAndroid Build Coastguard Worker
460*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateTensorNumDimNumElem(const TensorInfo & tensorInfo,unsigned int numDimension,unsigned int numElements,std::string const & tensorName) const461*89c4ff92SAndroid Build Coastguard Worker void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
462*89c4ff92SAndroid Build Coastguard Worker unsigned int numDimension,
463*89c4ff92SAndroid Build Coastguard Worker unsigned int numElements,
464*89c4ff92SAndroid Build Coastguard Worker std::string const& tensorName) const
465*89c4ff92SAndroid Build Coastguard Worker {
466*89c4ff92SAndroid Build Coastguard Worker const std::string functionName{"ValidateTensorNumDimNumElem"};
467*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
468*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
469*89c4ff92SAndroid Build Coastguard Worker }
470*89c4ff92SAndroid Build Coastguard Worker
471*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
ValidateInputsOutputs(const std::string & descName,unsigned int numExpectedIn,unsigned int numExpectedOut) const472*89c4ff92SAndroid Build Coastguard Worker void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
473*89c4ff92SAndroid Build Coastguard Worker unsigned int numExpectedIn, unsigned int numExpectedOut) const
474*89c4ff92SAndroid Build Coastguard Worker {
475*89c4ff92SAndroid Build Coastguard Worker ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
476*89c4ff92SAndroid Build Coastguard Worker ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
477*89c4ff92SAndroid Build Coastguard Worker }
478*89c4ff92SAndroid Build Coastguard Worker
479*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const480*89c4ff92SAndroid Build Coastguard Worker void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
481*89c4ff92SAndroid Build Coastguard Worker {
482*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"MapQueueDescriptor"};
483*89c4ff92SAndroid Build Coastguard Worker
484*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
485*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 0);
486*89c4ff92SAndroid Build Coastguard Worker
487*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_Inputs.size(); ++i)
488*89c4ff92SAndroid Build Coastguard Worker {
489*89c4ff92SAndroid Build Coastguard Worker if (!m_Inputs[i])
490*89c4ff92SAndroid Build Coastguard Worker {
491*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
492*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
493*89c4ff92SAndroid Build Coastguard Worker }
494*89c4ff92SAndroid Build Coastguard Worker }
495*89c4ff92SAndroid Build Coastguard Worker }
496*89c4ff92SAndroid Build Coastguard Worker
497*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const498*89c4ff92SAndroid Build Coastguard Worker void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
499*89c4ff92SAndroid Build Coastguard Worker {
500*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"UnmapQueueDescriptor"};
501*89c4ff92SAndroid Build Coastguard Worker
502*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
503*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 0);
504*89c4ff92SAndroid Build Coastguard Worker
505*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_Inputs.size(); ++i)
506*89c4ff92SAndroid Build Coastguard Worker {
507*89c4ff92SAndroid Build Coastguard Worker if (!m_Inputs[i])
508*89c4ff92SAndroid Build Coastguard Worker {
509*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
510*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
511*89c4ff92SAndroid Build Coastguard Worker }
512*89c4ff92SAndroid Build Coastguard Worker }
513*89c4ff92SAndroid Build Coastguard Worker }
514*89c4ff92SAndroid Build Coastguard Worker
515*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const516*89c4ff92SAndroid Build Coastguard Worker void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
517*89c4ff92SAndroid Build Coastguard Worker {
518*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"MemCopyQueueDescriptor"};
519*89c4ff92SAndroid Build Coastguard Worker
520*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
521*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName , 1);
522*89c4ff92SAndroid Build Coastguard Worker
523*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
524*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
525*89c4ff92SAndroid Build Coastguard Worker
526*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
527*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
528*89c4ff92SAndroid Build Coastguard Worker
529*89c4ff92SAndroid Build Coastguard Worker if (m_Inputs.size() != m_Outputs.size())
530*89c4ff92SAndroid Build Coastguard Worker {
531*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
532*89c4ff92SAndroid Build Coastguard Worker "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
533*89c4ff92SAndroid Build Coastguard Worker descriptorName, m_Inputs.size(), m_Outputs.size()));
534*89c4ff92SAndroid Build Coastguard Worker }
535*89c4ff92SAndroid Build Coastguard Worker
536*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_Inputs.size(); ++i)
537*89c4ff92SAndroid Build Coastguard Worker {
538*89c4ff92SAndroid Build Coastguard Worker if (!m_Inputs[i])
539*89c4ff92SAndroid Build Coastguard Worker {
540*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
541*89c4ff92SAndroid Build Coastguard Worker "{0}: Invalid NULL input {1}.", descriptorName, i));
542*89c4ff92SAndroid Build Coastguard Worker }
543*89c4ff92SAndroid Build Coastguard Worker
544*89c4ff92SAndroid Build Coastguard Worker if (!m_Outputs[i])
545*89c4ff92SAndroid Build Coastguard Worker {
546*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
547*89c4ff92SAndroid Build Coastguard Worker }
548*89c4ff92SAndroid Build Coastguard Worker }
549*89c4ff92SAndroid Build Coastguard Worker }
550*89c4ff92SAndroid Build Coastguard Worker
551*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const552*89c4ff92SAndroid Build Coastguard Worker void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
553*89c4ff92SAndroid Build Coastguard Worker {
554*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
555*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
556*89c4ff92SAndroid Build Coastguard Worker
557*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() != 1)
558*89c4ff92SAndroid Build Coastguard Worker {
559*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
560*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_InputTensorInfos.size()));
561*89c4ff92SAndroid Build Coastguard Worker
562*89c4ff92SAndroid Build Coastguard Worker }
563*89c4ff92SAndroid Build Coastguard Worker
564*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
565*89c4ff92SAndroid Build Coastguard Worker {
566*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
567*89c4ff92SAndroid Build Coastguard Worker "Number of input infos ({0}) does not match the number of output infos ({1})",
568*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
569*89c4ff92SAndroid Build Coastguard Worker }
570*89c4ff92SAndroid Build Coastguard Worker
571*89c4ff92SAndroid Build Coastguard Worker for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
572*89c4ff92SAndroid Build Coastguard Worker {
573*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
574*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_OutputTensorInfos[i].GetNumElements())
575*89c4ff92SAndroid Build Coastguard Worker {
576*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
577*89c4ff92SAndroid Build Coastguard Worker "Number of elements for tensor input and output {} does not match", i ));
578*89c4ff92SAndroid Build Coastguard Worker }
579*89c4ff92SAndroid Build Coastguard Worker }
580*89c4ff92SAndroid Build Coastguard Worker
581*89c4ff92SAndroid Build Coastguard Worker if (m_Inputs.size() != 1)
582*89c4ff92SAndroid Build Coastguard Worker {
583*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
584*89c4ff92SAndroid Build Coastguard Worker }
585*89c4ff92SAndroid Build Coastguard Worker
586*89c4ff92SAndroid Build Coastguard Worker if (m_Inputs.size() != m_Outputs.size())
587*89c4ff92SAndroid Build Coastguard Worker {
588*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
589*89c4ff92SAndroid Build Coastguard Worker "Number of inputs ({0}) does not match the number of outputs ({1})",
590*89c4ff92SAndroid Build Coastguard Worker m_Inputs.size(), m_Outputs.size()));
591*89c4ff92SAndroid Build Coastguard Worker }
592*89c4ff92SAndroid Build Coastguard Worker
593*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_Inputs.size(); ++i)
594*89c4ff92SAndroid Build Coastguard Worker {
595*89c4ff92SAndroid Build Coastguard Worker if (!m_Inputs[i])
596*89c4ff92SAndroid Build Coastguard Worker {
597*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
598*89c4ff92SAndroid Build Coastguard Worker }
599*89c4ff92SAndroid Build Coastguard Worker
600*89c4ff92SAndroid Build Coastguard Worker if (!m_Outputs[i])
601*89c4ff92SAndroid Build Coastguard Worker {
602*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
603*89c4ff92SAndroid Build Coastguard Worker }
604*89c4ff92SAndroid Build Coastguard Worker }
605*89c4ff92SAndroid Build Coastguard Worker }
606*89c4ff92SAndroid Build Coastguard Worker
607*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const608*89c4ff92SAndroid Build Coastguard Worker void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
609*89c4ff92SAndroid Build Coastguard Worker {
610*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
611*89c4ff92SAndroid Build Coastguard Worker
612*89c4ff92SAndroid Build Coastguard Worker if (m_Inputs.size() != 1)
613*89c4ff92SAndroid Build Coastguard Worker {
614*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
615*89c4ff92SAndroid Build Coastguard Worker }
616*89c4ff92SAndroid Build Coastguard Worker
617*89c4ff92SAndroid Build Coastguard Worker if (m_Outputs.size() != 0)
618*89c4ff92SAndroid Build Coastguard Worker {
619*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
620*89c4ff92SAndroid Build Coastguard Worker }
621*89c4ff92SAndroid Build Coastguard Worker
622*89c4ff92SAndroid Build Coastguard Worker if (!m_Inputs[0])
623*89c4ff92SAndroid Build Coastguard Worker {
624*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Invalid null input 0"));
625*89c4ff92SAndroid Build Coastguard Worker }
626*89c4ff92SAndroid Build Coastguard Worker }
627*89c4ff92SAndroid Build Coastguard Worker
628*89c4ff92SAndroid Build Coastguard Worker //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const629*89c4ff92SAndroid Build Coastguard Worker void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630*89c4ff92SAndroid Build Coastguard Worker {
631*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ActivationQueueDescriptor"};
632*89c4ff92SAndroid Build Coastguard Worker
633*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
634*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
635*89c4ff92SAndroid Build Coastguard Worker
636*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
637*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
638*89c4ff92SAndroid Build Coastguard Worker
639*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
640*89c4ff92SAndroid Build Coastguard Worker {
641*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
642*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
643*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
644*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
645*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
646*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
647*89c4ff92SAndroid Build Coastguard Worker };
648*89c4ff92SAndroid Build Coastguard Worker
649*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
650*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
651*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
652*89c4ff92SAndroid Build Coastguard Worker }
653*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const654*89c4ff92SAndroid Build Coastguard Worker void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
655*89c4ff92SAndroid Build Coastguard Worker {
656*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
657*89c4ff92SAndroid Build Coastguard Worker
658*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
659*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
660*89c4ff92SAndroid Build Coastguard Worker
661*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
662*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
663*89c4ff92SAndroid Build Coastguard Worker
664*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
665*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.GetDataType() != DataType::Signed64)
666*89c4ff92SAndroid Build Coastguard Worker {
667*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
668*89c4ff92SAndroid Build Coastguard Worker }
669*89c4ff92SAndroid Build Coastguard Worker
670*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedInputTypes =
671*89c4ff92SAndroid Build Coastguard Worker {
672*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
673*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
674*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
675*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
676*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
677*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
678*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
679*89c4ff92SAndroid Build Coastguard Worker DataType::Signed64
680*89c4ff92SAndroid Build Coastguard Worker };
681*89c4ff92SAndroid Build Coastguard Worker
682*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
683*89c4ff92SAndroid Build Coastguard Worker
684*89c4ff92SAndroid Build Coastguard Worker auto inputShape = inputTensorInfo.GetShape();
685*89c4ff92SAndroid Build Coastguard Worker auto outputShape = outputTensorInfo.GetShape();
686*89c4ff92SAndroid Build Coastguard Worker
687*89c4ff92SAndroid Build Coastguard Worker auto inputNumDimensions = inputShape.GetNumDimensions();
688*89c4ff92SAndroid Build Coastguard Worker auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
689*89c4ff92SAndroid Build Coastguard Worker
690*89c4ff92SAndroid Build Coastguard Worker const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
691*89c4ff92SAndroid Build Coastguard Worker
692*89c4ff92SAndroid Build Coastguard Worker // 1D input shape results in scalar output shape
693*89c4ff92SAndroid Build Coastguard Worker if (inputShape.GetNumDimensions() == 1)
694*89c4ff92SAndroid Build Coastguard Worker {
695*89c4ff92SAndroid Build Coastguard Worker if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
696*89c4ff92SAndroid Build Coastguard Worker {
697*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + outputShapeError);
698*89c4ff92SAndroid Build Coastguard Worker }
699*89c4ff92SAndroid Build Coastguard Worker }
700*89c4ff92SAndroid Build Coastguard Worker else
701*89c4ff92SAndroid Build Coastguard Worker {
702*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < unsignedAxis; ++i)
703*89c4ff92SAndroid Build Coastguard Worker {
704*89c4ff92SAndroid Build Coastguard Worker if (outputShape[i] != inputShape[i])
705*89c4ff92SAndroid Build Coastguard Worker {
706*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + outputShapeError);
707*89c4ff92SAndroid Build Coastguard Worker }
708*89c4ff92SAndroid Build Coastguard Worker }
709*89c4ff92SAndroid Build Coastguard Worker
710*89c4ff92SAndroid Build Coastguard Worker for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
711*89c4ff92SAndroid Build Coastguard Worker {
712*89c4ff92SAndroid Build Coastguard Worker if (outputShape[i - 1] != inputShape[i])
713*89c4ff92SAndroid Build Coastguard Worker {
714*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + outputShapeError);
715*89c4ff92SAndroid Build Coastguard Worker }
716*89c4ff92SAndroid Build Coastguard Worker }
717*89c4ff92SAndroid Build Coastguard Worker }
718*89c4ff92SAndroid Build Coastguard Worker }
719*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const720*89c4ff92SAndroid Build Coastguard Worker void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
721*89c4ff92SAndroid Build Coastguard Worker {
722*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"CastQueueDescriptor"};
723*89c4ff92SAndroid Build Coastguard Worker
724*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
725*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
726*89c4ff92SAndroid Build Coastguard Worker
727*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
728*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
729*89c4ff92SAndroid Build Coastguard Worker
730*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
731*89c4ff92SAndroid Build Coastguard Worker {
732*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
733*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
734*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
735*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
736*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
737*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
738*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
739*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
740*89c4ff92SAndroid Build Coastguard Worker DataType::Signed64
741*89c4ff92SAndroid Build Coastguard Worker };
742*89c4ff92SAndroid Build Coastguard Worker
743*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
744*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745*89c4ff92SAndroid Build Coastguard Worker }
746*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const747*89c4ff92SAndroid Build Coastguard Worker void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
748*89c4ff92SAndroid Build Coastguard Worker {
749*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"SoftmaxQueueDescriptor"};
750*89c4ff92SAndroid Build Coastguard Worker
751*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
752*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
753*89c4ff92SAndroid Build Coastguard Worker
754*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
755*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
756*89c4ff92SAndroid Build Coastguard Worker
757*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
758*89c4ff92SAndroid Build Coastguard Worker {
759*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
760*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
761*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
762*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
763*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
764*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
765*89c4ff92SAndroid Build Coastguard Worker };
766*89c4ff92SAndroid Build Coastguard Worker
767*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
768*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
769*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
770*89c4ff92SAndroid Build Coastguard Worker }
771*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const772*89c4ff92SAndroid Build Coastguard Worker void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
773*89c4ff92SAndroid Build Coastguard Worker {
774*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"SplitterQueueDescriptor"};
775*89c4ff92SAndroid Build Coastguard Worker
776*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
777*89c4ff92SAndroid Build Coastguard Worker
778*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
779*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
780*89c4ff92SAndroid Build Coastguard Worker {
781*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
782*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
783*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
784*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean,
785*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
786*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
787*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
788*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
789*89c4ff92SAndroid Build Coastguard Worker };
790*89c4ff92SAndroid Build Coastguard Worker
791*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
792*89c4ff92SAndroid Build Coastguard Worker for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
793*89c4ff92SAndroid Build Coastguard Worker {
794*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
795*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
796*89c4ff92SAndroid Build Coastguard Worker
797*89c4ff92SAndroid Build Coastguard Worker const std::string outputName = "output_" + std::to_string(i);
798*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
799*89c4ff92SAndroid Build Coastguard Worker }
800*89c4ff92SAndroid Build Coastguard Worker
801*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() <= 0)
802*89c4ff92SAndroid Build Coastguard Worker {
803*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
804*89c4ff92SAndroid Build Coastguard Worker }
805*89c4ff92SAndroid Build Coastguard Worker
806*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
807*89c4ff92SAndroid Build Coastguard Worker {
808*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
809*89c4ff92SAndroid Build Coastguard Worker descriptorName + ": Number of split windows "
810*89c4ff92SAndroid Build Coastguard Worker "has to match number of workloadInfo.m_OutputTensorInfos. "
811*89c4ff92SAndroid Build Coastguard Worker "Number of windows: " +
812*89c4ff92SAndroid Build Coastguard Worker to_string(m_ViewOrigins.size()) +
813*89c4ff92SAndroid Build Coastguard Worker ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
814*89c4ff92SAndroid Build Coastguard Worker }
815*89c4ff92SAndroid Build Coastguard Worker
816*89c4ff92SAndroid Build Coastguard Worker //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
817*89c4ff92SAndroid Build Coastguard Worker std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
818*89c4ff92SAndroid Build Coastguard Worker for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
819*89c4ff92SAndroid Build Coastguard Worker {
820*89c4ff92SAndroid Build Coastguard Worker //Checks that the dimensionality of input is same as the split windows.
821*89c4ff92SAndroid Build Coastguard Worker ViewOrigin const& e = m_ViewOrigins[w];
822*89c4ff92SAndroid Build Coastguard Worker if (e.m_Origin.size() != inputDims)
823*89c4ff92SAndroid Build Coastguard Worker {
824*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Window origin have to "
825*89c4ff92SAndroid Build Coastguard Worker "have the same dimensionality as the input tensor. "
826*89c4ff92SAndroid Build Coastguard Worker "Window origin (index: " +
827*89c4ff92SAndroid Build Coastguard Worker to_string(w) + ") has " + to_string(e.m_Origin.size()) +
828*89c4ff92SAndroid Build Coastguard Worker " dimensions, the input "
829*89c4ff92SAndroid Build Coastguard Worker "tensor has " +
830*89c4ff92SAndroid Build Coastguard Worker to_string(inputDims) + " dimensions.");
831*89c4ff92SAndroid Build Coastguard Worker }
832*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
833*89c4ff92SAndroid Build Coastguard Worker {
834*89c4ff92SAndroid Build Coastguard Worker if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
835*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_InputTensorInfos[0].GetShape()[i])
836*89c4ff92SAndroid Build Coastguard Worker {
837*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
838*89c4ff92SAndroid Build Coastguard Worker "be smaller or equal than the size of the input in that coord.");
839*89c4ff92SAndroid Build Coastguard Worker }
840*89c4ff92SAndroid Build Coastguard Worker }
841*89c4ff92SAndroid Build Coastguard Worker }
842*89c4ff92SAndroid Build Coastguard Worker }
843*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const844*89c4ff92SAndroid Build Coastguard Worker void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
845*89c4ff92SAndroid Build Coastguard Worker {
846*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ConcatQueueDescriptor"};
847*89c4ff92SAndroid Build Coastguard Worker
848*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
849*89c4ff92SAndroid Build Coastguard Worker
850*89c4ff92SAndroid Build Coastguard Worker if (m_Inputs.size() <= 0)
851*89c4ff92SAndroid Build Coastguard Worker {
852*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
853*89c4ff92SAndroid Build Coastguard Worker }
854*89c4ff92SAndroid Build Coastguard Worker if (m_Outputs.size() <= 0)
855*89c4ff92SAndroid Build Coastguard Worker {
856*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
857*89c4ff92SAndroid Build Coastguard Worker }
858*89c4ff92SAndroid Build Coastguard Worker
859*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() <= 0)
860*89c4ff92SAndroid Build Coastguard Worker {
861*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
862*89c4ff92SAndroid Build Coastguard Worker }
863*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() <= 0)
864*89c4ff92SAndroid Build Coastguard Worker {
865*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
866*89c4ff92SAndroid Build Coastguard Worker }
867*89c4ff92SAndroid Build Coastguard Worker
868*89c4ff92SAndroid Build Coastguard Worker if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
869*89c4ff92SAndroid Build Coastguard Worker {
870*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
871*89c4ff92SAndroid Build Coastguard Worker }
872*89c4ff92SAndroid Build Coastguard Worker
873*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
874*89c4ff92SAndroid Build Coastguard Worker {
875*89c4ff92SAndroid Build Coastguard Worker return;
876*89c4ff92SAndroid Build Coastguard Worker }
877*89c4ff92SAndroid Build Coastguard Worker
878*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
879*89c4ff92SAndroid Build Coastguard Worker {
880*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
881*89c4ff92SAndroid Build Coastguard Worker descriptorName + ": Number of split windows "
882*89c4ff92SAndroid Build Coastguard Worker "has to match number of workloadInfo.m_InputTensorInfos. "
883*89c4ff92SAndroid Build Coastguard Worker "Number of windows: " +
884*89c4ff92SAndroid Build Coastguard Worker to_string(m_ViewOrigins.size()) +
885*89c4ff92SAndroid Build Coastguard Worker ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
886*89c4ff92SAndroid Build Coastguard Worker }
887*89c4ff92SAndroid Build Coastguard Worker
888*89c4ff92SAndroid Build Coastguard Worker //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
889*89c4ff92SAndroid Build Coastguard Worker std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
890*89c4ff92SAndroid Build Coastguard Worker for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
891*89c4ff92SAndroid Build Coastguard Worker {
892*89c4ff92SAndroid Build Coastguard Worker //Checks that the dimensionality of output is same as the split windows.
893*89c4ff92SAndroid Build Coastguard Worker ViewOrigin const& e = m_ViewOrigins[w];
894*89c4ff92SAndroid Build Coastguard Worker if (e.m_Origin.size() != outputDims)
895*89c4ff92SAndroid Build Coastguard Worker {
896*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Window origin have to "
897*89c4ff92SAndroid Build Coastguard Worker "have the same dimensionality as the output tensor. "
898*89c4ff92SAndroid Build Coastguard Worker "Window origin (index: " +
899*89c4ff92SAndroid Build Coastguard Worker to_string(w) + ") has " + to_string(e.m_Origin.size()) +
900*89c4ff92SAndroid Build Coastguard Worker " dimensions, the output "
901*89c4ff92SAndroid Build Coastguard Worker "tensor has " +
902*89c4ff92SAndroid Build Coastguard Worker to_string(outputDims) + " dimensions.");
903*89c4ff92SAndroid Build Coastguard Worker }
904*89c4ff92SAndroid Build Coastguard Worker //Checks that the merge windows are within the output tensor.
905*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
906*89c4ff92SAndroid Build Coastguard Worker {
907*89c4ff92SAndroid Build Coastguard Worker if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
908*89c4ff92SAndroid Build Coastguard Worker > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
909*89c4ff92SAndroid Build Coastguard Worker {
910*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
911*89c4ff92SAndroid Build Coastguard Worker "be smaller or equal than the size of the output in that coord.");
912*89c4ff92SAndroid Build Coastguard Worker }
913*89c4ff92SAndroid Build Coastguard Worker }
914*89c4ff92SAndroid Build Coastguard Worker }
915*89c4ff92SAndroid Build Coastguard Worker
916*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
917*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
918*89c4ff92SAndroid Build Coastguard Worker {
919*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
920*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
921*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
922*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean,
923*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
924*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
925*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
926*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
927*89c4ff92SAndroid Build Coastguard Worker };
928*89c4ff92SAndroid Build Coastguard Worker
929*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
930*89c4ff92SAndroid Build Coastguard Worker for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
931*89c4ff92SAndroid Build Coastguard Worker {
932*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
933*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
934*89c4ff92SAndroid Build Coastguard Worker
935*89c4ff92SAndroid Build Coastguard Worker const std::string inputName = "input_" + std::to_string(i);
936*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
937*89c4ff92SAndroid Build Coastguard Worker }
938*89c4ff92SAndroid Build Coastguard Worker }
939*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const940*89c4ff92SAndroid Build Coastguard Worker void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
941*89c4ff92SAndroid Build Coastguard Worker {
942*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"StackQueueDescriptor"};
943*89c4ff92SAndroid Build Coastguard Worker
944*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
945*89c4ff92SAndroid Build Coastguard Worker
946*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
947*89c4ff92SAndroid Build Coastguard Worker {
948*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
949*89c4ff92SAndroid Build Coastguard Worker }
950*89c4ff92SAndroid Build Coastguard Worker
951*89c4ff92SAndroid Build Coastguard Worker // All inputs must have the same shape, which is defined in parameters
952*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = m_Parameters.m_InputShape;
953*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
954*89c4ff92SAndroid Build Coastguard Worker {
955*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
956*89c4ff92SAndroid Build Coastguard Worker {
957*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
958*89c4ff92SAndroid Build Coastguard Worker }
959*89c4ff92SAndroid Build Coastguard Worker }
960*89c4ff92SAndroid Build Coastguard Worker
961*89c4ff92SAndroid Build Coastguard Worker if (inputShape.GetNumDimensions() > 4)
962*89c4ff92SAndroid Build Coastguard Worker {
963*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
964*89c4ff92SAndroid Build Coastguard Worker }
965*89c4ff92SAndroid Build Coastguard Worker
966*89c4ff92SAndroid Build Coastguard Worker // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
967*89c4ff92SAndroid Build Coastguard Worker // since the output tensor has an additional dimension.
968*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
969*89c4ff92SAndroid Build Coastguard Worker {
970*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
971*89c4ff92SAndroid Build Coastguard Worker "than the number of input dimensions.");
972*89c4ff92SAndroid Build Coastguard Worker }
973*89c4ff92SAndroid Build Coastguard Worker
974*89c4ff92SAndroid Build Coastguard Worker // Output shape must be as inferred from the input shape
975*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
976*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
977*89c4ff92SAndroid Build Coastguard Worker {
978*89c4ff92SAndroid Build Coastguard Worker if (outputShape[i] != inputShape[i])
979*89c4ff92SAndroid Build Coastguard Worker {
980*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor must "
981*89c4ff92SAndroid Build Coastguard Worker "match shape inferred from input tensor.");
982*89c4ff92SAndroid Build Coastguard Worker }
983*89c4ff92SAndroid Build Coastguard Worker }
984*89c4ff92SAndroid Build Coastguard Worker
985*89c4ff92SAndroid Build Coastguard Worker if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
986*89c4ff92SAndroid Build Coastguard Worker {
987*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor must "
988*89c4ff92SAndroid Build Coastguard Worker "match shape inferred from input tensor.");
989*89c4ff92SAndroid Build Coastguard Worker }
990*89c4ff92SAndroid Build Coastguard Worker
991*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
992*89c4ff92SAndroid Build Coastguard Worker {
993*89c4ff92SAndroid Build Coastguard Worker if (outputShape[i] != inputShape[i-1])
994*89c4ff92SAndroid Build Coastguard Worker {
995*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor must "
996*89c4ff92SAndroid Build Coastguard Worker "match shape inferred from input tensor.");
997*89c4ff92SAndroid Build Coastguard Worker }
998*89c4ff92SAndroid Build Coastguard Worker }
999*89c4ff92SAndroid Build Coastguard Worker
1000*89c4ff92SAndroid Build Coastguard Worker if (outputShape.GetNumDimensions() > 5)
1001*89c4ff92SAndroid Build Coastguard Worker {
1002*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1003*89c4ff92SAndroid Build Coastguard Worker }
1004*89c4ff92SAndroid Build Coastguard Worker
1005*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1006*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1007*89c4ff92SAndroid Build Coastguard Worker {
1008*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1009*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1010*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1011*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean,
1012*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
1013*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1014*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1015*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1016*89c4ff92SAndroid Build Coastguard Worker };
1017*89c4ff92SAndroid Build Coastguard Worker
1018*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1019*89c4ff92SAndroid Build Coastguard Worker
1020*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1021*89c4ff92SAndroid Build Coastguard Worker {
1022*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1023*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_InputTensorInfos[i],
1024*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1025*89c4ff92SAndroid Build Coastguard Worker "input_0",
1026*89c4ff92SAndroid Build Coastguard Worker "input_" + std::to_string(i));
1027*89c4ff92SAndroid Build Coastguard Worker }
1028*89c4ff92SAndroid Build Coastguard Worker
1029*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1030*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_OutputTensorInfos[0],
1031*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1032*89c4ff92SAndroid Build Coastguard Worker "input_0",
1033*89c4ff92SAndroid Build Coastguard Worker "output");
1034*89c4ff92SAndroid Build Coastguard Worker }
1035*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1036*89c4ff92SAndroid Build Coastguard Worker void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1037*89c4ff92SAndroid Build Coastguard Worker {
1038*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"FillQueueDescriptor"};
1039*89c4ff92SAndroid Build Coastguard Worker
1040*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1041*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1042*89c4ff92SAndroid Build Coastguard Worker
1043*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1044*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1045*89c4ff92SAndroid Build Coastguard Worker
1046*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1047*89c4ff92SAndroid Build Coastguard Worker
1048*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1049*89c4ff92SAndroid Build Coastguard Worker {
1050*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1051*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1052*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1053*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1054*89c4ff92SAndroid Build Coastguard Worker };
1055*89c4ff92SAndroid Build Coastguard Worker
1056*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1057*89c4ff92SAndroid Build Coastguard Worker }
1058*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1059*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1060*89c4ff92SAndroid Build Coastguard Worker {
1061*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"FullyConnectedQueueDescriptor"};
1062*89c4ff92SAndroid Build Coastguard Worker
1063*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = 2;
1064*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1065*89c4ff92SAndroid Build Coastguard Worker {
1066*89c4ff92SAndroid Build Coastguard Worker numInputs = 3;
1067*89c4ff92SAndroid Build Coastguard Worker }
1068*89c4ff92SAndroid Build Coastguard Worker
1069*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1070*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1071*89c4ff92SAndroid Build Coastguard Worker
1072*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1073*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1074*89c4ff92SAndroid Build Coastguard Worker
1075*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1076*89c4ff92SAndroid Build Coastguard Worker
1077*89c4ff92SAndroid Build Coastguard Worker if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
1078*89c4ff92SAndroid Build Coastguard Worker {
1079*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
1080*89c4ff92SAndroid Build Coastguard Worker }
1081*89c4ff92SAndroid Build Coastguard Worker
1082*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1083*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
1084*89c4ff92SAndroid Build Coastguard Worker
1085*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1086*89c4ff92SAndroid Build Coastguard Worker {
1087*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1088*89c4ff92SAndroid Build Coastguard Worker // Validates type and quantization values.
1089*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1090*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1091*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1092*89c4ff92SAndroid Build Coastguard Worker }
1093*89c4ff92SAndroid Build Coastguard Worker
1094*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1095*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1096*89c4ff92SAndroid Build Coastguard Worker {
1097*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1098*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1099*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1100*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1101*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1102*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1103*89c4ff92SAndroid Build Coastguard Worker };
1104*89c4ff92SAndroid Build Coastguard Worker
1105*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1106*89c4ff92SAndroid Build Coastguard Worker
1107*89c4ff92SAndroid Build Coastguard Worker // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1108*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1109*89c4ff92SAndroid Build Coastguard Worker {
1110*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1111*89c4ff92SAndroid Build Coastguard Worker {
1112*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1113*89c4ff92SAndroid Build Coastguard Worker "for BFloat16 input.");
1114*89c4ff92SAndroid Build Coastguard Worker }
1115*89c4ff92SAndroid Build Coastguard Worker }
1116*89c4ff92SAndroid Build Coastguard Worker else
1117*89c4ff92SAndroid Build Coastguard Worker {
1118*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1119*89c4ff92SAndroid Build Coastguard Worker }
1120*89c4ff92SAndroid Build Coastguard Worker }
1121*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1122*89c4ff92SAndroid Build Coastguard Worker void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1123*89c4ff92SAndroid Build Coastguard Worker {
1124*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"NormalizationQueueDescriptor"};
1125*89c4ff92SAndroid Build Coastguard Worker
1126*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1127*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1128*89c4ff92SAndroid Build Coastguard Worker
1129*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1130*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1131*89c4ff92SAndroid Build Coastguard Worker
1132*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1133*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1134*89c4ff92SAndroid Build Coastguard Worker {
1135*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1136*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1137*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1138*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1139*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1140*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1141*89c4ff92SAndroid Build Coastguard Worker };
1142*89c4ff92SAndroid Build Coastguard Worker
1143*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1144*89c4ff92SAndroid Build Coastguard Worker
1145*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1146*89c4ff92SAndroid Build Coastguard Worker
1147*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1148*89c4ff92SAndroid Build Coastguard Worker }
1149*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1150*89c4ff92SAndroid Build Coastguard Worker void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1151*89c4ff92SAndroid Build Coastguard Worker {
1152*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"AdditionQueueDescriptor"};
1153*89c4ff92SAndroid Build Coastguard Worker
1154*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
1155*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1156*89c4ff92SAndroid Build Coastguard Worker
1157*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1158*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1159*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1160*89c4ff92SAndroid Build Coastguard Worker
1161*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1162*89c4ff92SAndroid Build Coastguard Worker {
1163*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1164*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1165*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1166*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1167*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1168*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1169*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1170*89c4ff92SAndroid Build Coastguard Worker };
1171*89c4ff92SAndroid Build Coastguard Worker
1172*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1173*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1174*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1175*89c4ff92SAndroid Build Coastguard Worker
1176*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1177*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
1178*89c4ff92SAndroid Build Coastguard Worker
1179*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1180*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
1181*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
1182*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1183*89c4ff92SAndroid Build Coastguard Worker "input_0",
1184*89c4ff92SAndroid Build Coastguard Worker "input_1");
1185*89c4ff92SAndroid Build Coastguard Worker }
1186*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1187*89c4ff92SAndroid Build Coastguard Worker void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1188*89c4ff92SAndroid Build Coastguard Worker {
1189*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"MultiplicationQueueDescriptor"};
1190*89c4ff92SAndroid Build Coastguard Worker
1191*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
1192*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1193*89c4ff92SAndroid Build Coastguard Worker
1194*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1195*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1196*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1197*89c4ff92SAndroid Build Coastguard Worker
1198*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1199*89c4ff92SAndroid Build Coastguard Worker {
1200*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1201*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1202*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1203*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1204*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1205*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1206*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1207*89c4ff92SAndroid Build Coastguard Worker };
1208*89c4ff92SAndroid Build Coastguard Worker
1209*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1210*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1211*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1212*89c4ff92SAndroid Build Coastguard Worker
1213*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1214*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
1215*89c4ff92SAndroid Build Coastguard Worker
1216*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1217*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
1218*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
1219*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1220*89c4ff92SAndroid Build Coastguard Worker "input_0",
1221*89c4ff92SAndroid Build Coastguard Worker "input_1");
1222*89c4ff92SAndroid Build Coastguard Worker }
1223*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1224*89c4ff92SAndroid Build Coastguard Worker void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1225*89c4ff92SAndroid Build Coastguard Worker {
1226*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
1227*89c4ff92SAndroid Build Coastguard Worker
1228*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1229*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1230*89c4ff92SAndroid Build Coastguard Worker
1231*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1232*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1233*89c4ff92SAndroid Build Coastguard Worker
1234*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1235*89c4ff92SAndroid Build Coastguard Worker {
1236*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1237*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1238*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1239*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1240*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1241*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1242*89c4ff92SAndroid Build Coastguard Worker };
1243*89c4ff92SAndroid Build Coastguard Worker
1244*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1245*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1246*89c4ff92SAndroid Build Coastguard Worker
1247*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1248*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1249*89c4ff92SAndroid Build Coastguard Worker
1250*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_Mean, descriptorName, "mean");
1251*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_Variance, descriptorName, "variance");
1252*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_Beta, descriptorName, "beta");
1253*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_Gamma, descriptorName, "gamma");
1254*89c4ff92SAndroid Build Coastguard Worker
1255*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& mean = m_Mean->GetTensorInfo();
1256*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& variance = m_Variance->GetTensorInfo();
1257*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& beta = m_Beta->GetTensorInfo();
1258*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& gamma = m_Gamma->GetTensorInfo();
1259*89c4ff92SAndroid Build Coastguard Worker
1260*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1261*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1262*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1263*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
1264*89c4ff92SAndroid Build Coastguard Worker
1265*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1266*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1267*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
1268*89c4ff92SAndroid Build Coastguard Worker }
1269*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1270*89c4ff92SAndroid Build Coastguard Worker void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1271*89c4ff92SAndroid Build Coastguard Worker {
1272*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"Convolution2dQueueDescriptor"};
1273*89c4ff92SAndroid Build Coastguard Worker
1274*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = 2;
1275*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1276*89c4ff92SAndroid Build Coastguard Worker {
1277*89c4ff92SAndroid Build Coastguard Worker numInputs = 3;
1278*89c4ff92SAndroid Build Coastguard Worker }
1279*89c4ff92SAndroid Build Coastguard Worker
1280*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1281*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1282*89c4ff92SAndroid Build Coastguard Worker
1283*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1284*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1285*89c4ff92SAndroid Build Coastguard Worker
1286*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1287*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1288*89c4ff92SAndroid Build Coastguard Worker
1289*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1290*89c4ff92SAndroid Build Coastguard Worker
1291*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1292*89c4ff92SAndroid Build Coastguard Worker
1293*89c4ff92SAndroid Build Coastguard Worker ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1294*89c4ff92SAndroid Build Coastguard Worker
1295*89c4ff92SAndroid Build Coastguard Worker Optional<TensorInfo> optionalBiasTensorInfo;
1296*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1297*89c4ff92SAndroid Build Coastguard Worker {
1298*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1299*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1300*89c4ff92SAndroid Build Coastguard Worker
1301*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1302*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1303*89c4ff92SAndroid Build Coastguard Worker }
1304*89c4ff92SAndroid Build Coastguard Worker
1305*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1306*89c4ff92SAndroid Build Coastguard Worker {
1307*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1308*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1309*89c4ff92SAndroid Build Coastguard Worker "cannot be either negative or 0.",
1310*89c4ff92SAndroid Build Coastguard Worker descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1311*89c4ff92SAndroid Build Coastguard Worker }
1312*89c4ff92SAndroid Build Coastguard Worker
1313*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantization(inputTensorInfo,
1314*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
1315*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo,
1316*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo,
1317*89c4ff92SAndroid Build Coastguard Worker descriptorName);
1318*89c4ff92SAndroid Build Coastguard Worker
1319*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1320*89c4ff92SAndroid Build Coastguard Worker {
1321*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1322*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1323*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1324*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1325*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1326*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1327*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
1328*89c4ff92SAndroid Build Coastguard Worker };
1329*89c4ff92SAndroid Build Coastguard Worker
1330*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1331*89c4ff92SAndroid Build Coastguard Worker
1332*89c4ff92SAndroid Build Coastguard Worker // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1333*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1334*89c4ff92SAndroid Build Coastguard Worker {
1335*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1336*89c4ff92SAndroid Build Coastguard Worker {
1337*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1338*89c4ff92SAndroid Build Coastguard Worker "for BFloat16 input.");
1339*89c4ff92SAndroid Build Coastguard Worker }
1340*89c4ff92SAndroid Build Coastguard Worker }
1341*89c4ff92SAndroid Build Coastguard Worker else
1342*89c4ff92SAndroid Build Coastguard Worker {
1343*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1344*89c4ff92SAndroid Build Coastguard Worker }
1345*89c4ff92SAndroid Build Coastguard Worker }
1346*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1347*89c4ff92SAndroid Build Coastguard Worker void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1348*89c4ff92SAndroid Build Coastguard Worker {
1349*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"Convolution3dQueueDescriptor"};
1350*89c4ff92SAndroid Build Coastguard Worker
1351*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = 2;
1352*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1353*89c4ff92SAndroid Build Coastguard Worker {
1354*89c4ff92SAndroid Build Coastguard Worker numInputs = 3;
1355*89c4ff92SAndroid Build Coastguard Worker }
1356*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1357*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1358*89c4ff92SAndroid Build Coastguard Worker
1359*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1360*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1361*89c4ff92SAndroid Build Coastguard Worker
1362*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1363*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1364*89c4ff92SAndroid Build Coastguard Worker
1365*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1366*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1367*89c4ff92SAndroid Build Coastguard Worker
1368*89c4ff92SAndroid Build Coastguard Worker ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1369*89c4ff92SAndroid Build Coastguard Worker
1370*89c4ff92SAndroid Build Coastguard Worker Optional<TensorInfo> optionalBiasTensorInfo;
1371*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1372*89c4ff92SAndroid Build Coastguard Worker {
1373*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1374*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1375*89c4ff92SAndroid Build Coastguard Worker
1376*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1377*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1378*89c4ff92SAndroid Build Coastguard Worker }
1379*89c4ff92SAndroid Build Coastguard Worker
1380*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1381*89c4ff92SAndroid Build Coastguard Worker {
1382*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1383*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1384*89c4ff92SAndroid Build Coastguard Worker "cannot be either negative or 0.",
1385*89c4ff92SAndroid Build Coastguard Worker descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1386*89c4ff92SAndroid Build Coastguard Worker }
1387*89c4ff92SAndroid Build Coastguard Worker
1388*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantization(inputTensorInfo,
1389*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
1390*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo,
1391*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo,
1392*89c4ff92SAndroid Build Coastguard Worker descriptorName);
1393*89c4ff92SAndroid Build Coastguard Worker
1394*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1395*89c4ff92SAndroid Build Coastguard Worker {
1396*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1397*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1398*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1399*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1400*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1401*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1402*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
1403*89c4ff92SAndroid Build Coastguard Worker };
1404*89c4ff92SAndroid Build Coastguard Worker
1405*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1406*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1407*89c4ff92SAndroid Build Coastguard Worker }
1408*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1409*89c4ff92SAndroid Build Coastguard Worker void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1410*89c4ff92SAndroid Build Coastguard Worker {
1411*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1412*89c4ff92SAndroid Build Coastguard Worker
1413*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = 2;
1414*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1415*89c4ff92SAndroid Build Coastguard Worker {
1416*89c4ff92SAndroid Build Coastguard Worker numInputs = 3;
1417*89c4ff92SAndroid Build Coastguard Worker }
1418*89c4ff92SAndroid Build Coastguard Worker
1419*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1420*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1421*89c4ff92SAndroid Build Coastguard Worker
1422*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1423*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1424*89c4ff92SAndroid Build Coastguard Worker
1425*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1426*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1427*89c4ff92SAndroid Build Coastguard Worker
1428*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1429*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1430*89c4ff92SAndroid Build Coastguard Worker
1431*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1432*89c4ff92SAndroid Build Coastguard Worker {
1433*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1434*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1435*89c4ff92SAndroid Build Coastguard Worker "cannot be smaller than 1.",
1436*89c4ff92SAndroid Build Coastguard Worker descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
1437*89c4ff92SAndroid Build Coastguard Worker }
1438*89c4ff92SAndroid Build Coastguard Worker
1439*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1440*89c4ff92SAndroid Build Coastguard Worker {
1441*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1442*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1443*89c4ff92SAndroid Build Coastguard Worker "cannot be either negative or 0.",
1444*89c4ff92SAndroid Build Coastguard Worker descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1445*89c4ff92SAndroid Build Coastguard Worker }
1446*89c4ff92SAndroid Build Coastguard Worker
1447*89c4ff92SAndroid Build Coastguard Worker if (weightTensorInfo.GetShape()[0] != 1)
1448*89c4ff92SAndroid Build Coastguard Worker {
1449*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
1450*89c4ff92SAndroid Build Coastguard Worker "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1451*89c4ff92SAndroid Build Coastguard Worker "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1452*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1453*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[0],
1454*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[1],
1455*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[2],
1456*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[3]));
1457*89c4ff92SAndroid Build Coastguard Worker }
1458*89c4ff92SAndroid Build Coastguard Worker
1459*89c4ff92SAndroid Build Coastguard Worker const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1460*89c4ff92SAndroid Build Coastguard Worker const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1461*89c4ff92SAndroid Build Coastguard Worker const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1462*89c4ff92SAndroid Build Coastguard Worker const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1463*89c4ff92SAndroid Build Coastguard Worker
1464*89c4ff92SAndroid Build Coastguard Worker // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1465*89c4ff92SAndroid Build Coastguard Worker bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1466*89c4ff92SAndroid Build Coastguard Worker bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1467*89c4ff92SAndroid Build Coastguard Worker
1468*89c4ff92SAndroid Build Coastguard Worker if (!(validRefFormat || validAclFormat))
1469*89c4ff92SAndroid Build Coastguard Worker {
1470*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format(
1471*89c4ff92SAndroid Build Coastguard Worker "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1472*89c4ff92SAndroid Build Coastguard Worker "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1473*89c4ff92SAndroid Build Coastguard Worker "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1474*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1475*89c4ff92SAndroid Build Coastguard Worker numOutputChannels,
1476*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[0],
1477*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[1],
1478*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[2],
1479*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo.GetShape()[3]));
1480*89c4ff92SAndroid Build Coastguard Worker }
1481*89c4ff92SAndroid Build Coastguard Worker
1482*89c4ff92SAndroid Build Coastguard Worker ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1483*89c4ff92SAndroid Build Coastguard Worker
1484*89c4ff92SAndroid Build Coastguard Worker Optional<TensorInfo> optionalBiasTensorInfo;
1485*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
1486*89c4ff92SAndroid Build Coastguard Worker {
1487*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1488*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1489*89c4ff92SAndroid Build Coastguard Worker
1490*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1491*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1492*89c4ff92SAndroid Build Coastguard Worker }
1493*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantization(inputTensorInfo,
1494*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
1495*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo,
1496*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo,
1497*89c4ff92SAndroid Build Coastguard Worker descriptorName);
1498*89c4ff92SAndroid Build Coastguard Worker
1499*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1500*89c4ff92SAndroid Build Coastguard Worker {
1501*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1502*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1503*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1504*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1505*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1506*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1507*89c4ff92SAndroid Build Coastguard Worker };
1508*89c4ff92SAndroid Build Coastguard Worker
1509*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1510*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1511*89c4ff92SAndroid Build Coastguard Worker }
1512*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1513*89c4ff92SAndroid Build Coastguard Worker void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1514*89c4ff92SAndroid Build Coastguard Worker {
1515*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"PermuteQueueDescriptor"};
1516*89c4ff92SAndroid Build Coastguard Worker
1517*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1518*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1519*89c4ff92SAndroid Build Coastguard Worker
1520*89c4ff92SAndroid Build Coastguard Worker const PermutationVector& mapping = m_Parameters.m_DimMappings;
1521*89c4ff92SAndroid Build Coastguard Worker
1522*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1523*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1524*89c4ff92SAndroid Build Coastguard Worker
1525*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1526*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
1527*89c4ff92SAndroid Build Coastguard Worker
1528*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
1529*89c4ff92SAndroid Build Coastguard Worker {
1530*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
1531*89c4ff92SAndroid Build Coastguard Worker {
1532*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1533*89c4ff92SAndroid Build Coastguard Worker " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1534*89c4ff92SAndroid Build Coastguard Worker "must match dst dimension " + to_string(mapping[i]) +
1535*89c4ff92SAndroid Build Coastguard Worker " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
1536*89c4ff92SAndroid Build Coastguard Worker }
1537*89c4ff92SAndroid Build Coastguard Worker }
1538*89c4ff92SAndroid Build Coastguard Worker
1539*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1540*89c4ff92SAndroid Build Coastguard Worker }
1541*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1542*89c4ff92SAndroid Build Coastguard Worker void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1543*89c4ff92SAndroid Build Coastguard Worker {
1544*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"Pooling2dQueueDescriptor"};
1545*89c4ff92SAndroid Build Coastguard Worker
1546*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1547*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1548*89c4ff92SAndroid Build Coastguard Worker
1549*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1550*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1551*89c4ff92SAndroid Build Coastguard Worker
1552*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1553*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1554*89c4ff92SAndroid Build Coastguard Worker
1555*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1556*89c4ff92SAndroid Build Coastguard Worker {
1557*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1558*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1559*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1560*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1561*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1562*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1563*89c4ff92SAndroid Build Coastguard Worker };
1564*89c4ff92SAndroid Build Coastguard Worker
1565*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1566*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1567*89c4ff92SAndroid Build Coastguard Worker }
1568*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1569*89c4ff92SAndroid Build Coastguard Worker void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1570*89c4ff92SAndroid Build Coastguard Worker {
1571*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"Pooling3dQueueDescriptor"};
1572*89c4ff92SAndroid Build Coastguard Worker
1573*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1574*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1575*89c4ff92SAndroid Build Coastguard Worker
1576*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1577*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1578*89c4ff92SAndroid Build Coastguard Worker
1579*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1580*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1581*89c4ff92SAndroid Build Coastguard Worker
1582*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1583*89c4ff92SAndroid Build Coastguard Worker {
1584*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1585*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1586*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1587*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1588*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1589*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1590*89c4ff92SAndroid Build Coastguard Worker };
1591*89c4ff92SAndroid Build Coastguard Worker
1592*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1593*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1594*89c4ff92SAndroid Build Coastguard Worker }
1595*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1596*89c4ff92SAndroid Build Coastguard Worker void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1597*89c4ff92SAndroid Build Coastguard Worker {
1598*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ResizeQueueDescriptor"};
1599*89c4ff92SAndroid Build Coastguard Worker
1600*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1601*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1602*89c4ff92SAndroid Build Coastguard Worker
1603*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1604*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1605*89c4ff92SAndroid Build Coastguard Worker
1606*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1607*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1608*89c4ff92SAndroid Build Coastguard Worker
1609*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1610*89c4ff92SAndroid Build Coastguard Worker {
1611*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1612*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1613*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1614*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1615*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1616*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1617*89c4ff92SAndroid Build Coastguard Worker };
1618*89c4ff92SAndroid Build Coastguard Worker
1619*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1620*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1621*89c4ff92SAndroid Build Coastguard Worker
1622*89c4ff92SAndroid Build Coastguard Worker // Resize only changes width and height: batch and channel count must match.
1623*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1624*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
1625*89c4ff92SAndroid Build Coastguard Worker if (inputBatchSize != outputBatchSize)
1626*89c4ff92SAndroid Build Coastguard Worker {
1627*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1628*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1629*89c4ff92SAndroid Build Coastguard Worker descriptorName, inputBatchSize, outputBatchSize));
1630*89c4ff92SAndroid Build Coastguard Worker }
1631*89c4ff92SAndroid Build Coastguard Worker
1632*89c4ff92SAndroid Build Coastguard Worker DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1633*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1634*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1635*89c4ff92SAndroid Build Coastguard Worker if (inputChannelCount != outputChannelCount)
1636*89c4ff92SAndroid Build Coastguard Worker {
1637*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1638*89c4ff92SAndroid Build Coastguard Worker fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1639*89c4ff92SAndroid Build Coastguard Worker descriptorName, inputChannelCount, outputChannelCount));
1640*89c4ff92SAndroid Build Coastguard Worker }
1641*89c4ff92SAndroid Build Coastguard Worker }
1642*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1643*89c4ff92SAndroid Build Coastguard Worker void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1644*89c4ff92SAndroid Build Coastguard Worker {
1645*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
1646*89c4ff92SAndroid Build Coastguard Worker
1647*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1648*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1649*89c4ff92SAndroid Build Coastguard Worker
1650*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1651*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1652*89c4ff92SAndroid Build Coastguard Worker
1653*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1654*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1655*89c4ff92SAndroid Build Coastguard Worker
1656*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1657*89c4ff92SAndroid Build Coastguard Worker
1658*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Min > m_Parameters.m_Max)
1659*89c4ff92SAndroid Build Coastguard Worker {
1660*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
1661*89c4ff92SAndroid Build Coastguard Worker }
1662*89c4ff92SAndroid Build Coastguard Worker }
1663*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1664*89c4ff92SAndroid Build Coastguard Worker void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1665*89c4ff92SAndroid Build Coastguard Worker {
1666*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1667*89c4ff92SAndroid Build Coastguard Worker
1668*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1669*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1670*89c4ff92SAndroid Build Coastguard Worker
1671*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1672*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1673*89c4ff92SAndroid Build Coastguard Worker
1674*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetNumDimensions() > 4)
1675*89c4ff92SAndroid Build Coastguard Worker {
1676*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1677*89c4ff92SAndroid Build Coastguard Worker }
1678*89c4ff92SAndroid Build Coastguard Worker
1679*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1680*89c4ff92SAndroid Build Coastguard Worker
1681*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1682*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1683*89c4ff92SAndroid Build Coastguard Worker {
1684*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1685*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1686*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
1687*89c4ff92SAndroid Build Coastguard Worker };
1688*89c4ff92SAndroid Build Coastguard Worker
1689*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1690*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1691*89c4ff92SAndroid Build Coastguard Worker }
1692*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1693*89c4ff92SAndroid Build Coastguard Worker void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1694*89c4ff92SAndroid Build Coastguard Worker {
1695*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"L2NormalizationQueueDescriptor"};
1696*89c4ff92SAndroid Build Coastguard Worker
1697*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1698*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1699*89c4ff92SAndroid Build Coastguard Worker
1700*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1701*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1702*89c4ff92SAndroid Build Coastguard Worker
1703*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetNumDimensions() > 4)
1704*89c4ff92SAndroid Build Coastguard Worker {
1705*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1706*89c4ff92SAndroid Build Coastguard Worker }
1707*89c4ff92SAndroid Build Coastguard Worker
1708*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1709*89c4ff92SAndroid Build Coastguard Worker
1710*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1711*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1712*89c4ff92SAndroid Build Coastguard Worker {
1713*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1714*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1715*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1716*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1717*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1718*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1719*89c4ff92SAndroid Build Coastguard Worker };
1720*89c4ff92SAndroid Build Coastguard Worker
1721*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1722*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1723*89c4ff92SAndroid Build Coastguard Worker }
1724*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1725*89c4ff92SAndroid Build Coastguard Worker void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1726*89c4ff92SAndroid Build Coastguard Worker {
1727*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1728*89c4ff92SAndroid Build Coastguard Worker
1729*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1730*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1731*89c4ff92SAndroid Build Coastguard Worker
1732*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1733*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1734*89c4ff92SAndroid Build Coastguard Worker
1735*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1736*89c4ff92SAndroid Build Coastguard Worker
1737*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1738*89c4ff92SAndroid Build Coastguard Worker {
1739*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1740*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1741*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1742*89c4ff92SAndroid Build Coastguard Worker };
1743*89c4ff92SAndroid Build Coastguard Worker
1744*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1745*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1746*89c4ff92SAndroid Build Coastguard Worker }
1747*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1748*89c4ff92SAndroid Build Coastguard Worker void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1749*89c4ff92SAndroid Build Coastguard Worker {
1750*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ConstantQueueDescriptor"};
1751*89c4ff92SAndroid Build Coastguard Worker
1752*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 0);
1753*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1754*89c4ff92SAndroid Build Coastguard Worker
1755*89c4ff92SAndroid Build Coastguard Worker if (!m_LayerOutput)
1756*89c4ff92SAndroid Build Coastguard Worker {
1757*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": No const input specified.");
1758*89c4ff92SAndroid Build Coastguard Worker }
1759*89c4ff92SAndroid Build Coastguard Worker
1760*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1761*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
1762*89c4ff92SAndroid Build Coastguard Worker
1763*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1764*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1765*89c4ff92SAndroid Build Coastguard Worker {
1766*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1767*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1768*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1769*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1770*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1771*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
1772*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1773*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
1774*89c4ff92SAndroid Build Coastguard Worker };
1775*89c4ff92SAndroid Build Coastguard Worker
1776*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1777*89c4ff92SAndroid Build Coastguard Worker }
1778*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1779*89c4ff92SAndroid Build Coastguard Worker void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1780*89c4ff92SAndroid Build Coastguard Worker {
1781*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ReshapeQueueDescriptor"};
1782*89c4ff92SAndroid Build Coastguard Worker
1783*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1784*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1785*89c4ff92SAndroid Build Coastguard Worker
1786*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1787*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1788*89c4ff92SAndroid Build Coastguard Worker
1789*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1790*89c4ff92SAndroid Build Coastguard Worker
1791*89c4ff92SAndroid Build Coastguard Worker // Check the supported data types
1792*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1793*89c4ff92SAndroid Build Coastguard Worker {
1794*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1795*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1796*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1797*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1798*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1799*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
1800*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
1801*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean
1802*89c4ff92SAndroid Build Coastguard Worker };
1803*89c4ff92SAndroid Build Coastguard Worker
1804*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1805*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1806*89c4ff92SAndroid Build Coastguard Worker }
1807*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1808*89c4ff92SAndroid Build Coastguard Worker void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1809*89c4ff92SAndroid Build Coastguard Worker {
1810*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
1811*89c4ff92SAndroid Build Coastguard Worker
1812*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1813*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1814*89c4ff92SAndroid Build Coastguard Worker
1815*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1816*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1817*89c4ff92SAndroid Build Coastguard Worker
1818*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1819*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1820*89c4ff92SAndroid Build Coastguard Worker
1821*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BlockShape.size() != 2)
1822*89c4ff92SAndroid Build Coastguard Worker {
1823*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
1824*89c4ff92SAndroid Build Coastguard Worker }
1825*89c4ff92SAndroid Build Coastguard Worker
1826*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1827*89c4ff92SAndroid Build Coastguard Worker {
1828*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1829*89c4ff92SAndroid Build Coastguard Worker "dimensions as Block Shape.");
1830*89c4ff92SAndroid Build Coastguard Worker }
1831*89c4ff92SAndroid Build Coastguard Worker
1832*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = inputTensorInfo.GetShape();
1833*89c4ff92SAndroid Build Coastguard Worker
1834*89c4ff92SAndroid Build Coastguard Worker std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
1835*89c4ff92SAndroid Build Coastguard Worker std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
1836*89c4ff92SAndroid Build Coastguard Worker
1837*89c4ff92SAndroid Build Coastguard Worker DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1838*89c4ff92SAndroid Build Coastguard Worker
1839*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1840*89c4ff92SAndroid Build Coastguard Worker widthPad.first + widthPad.second;
1841*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1842*89c4ff92SAndroid Build Coastguard Worker heightPad.first + heightPad.second;
1843*89c4ff92SAndroid Build Coastguard Worker
1844*89c4ff92SAndroid Build Coastguard Worker const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1845*89c4ff92SAndroid Build Coastguard Worker inputShape[dimensionIndices.GetChannelsIndex()];
1846*89c4ff92SAndroid Build Coastguard Worker const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1847*89c4ff92SAndroid Build Coastguard Worker
1848*89c4ff92SAndroid Build Coastguard Worker if (numOutputElements != numInputElements)
1849*89c4ff92SAndroid Build Coastguard Worker {
1850*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1851*89c4ff92SAndroid Build Coastguard Worker to_string(numInputElements) + " after padding but output tensor has " +
1852*89c4ff92SAndroid Build Coastguard Worker to_string(numOutputElements) + " elements.");
1853*89c4ff92SAndroid Build Coastguard Worker }
1854*89c4ff92SAndroid Build Coastguard Worker
1855*89c4ff92SAndroid Build Coastguard Worker if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
1856*89c4ff92SAndroid Build Coastguard Worker {
1857*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1858*89c4ff92SAndroid Build Coastguard Worker "divisible by Block Shape in all spatial dimensions");
1859*89c4ff92SAndroid Build Coastguard Worker }
1860*89c4ff92SAndroid Build Coastguard Worker
1861*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1862*89c4ff92SAndroid Build Coastguard Worker {
1863*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1864*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1865*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1866*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1867*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1868*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1869*89c4ff92SAndroid Build Coastguard Worker };
1870*89c4ff92SAndroid Build Coastguard Worker
1871*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1872*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1873*89c4ff92SAndroid Build Coastguard Worker }
1874*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1875*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1876*89c4ff92SAndroid Build Coastguard Worker {
1877*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
1878*89c4ff92SAndroid Build Coastguard Worker
1879*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1880*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1881*89c4ff92SAndroid Build Coastguard Worker
1882*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1883*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1884*89c4ff92SAndroid Build Coastguard Worker
1885*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1886*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1887*89c4ff92SAndroid Build Coastguard Worker
1888*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1889*89c4ff92SAndroid Build Coastguard Worker {
1890*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1891*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1892*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1893*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
1894*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
1895*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1896*89c4ff92SAndroid Build Coastguard Worker };
1897*89c4ff92SAndroid Build Coastguard Worker
1898*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1899*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1900*89c4ff92SAndroid Build Coastguard Worker
1901*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1902*89c4ff92SAndroid Build Coastguard Worker
1903*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BlockSize == 0)
1904*89c4ff92SAndroid Build Coastguard Worker {
1905*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1906*89c4ff92SAndroid Build Coastguard Worker }
1907*89c4ff92SAndroid Build Coastguard Worker
1908*89c4ff92SAndroid Build Coastguard Worker DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1909*89c4ff92SAndroid Build Coastguard Worker const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1910*89c4ff92SAndroid Build Coastguard Worker const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1911*89c4ff92SAndroid Build Coastguard Worker const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
1912*89c4ff92SAndroid Build Coastguard Worker
1913*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = inputTensorInfo.GetShape();
1914*89c4ff92SAndroid Build Coastguard Worker if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
1915*89c4ff92SAndroid Build Coastguard Worker {
1916*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1917*89c4ff92SAndroid Build Coastguard Worker "by block size in all spatial dimensions");
1918*89c4ff92SAndroid Build Coastguard Worker }
1919*89c4ff92SAndroid Build Coastguard Worker
1920*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = outputTensorInfo.GetShape();
1921*89c4ff92SAndroid Build Coastguard Worker if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1922*89c4ff92SAndroid Build Coastguard Worker {
1923*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1924*89c4ff92SAndroid Build Coastguard Worker "must be divisible by the square of block size." );
1925*89c4ff92SAndroid Build Coastguard Worker }
1926*89c4ff92SAndroid Build Coastguard Worker }
1927*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1928*89c4ff92SAndroid Build Coastguard Worker void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1929*89c4ff92SAndroid Build Coastguard Worker {
1930*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"FloorQueueDescriptor"};
1931*89c4ff92SAndroid Build Coastguard Worker
1932*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
1933*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
1934*89c4ff92SAndroid Build Coastguard Worker
1935*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1936*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1937*89c4ff92SAndroid Build Coastguard Worker
1938*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1939*89c4ff92SAndroid Build Coastguard Worker {
1940*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1941*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1942*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1943*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1944*89c4ff92SAndroid Build Coastguard Worker };
1945*89c4ff92SAndroid Build Coastguard Worker
1946*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1947*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1948*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1949*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1950*89c4ff92SAndroid Build Coastguard Worker }
1951*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const1952*89c4ff92SAndroid Build Coastguard Worker void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1953*89c4ff92SAndroid Build Coastguard Worker {
1954*89c4ff92SAndroid Build Coastguard Worker // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1955*89c4ff92SAndroid Build Coastguard Worker
1956*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"LstmQueueDescriptor"};
1957*89c4ff92SAndroid Build Coastguard Worker
1958*89c4ff92SAndroid Build Coastguard Worker // check dimensions of all inputs and outputs
1959*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() != 3)
1960*89c4ff92SAndroid Build Coastguard Worker {
1961*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1962*89c4ff92SAndroid Build Coastguard Worker }
1963*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() != 4)
1964*89c4ff92SAndroid Build Coastguard Worker {
1965*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1966*89c4ff92SAndroid Build Coastguard Worker }
1967*89c4ff92SAndroid Build Coastguard Worker
1968*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
1969*89c4ff92SAndroid Build Coastguard Worker {
1970*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
1971*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
1972*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
1973*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
1974*89c4ff92SAndroid Build Coastguard Worker };
1975*89c4ff92SAndroid Build Coastguard Worker
1976*89c4ff92SAndroid Build Coastguard Worker // check for supported type of one input and match them with all the other input and output
1977*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1978*89c4ff92SAndroid Build Coastguard Worker
1979*89c4ff92SAndroid Build Coastguard Worker // type matches all other inputs
1980*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1981*89c4ff92SAndroid Build Coastguard Worker {
1982*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1983*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_InputTensorInfos[i],
1984*89c4ff92SAndroid Build Coastguard Worker descriptorName,
1985*89c4ff92SAndroid Build Coastguard Worker "input_0",
1986*89c4ff92SAndroid Build Coastguard Worker "input_" + std::to_string(i));
1987*89c4ff92SAndroid Build Coastguard Worker }
1988*89c4ff92SAndroid Build Coastguard Worker // type matches all other outputs
1989*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
1990*89c4ff92SAndroid Build Coastguard Worker {
1991*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1992*89c4ff92SAndroid Build Coastguard Worker workloadInfo.m_OutputTensorInfos[i],
1993*89c4ff92SAndroid Build Coastguard Worker "LstmQueueDescriptor",
1994*89c4ff92SAndroid Build Coastguard Worker "input_0",
1995*89c4ff92SAndroid Build Coastguard Worker "output_" + std::to_string(i));
1996*89c4ff92SAndroid Build Coastguard Worker }
1997*89c4ff92SAndroid Build Coastguard Worker
1998*89c4ff92SAndroid Build Coastguard Worker // Making sure clipping parameters have valid values.
1999*89c4ff92SAndroid Build Coastguard Worker // == 0 means no clipping
2000*89c4ff92SAndroid Build Coastguard Worker // > 0 means clipping
2001*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_ClippingThresCell < 0.0f)
2002*89c4ff92SAndroid Build Coastguard Worker {
2003*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2004*89c4ff92SAndroid Build Coastguard Worker }
2005*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_ClippingThresProj < 0.0f)
2006*89c4ff92SAndroid Build Coastguard Worker {
2007*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2008*89c4ff92SAndroid Build Coastguard Worker }
2009*89c4ff92SAndroid Build Coastguard Worker
2010*89c4ff92SAndroid Build Coastguard Worker // Inferring batch size, number of outputs and number of cells from the inputs.
2011*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2012*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2013*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2014*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2015*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2016*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2017*89c4ff92SAndroid Build Coastguard Worker
2018*89c4ff92SAndroid Build Coastguard Worker // input tensor
2019*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2020*89c4ff92SAndroid Build Coastguard Worker descriptorName + " input_0");
2021*89c4ff92SAndroid Build Coastguard Worker // outputStateInTensor
2022*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2023*89c4ff92SAndroid Build Coastguard Worker descriptorName + " input_1");
2024*89c4ff92SAndroid Build Coastguard Worker // outputStateInTensor
2025*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2026*89c4ff92SAndroid Build Coastguard Worker descriptorName + " input_2");
2027*89c4ff92SAndroid Build Coastguard Worker // scratchBufferTensor
2028*89c4ff92SAndroid Build Coastguard Worker unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
2029*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2030*89c4ff92SAndroid Build Coastguard Worker descriptorName + " output_0");
2031*89c4ff92SAndroid Build Coastguard Worker // outputStateOutTensor
2032*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2033*89c4ff92SAndroid Build Coastguard Worker descriptorName + " output_1");
2034*89c4ff92SAndroid Build Coastguard Worker // cellStateOutTensor
2035*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2036*89c4ff92SAndroid Build Coastguard Worker descriptorName + " output_2");
2037*89c4ff92SAndroid Build Coastguard Worker // outputTensor
2038*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2039*89c4ff92SAndroid Build Coastguard Worker descriptorName + " output_3");
2040*89c4ff92SAndroid Build Coastguard Worker
2041*89c4ff92SAndroid Build Coastguard Worker // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2042*89c4ff92SAndroid Build Coastguard Worker if ( m_InputToInputWeights )
2043*89c4ff92SAndroid Build Coastguard Worker {
2044*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2045*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_input), "InputLayerNormWeights");
2046*89c4ff92SAndroid Build Coastguard Worker }
2047*89c4ff92SAndroid Build Coastguard Worker
2048*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2049*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2050*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_input), "InputToForgetWeights");
2051*89c4ff92SAndroid Build Coastguard Worker
2052*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2053*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2054*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_input), "InputToCellWeights");
2055*89c4ff92SAndroid Build Coastguard Worker
2056*89c4ff92SAndroid Build Coastguard Worker if ( m_RecurrentToInputWeights )
2057*89c4ff92SAndroid Build Coastguard Worker {
2058*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2059*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "RecurrentToInputWeights");
2060*89c4ff92SAndroid Build Coastguard Worker }
2061*89c4ff92SAndroid Build Coastguard Worker
2062*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2063*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2064*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "RecurrentToForgetWeights");
2065*89c4ff92SAndroid Build Coastguard Worker
2066*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2067*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2068*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "RecurrentToCellWeights");
2069*89c4ff92SAndroid Build Coastguard Worker
2070*89c4ff92SAndroid Build Coastguard Worker // Make sure the input-gate's parameters are either both present (regular
2071*89c4ff92SAndroid Build Coastguard Worker // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2072*89c4ff92SAndroid Build Coastguard Worker bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2073*89c4ff92SAndroid Build Coastguard Worker !m_Parameters.m_CifgEnabled) ||
2074*89c4ff92SAndroid Build Coastguard Worker (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2075*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_CifgEnabled));
2076*89c4ff92SAndroid Build Coastguard Worker if (!cifg_weights_all_or_none)
2077*89c4ff92SAndroid Build Coastguard Worker {
2078*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2079*89c4ff92SAndroid Build Coastguard Worker "RecurrentToInputWeights must either both be present (regular LSTM) "
2080*89c4ff92SAndroid Build Coastguard Worker "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2081*89c4ff92SAndroid Build Coastguard Worker "accordingly.");
2082*89c4ff92SAndroid Build Coastguard Worker }
2083*89c4ff92SAndroid Build Coastguard Worker
2084*89c4ff92SAndroid Build Coastguard Worker if ( m_CellToInputWeights )
2085*89c4ff92SAndroid Build Coastguard Worker {
2086*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2087*89c4ff92SAndroid Build Coastguard Worker n_cell, "CellToInputWeights");
2088*89c4ff92SAndroid Build Coastguard Worker }
2089*89c4ff92SAndroid Build Coastguard Worker if ( m_CellToForgetWeights )
2090*89c4ff92SAndroid Build Coastguard Worker {
2091*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2092*89c4ff92SAndroid Build Coastguard Worker n_cell, "CellToForgetWeights");
2093*89c4ff92SAndroid Build Coastguard Worker }
2094*89c4ff92SAndroid Build Coastguard Worker if ( m_CellToOutputWeights )
2095*89c4ff92SAndroid Build Coastguard Worker {
2096*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2097*89c4ff92SAndroid Build Coastguard Worker n_cell, "CellToOutputWeights");
2098*89c4ff92SAndroid Build Coastguard Worker }
2099*89c4ff92SAndroid Build Coastguard Worker
2100*89c4ff92SAndroid Build Coastguard Worker // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2101*89c4ff92SAndroid Build Coastguard Worker bool peephole_weights_all_or_none =
2102*89c4ff92SAndroid Build Coastguard Worker (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2103*89c4ff92SAndroid Build Coastguard Worker && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2104*89c4ff92SAndroid Build Coastguard Worker || ( !m_CellToInputWeights && !m_CellToForgetWeights
2105*89c4ff92SAndroid Build Coastguard Worker && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2106*89c4ff92SAndroid Build Coastguard Worker if (!peephole_weights_all_or_none)
2107*89c4ff92SAndroid Build Coastguard Worker {
2108*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
2109*89c4ff92SAndroid Build Coastguard Worker }
2110*89c4ff92SAndroid Build Coastguard Worker
2111*89c4ff92SAndroid Build Coastguard Worker // Make sure the input gate bias is present only when not a CIFG-LSTM.
2112*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_CifgEnabled)
2113*89c4ff92SAndroid Build Coastguard Worker {
2114*89c4ff92SAndroid Build Coastguard Worker if (m_InputGateBias)
2115*89c4ff92SAndroid Build Coastguard Worker {
2116*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
2117*89c4ff92SAndroid Build Coastguard Worker }
2118*89c4ff92SAndroid Build Coastguard Worker }
2119*89c4ff92SAndroid Build Coastguard Worker else
2120*89c4ff92SAndroid Build Coastguard Worker {
2121*89c4ff92SAndroid Build Coastguard Worker if (!m_InputGateBias)
2122*89c4ff92SAndroid Build Coastguard Worker {
2123*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2124*89c4ff92SAndroid Build Coastguard Worker "must be present.");
2125*89c4ff92SAndroid Build Coastguard Worker }
2126*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2127*89c4ff92SAndroid Build Coastguard Worker n_cell, "InputGateBias");
2128*89c4ff92SAndroid Build Coastguard Worker }
2129*89c4ff92SAndroid Build Coastguard Worker
2130*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2131*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2132*89c4ff92SAndroid Build Coastguard Worker
2133*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2134*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2135*89c4ff92SAndroid Build Coastguard Worker
2136*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2137*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2138*89c4ff92SAndroid Build Coastguard Worker
2139*89c4ff92SAndroid Build Coastguard Worker if (m_ProjectionWeights)
2140*89c4ff92SAndroid Build Coastguard Worker {
2141*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2142*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "ProjectionWeights");
2143*89c4ff92SAndroid Build Coastguard Worker }
2144*89c4ff92SAndroid Build Coastguard Worker if (m_ProjectionBias)
2145*89c4ff92SAndroid Build Coastguard Worker {
2146*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2147*89c4ff92SAndroid Build Coastguard Worker }
2148*89c4ff92SAndroid Build Coastguard Worker
2149*89c4ff92SAndroid Build Coastguard Worker // Making sure the projection tensors are consistent:
2150*89c4ff92SAndroid Build Coastguard Worker // 1) If projection weight is not present, then projection bias should not be
2151*89c4ff92SAndroid Build Coastguard Worker // present.
2152*89c4ff92SAndroid Build Coastguard Worker // 2) If projection weight is present, then projection bias is optional.
2153*89c4ff92SAndroid Build Coastguard Worker bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2154*89c4ff92SAndroid Build Coastguard Worker !m_Parameters.m_ProjectionEnabled)
2155*89c4ff92SAndroid Build Coastguard Worker || (m_ProjectionWeights && !m_ProjectionBias &&
2156*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_ProjectionEnabled)
2157*89c4ff92SAndroid Build Coastguard Worker || (m_ProjectionWeights && m_ProjectionBias &&
2158*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_ProjectionEnabled));
2159*89c4ff92SAndroid Build Coastguard Worker if (!projecton_tensors_consistent)
2160*89c4ff92SAndroid Build Coastguard Worker {
2161*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
2162*89c4ff92SAndroid Build Coastguard Worker }
2163*89c4ff92SAndroid Build Coastguard Worker
2164*89c4ff92SAndroid Build Coastguard Worker // The four layer normalization weights either all have values or none of them have values. Additionally, if
2165*89c4ff92SAndroid Build Coastguard Worker // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2166*89c4ff92SAndroid Build Coastguard Worker // either all have values or none of them have values. Layer normalization is used when the values of all the
2167*89c4ff92SAndroid Build Coastguard Worker // layer normalization weights are present
2168*89c4ff92SAndroid Build Coastguard Worker if (m_InputLayerNormWeights)
2169*89c4ff92SAndroid Build Coastguard Worker {
2170*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2171*89c4ff92SAndroid Build Coastguard Worker }
2172*89c4ff92SAndroid Build Coastguard Worker if (m_ForgetLayerNormWeights)
2173*89c4ff92SAndroid Build Coastguard Worker {
2174*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2175*89c4ff92SAndroid Build Coastguard Worker }
2176*89c4ff92SAndroid Build Coastguard Worker if (m_CellLayerNormWeights)
2177*89c4ff92SAndroid Build Coastguard Worker {
2178*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2179*89c4ff92SAndroid Build Coastguard Worker }
2180*89c4ff92SAndroid Build Coastguard Worker if (m_OutputLayerNormWeights)
2181*89c4ff92SAndroid Build Coastguard Worker {
2182*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2183*89c4ff92SAndroid Build Coastguard Worker }
2184*89c4ff92SAndroid Build Coastguard Worker
2185*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_LayerNormEnabled)
2186*89c4ff92SAndroid Build Coastguard Worker {
2187*89c4ff92SAndroid Build Coastguard Worker if (!m_Parameters.m_CifgEnabled)
2188*89c4ff92SAndroid Build Coastguard Worker {
2189*89c4ff92SAndroid Build Coastguard Worker if (!m_InputLayerNormWeights)
2190*89c4ff92SAndroid Build Coastguard Worker {
2191*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2192*89c4ff92SAndroid Build Coastguard Worker "disabled but InputLayerNormWeights are not present");
2193*89c4ff92SAndroid Build Coastguard Worker }
2194*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2195*89c4ff92SAndroid Build Coastguard Worker 1, n_cell, "InputLayerNormWeights");
2196*89c4ff92SAndroid Build Coastguard Worker }
2197*89c4ff92SAndroid Build Coastguard Worker else if (m_InputLayerNormWeights)
2198*89c4ff92SAndroid Build Coastguard Worker {
2199*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2200*89c4ff92SAndroid Build Coastguard Worker "enabled");
2201*89c4ff92SAndroid Build Coastguard Worker }
2202*89c4ff92SAndroid Build Coastguard Worker
2203*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2204*89c4ff92SAndroid Build Coastguard Worker "ForgetLayerNormWeights");
2205*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2206*89c4ff92SAndroid Build Coastguard Worker
2207*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2208*89c4ff92SAndroid Build Coastguard Worker "OutputLayerNormWeights");
2209*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2210*89c4ff92SAndroid Build Coastguard Worker
2211*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2212*89c4ff92SAndroid Build Coastguard Worker "CellLayerNormWeights");
2213*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2214*89c4ff92SAndroid Build Coastguard Worker }
2215*89c4ff92SAndroid Build Coastguard Worker else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2216*89c4ff92SAndroid Build Coastguard Worker {
2217*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2218*89c4ff92SAndroid Build Coastguard Worker "normalisation weights are present.");
2219*89c4ff92SAndroid Build Coastguard Worker }
2220*89c4ff92SAndroid Build Coastguard Worker }
2221*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2222*89c4ff92SAndroid Build Coastguard Worker void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2223*89c4ff92SAndroid Build Coastguard Worker {
2224*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
2225*89c4ff92SAndroid Build Coastguard Worker
2226*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2227*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2228*89c4ff92SAndroid Build Coastguard Worker
2229*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2230*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2231*89c4ff92SAndroid Build Coastguard Worker
2232*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetDataType() != DataType::Float32)
2233*89c4ff92SAndroid Build Coastguard Worker {
2234*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2235*89c4ff92SAndroid Build Coastguard Worker }
2236*89c4ff92SAndroid Build Coastguard Worker
2237*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Float16)
2238*89c4ff92SAndroid Build Coastguard Worker {
2239*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
2240*89c4ff92SAndroid Build Coastguard Worker }
2241*89c4ff92SAndroid Build Coastguard Worker
2242*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2243*89c4ff92SAndroid Build Coastguard Worker }
2244*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2245*89c4ff92SAndroid Build Coastguard Worker void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2246*89c4ff92SAndroid Build Coastguard Worker {
2247*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
2248*89c4ff92SAndroid Build Coastguard Worker
2249*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2250*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2251*89c4ff92SAndroid Build Coastguard Worker
2252*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2253*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2254*89c4ff92SAndroid Build Coastguard Worker
2255*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetDataType() != DataType::Float16)
2256*89c4ff92SAndroid Build Coastguard Worker {
2257*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
2258*89c4ff92SAndroid Build Coastguard Worker }
2259*89c4ff92SAndroid Build Coastguard Worker
2260*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Float32)
2261*89c4ff92SAndroid Build Coastguard Worker {
2262*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2263*89c4ff92SAndroid Build Coastguard Worker }
2264*89c4ff92SAndroid Build Coastguard Worker
2265*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2266*89c4ff92SAndroid Build Coastguard Worker }
2267*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2268*89c4ff92SAndroid Build Coastguard Worker void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2269*89c4ff92SAndroid Build Coastguard Worker {
2270*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"DivisionQueueDescriptor"};
2271*89c4ff92SAndroid Build Coastguard Worker
2272*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2273*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2274*89c4ff92SAndroid Build Coastguard Worker
2275*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2276*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2277*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2278*89c4ff92SAndroid Build Coastguard Worker
2279*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2280*89c4ff92SAndroid Build Coastguard Worker {
2281*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2282*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2283*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2284*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2285*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2286*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2287*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2288*89c4ff92SAndroid Build Coastguard Worker };
2289*89c4ff92SAndroid Build Coastguard Worker
2290*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2291*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2292*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2293*89c4ff92SAndroid Build Coastguard Worker
2294*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2295*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
2296*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2297*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2298*89c4ff92SAndroid Build Coastguard Worker "input_0",
2299*89c4ff92SAndroid Build Coastguard Worker "input_1");
2300*89c4ff92SAndroid Build Coastguard Worker }
2301*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2302*89c4ff92SAndroid Build Coastguard Worker void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303*89c4ff92SAndroid Build Coastguard Worker {
2304*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"SubtractionQueueDescriptor"};
2305*89c4ff92SAndroid Build Coastguard Worker
2306*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2307*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308*89c4ff92SAndroid Build Coastguard Worker
2309*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2310*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2311*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312*89c4ff92SAndroid Build Coastguard Worker
2313*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2314*89c4ff92SAndroid Build Coastguard Worker {
2315*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2316*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2317*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2318*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2319*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2320*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2321*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
2322*89c4ff92SAndroid Build Coastguard Worker };
2323*89c4ff92SAndroid Build Coastguard Worker
2324*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2325*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2326*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2327*89c4ff92SAndroid Build Coastguard Worker
2328*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2329*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
2330*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2331*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2332*89c4ff92SAndroid Build Coastguard Worker "input_0",
2333*89c4ff92SAndroid Build Coastguard Worker "input_1");
2334*89c4ff92SAndroid Build Coastguard Worker }
2335*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2336*89c4ff92SAndroid Build Coastguard Worker void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2337*89c4ff92SAndroid Build Coastguard Worker {
2338*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"MaximumQueueDescriptor"};
2339*89c4ff92SAndroid Build Coastguard Worker
2340*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2341*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2342*89c4ff92SAndroid Build Coastguard Worker
2343*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2344*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2345*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2346*89c4ff92SAndroid Build Coastguard Worker
2347*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2348*89c4ff92SAndroid Build Coastguard Worker {
2349*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2350*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2351*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2352*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2353*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2354*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2355*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2356*89c4ff92SAndroid Build Coastguard Worker };
2357*89c4ff92SAndroid Build Coastguard Worker
2358*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2359*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2360*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2361*89c4ff92SAndroid Build Coastguard Worker
2362*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2363*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
2364*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2365*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2366*89c4ff92SAndroid Build Coastguard Worker "input_0",
2367*89c4ff92SAndroid Build Coastguard Worker "input_1");
2368*89c4ff92SAndroid Build Coastguard Worker }
2369*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2370*89c4ff92SAndroid Build Coastguard Worker void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2371*89c4ff92SAndroid Build Coastguard Worker {
2372*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"MeanQueueDescriptor"};
2373*89c4ff92SAndroid Build Coastguard Worker
2374*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2375*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2376*89c4ff92SAndroid Build Coastguard Worker
2377*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2378*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2379*89c4ff92SAndroid Build Coastguard Worker
2380*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2381*89c4ff92SAndroid Build Coastguard Worker {
2382*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2383*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2384*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2385*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2386*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2387*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2388*89c4ff92SAndroid Build Coastguard Worker };
2389*89c4ff92SAndroid Build Coastguard Worker
2390*89c4ff92SAndroid Build Coastguard Worker // First check if input tensor data type is supported, then
2391*89c4ff92SAndroid Build Coastguard Worker // check if this data type matches the output tensor data type
2392*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2393*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2394*89c4ff92SAndroid Build Coastguard Worker
2395*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_KeepDims)
2396*89c4ff92SAndroid Build Coastguard Worker {
2397*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2398*89c4ff92SAndroid Build Coastguard Worker }
2399*89c4ff92SAndroid Build Coastguard Worker else if (m_Parameters.m_Axis.empty())
2400*89c4ff92SAndroid Build Coastguard Worker {
2401*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
2402*89c4ff92SAndroid Build Coastguard Worker }
2403*89c4ff92SAndroid Build Coastguard Worker else
2404*89c4ff92SAndroid Build Coastguard Worker {
2405*89c4ff92SAndroid Build Coastguard Worker unsigned int outputDim =
2406*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2407*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo,
2408*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2409*89c4ff92SAndroid Build Coastguard Worker outputDim > 0 ? outputDim : 1,
2410*89c4ff92SAndroid Build Coastguard Worker "output");
2411*89c4ff92SAndroid Build Coastguard Worker }
2412*89c4ff92SAndroid Build Coastguard Worker }
2413*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2414*89c4ff92SAndroid Build Coastguard Worker void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2415*89c4ff92SAndroid Build Coastguard Worker {
2416*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"PadQueueDescriptor"};
2417*89c4ff92SAndroid Build Coastguard Worker
2418*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2419*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2420*89c4ff92SAndroid Build Coastguard Worker
2421*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2422*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2423*89c4ff92SAndroid Build Coastguard Worker
2424*89c4ff92SAndroid Build Coastguard Worker // input and output should have the same number of dimensions
2425*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2426*89c4ff92SAndroid Build Coastguard Worker
2427*89c4ff92SAndroid Build Coastguard Worker // there should be entry in the pad list for each dimension in the input tensor
2428*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2429*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2430*89c4ff92SAndroid Build Coastguard Worker "as there are dimensions in the input tensor that is " +
2431*89c4ff92SAndroid Build Coastguard Worker std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2432*89c4ff92SAndroid Build Coastguard Worker " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
2433*89c4ff92SAndroid Build Coastguard Worker }
2434*89c4ff92SAndroid Build Coastguard Worker }
2435*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2436*89c4ff92SAndroid Build Coastguard Worker void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2437*89c4ff92SAndroid Build Coastguard Worker {
2438*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"QuantizeQueueDescriptor"};
2439*89c4ff92SAndroid Build Coastguard Worker
2440*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2441*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2442*89c4ff92SAndroid Build Coastguard Worker
2443*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2444*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2445*89c4ff92SAndroid Build Coastguard Worker
2446*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2447*89c4ff92SAndroid Build Coastguard Worker {
2448*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2449*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2450*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2451*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2452*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2453*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2454*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2455*89c4ff92SAndroid Build Coastguard Worker };
2456*89c4ff92SAndroid Build Coastguard Worker
2457*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2458*89c4ff92SAndroid Build Coastguard Worker
2459*89c4ff92SAndroid Build Coastguard Worker if (!IsQuantizedType(outputTensorInfo.GetDataType()))
2460*89c4ff92SAndroid Build Coastguard Worker {
2461*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
2462*89c4ff92SAndroid Build Coastguard Worker }
2463*89c4ff92SAndroid Build Coastguard Worker }
2464*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2465*89c4ff92SAndroid Build Coastguard Worker void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2466*89c4ff92SAndroid Build Coastguard Worker {
2467*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
2468*89c4ff92SAndroid Build Coastguard Worker
2469*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2470*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2471*89c4ff92SAndroid Build Coastguard Worker
2472*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2473*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2474*89c4ff92SAndroid Build Coastguard Worker
2475*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2476*89c4ff92SAndroid Build Coastguard Worker {
2477*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2478*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2479*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2480*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2481*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2482*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2483*89c4ff92SAndroid Build Coastguard Worker };
2484*89c4ff92SAndroid Build Coastguard Worker
2485*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2486*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2487*89c4ff92SAndroid Build Coastguard Worker }
2488*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2489*89c4ff92SAndroid Build Coastguard Worker void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2490*89c4ff92SAndroid Build Coastguard Worker {
2491*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"StridedSliceQueueDescriptor"};
2492*89c4ff92SAndroid Build Coastguard Worker
2493*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2494*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2495*89c4ff92SAndroid Build Coastguard Worker
2496*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2497*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2498*89c4ff92SAndroid Build Coastguard Worker
2499*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2500*89c4ff92SAndroid Build Coastguard Worker {
2501*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2502*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2503*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2504*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2505*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2506*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2507*89c4ff92SAndroid Build Coastguard Worker };
2508*89c4ff92SAndroid Build Coastguard Worker
2509*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2510*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2511*89c4ff92SAndroid Build Coastguard Worker
2512*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2513*89c4ff92SAndroid Build Coastguard Worker
2514*89c4ff92SAndroid Build Coastguard Worker const uint32_t rank = inputTensorInfo.GetNumDimensions();
2515*89c4ff92SAndroid Build Coastguard Worker if (rank > 4)
2516*89c4ff92SAndroid Build Coastguard Worker {
2517*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2518*89c4ff92SAndroid Build Coastguard Worker }
2519*89c4ff92SAndroid Build Coastguard Worker
2520*89c4ff92SAndroid Build Coastguard Worker // Begin, End & Stride length must be of rank(input0)
2521*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Begin.size() != rank)
2522*89c4ff92SAndroid Build Coastguard Worker {
2523*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
2524*89c4ff92SAndroid Build Coastguard Worker }
2525*89c4ff92SAndroid Build Coastguard Worker
2526*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_End.size() != rank)
2527*89c4ff92SAndroid Build Coastguard Worker {
2528*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
2529*89c4ff92SAndroid Build Coastguard Worker }
2530*89c4ff92SAndroid Build Coastguard Worker
2531*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Stride.size() != rank)
2532*89c4ff92SAndroid Build Coastguard Worker {
2533*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
2534*89c4ff92SAndroid Build Coastguard Worker }
2535*89c4ff92SAndroid Build Coastguard Worker
2536*89c4ff92SAndroid Build Coastguard Worker // Stride entries must be non-zero
2537*89c4ff92SAndroid Build Coastguard Worker for (auto& stride : m_Parameters.m_Stride)
2538*89c4ff92SAndroid Build Coastguard Worker {
2539*89c4ff92SAndroid Build Coastguard Worker if (stride == 0)
2540*89c4ff92SAndroid Build Coastguard Worker {
2541*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
2542*89c4ff92SAndroid Build Coastguard Worker }
2543*89c4ff92SAndroid Build Coastguard Worker }
2544*89c4ff92SAndroid Build Coastguard Worker }
2545*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2546*89c4ff92SAndroid Build Coastguard Worker void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2547*89c4ff92SAndroid Build Coastguard Worker {
2548*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"MinimumQueueDescriptor"};
2549*89c4ff92SAndroid Build Coastguard Worker
2550*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2551*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2552*89c4ff92SAndroid Build Coastguard Worker
2553*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2554*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2555*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2556*89c4ff92SAndroid Build Coastguard Worker
2557*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2558*89c4ff92SAndroid Build Coastguard Worker {
2559*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2560*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2561*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2562*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2563*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2564*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2565*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2566*89c4ff92SAndroid Build Coastguard Worker };
2567*89c4ff92SAndroid Build Coastguard Worker
2568*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2569*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2570*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2571*89c4ff92SAndroid Build Coastguard Worker
2572*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2573*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
2574*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2575*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2576*89c4ff92SAndroid Build Coastguard Worker "input_0",
2577*89c4ff92SAndroid Build Coastguard Worker "input_1");
2578*89c4ff92SAndroid Build Coastguard Worker }
2579*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2580*89c4ff92SAndroid Build Coastguard Worker void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2581*89c4ff92SAndroid Build Coastguard Worker {
2582*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"DebugQueueDescriptor"};
2583*89c4ff92SAndroid Build Coastguard Worker
2584*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2585*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2586*89c4ff92SAndroid Build Coastguard Worker }
2587*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2588*89c4ff92SAndroid Build Coastguard Worker void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2589*89c4ff92SAndroid Build Coastguard Worker {
2590*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"EqualQueueDescriptor"};
2591*89c4ff92SAndroid Build Coastguard Worker
2592*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2593*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2594*89c4ff92SAndroid Build Coastguard Worker
2595*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2596*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2597*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2598*89c4ff92SAndroid Build Coastguard Worker
2599*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2600*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
2601*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2602*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2603*89c4ff92SAndroid Build Coastguard Worker "input_0",
2604*89c4ff92SAndroid Build Coastguard Worker "input_1");
2605*89c4ff92SAndroid Build Coastguard Worker
2606*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Boolean)
2607*89c4ff92SAndroid Build Coastguard Worker {
2608*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2609*89c4ff92SAndroid Build Coastguard Worker }
2610*89c4ff92SAndroid Build Coastguard Worker }
2611*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2612*89c4ff92SAndroid Build Coastguard Worker void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2613*89c4ff92SAndroid Build Coastguard Worker {
2614*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"GreaterQueueDescriptor"};
2615*89c4ff92SAndroid Build Coastguard Worker
2616*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2617*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2618*89c4ff92SAndroid Build Coastguard Worker
2619*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2620*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2621*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2622*89c4ff92SAndroid Build Coastguard Worker
2623*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2624*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
2625*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2626*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2627*89c4ff92SAndroid Build Coastguard Worker "input_0",
2628*89c4ff92SAndroid Build Coastguard Worker "input_1");
2629*89c4ff92SAndroid Build Coastguard Worker
2630*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Boolean)
2631*89c4ff92SAndroid Build Coastguard Worker {
2632*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2633*89c4ff92SAndroid Build Coastguard Worker }
2634*89c4ff92SAndroid Build Coastguard Worker }
2635*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2636*89c4ff92SAndroid Build Coastguard Worker void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2637*89c4ff92SAndroid Build Coastguard Worker {
2638*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"RsqrtQueueDescriptor"};
2639*89c4ff92SAndroid Build Coastguard Worker
2640*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2641*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2642*89c4ff92SAndroid Build Coastguard Worker
2643*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2644*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2645*89c4ff92SAndroid Build Coastguard Worker
2646*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2647*89c4ff92SAndroid Build Coastguard Worker
2648*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2649*89c4ff92SAndroid Build Coastguard Worker {
2650*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2651*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2652*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2653*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2654*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2655*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2656*89c4ff92SAndroid Build Coastguard Worker };
2657*89c4ff92SAndroid Build Coastguard Worker
2658*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2659*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2660*89c4ff92SAndroid Build Coastguard Worker }
2661*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2662*89c4ff92SAndroid Build Coastguard Worker void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2663*89c4ff92SAndroid Build Coastguard Worker {
2664*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"GatherNdQueueDescriptor"};
2665*89c4ff92SAndroid Build Coastguard Worker
2666*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2667*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2668*89c4ff92SAndroid Build Coastguard Worker
2669*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2670*89c4ff92SAndroid Build Coastguard Worker if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2671*89c4ff92SAndroid Build Coastguard Worker {
2672*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2673*89c4ff92SAndroid Build Coastguard Worker }
2674*89c4ff92SAndroid Build Coastguard Worker
2675*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2676*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2677*89c4ff92SAndroid Build Coastguard Worker
2678*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2679*89c4ff92SAndroid Build Coastguard Worker {
2680*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2681*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2682*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2683*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2684*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2685*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2686*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
2687*89c4ff92SAndroid Build Coastguard Worker };
2688*89c4ff92SAndroid Build Coastguard Worker
2689*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2690*89c4ff92SAndroid Build Coastguard Worker
2691*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2692*89c4ff92SAndroid Build Coastguard Worker
2693*89c4ff92SAndroid Build Coastguard Worker unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2694*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2695*89c4ff92SAndroid Build Coastguard Worker }
2696*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2697*89c4ff92SAndroid Build Coastguard Worker void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2698*89c4ff92SAndroid Build Coastguard Worker {
2699*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"GatherQueueDescriptor"};
2700*89c4ff92SAndroid Build Coastguard Worker
2701*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2702*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2703*89c4ff92SAndroid Build Coastguard Worker
2704*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2705*89c4ff92SAndroid Build Coastguard Worker if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2706*89c4ff92SAndroid Build Coastguard Worker {
2707*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2708*89c4ff92SAndroid Build Coastguard Worker }
2709*89c4ff92SAndroid Build Coastguard Worker
2710*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2711*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2712*89c4ff92SAndroid Build Coastguard Worker
2713*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2714*89c4ff92SAndroid Build Coastguard Worker {
2715*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2716*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2717*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2718*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2719*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2720*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2721*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32,
2722*89c4ff92SAndroid Build Coastguard Worker };
2723*89c4ff92SAndroid Build Coastguard Worker
2724*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2725*89c4ff92SAndroid Build Coastguard Worker
2726*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2727*89c4ff92SAndroid Build Coastguard Worker
2728*89c4ff92SAndroid Build Coastguard Worker unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2729*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2730*89c4ff92SAndroid Build Coastguard Worker }
2731*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2732*89c4ff92SAndroid Build Coastguard Worker void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2733*89c4ff92SAndroid Build Coastguard Worker {
2734*89c4ff92SAndroid Build Coastguard Worker const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2735*89c4ff92SAndroid Build Coastguard Worker
2736*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2737*89c4ff92SAndroid Build Coastguard Worker
2738*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() != 4)
2739*89c4ff92SAndroid Build Coastguard Worker {
2740*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
2741*89c4ff92SAndroid Build Coastguard Worker to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2742*89c4ff92SAndroid Build Coastguard Worker }
2743*89c4ff92SAndroid Build Coastguard Worker
2744*89c4ff92SAndroid Build Coastguard Worker if (m_Anchors == nullptr)
2745*89c4ff92SAndroid Build Coastguard Worker {
2746*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
2747*89c4ff92SAndroid Build Coastguard Worker }
2748*89c4ff92SAndroid Build Coastguard Worker
2749*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
2750*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2751*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2752*89c4ff92SAndroid Build Coastguard Worker
2753*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
2754*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
2755*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2756*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
2757*89c4ff92SAndroid Build Coastguard Worker
2758*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2759*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2760*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
2761*89c4ff92SAndroid Build Coastguard Worker
2762*89c4ff92SAndroid Build Coastguard Worker const std::vector<DataType> supportedInputTypes =
2763*89c4ff92SAndroid Build Coastguard Worker {
2764*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2765*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2766*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2767*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2768*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2769*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2770*89c4ff92SAndroid Build Coastguard Worker };
2771*89c4ff92SAndroid Build Coastguard Worker
2772*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2773*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2774*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2775*89c4ff92SAndroid Build Coastguard Worker
2776*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2777*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2778*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2779*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2780*89c4ff92SAndroid Build Coastguard Worker
2781*89c4ff92SAndroid Build Coastguard Worker // NOTE: Output is always Float32 regardless of input type
2782*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2783*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2784*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2785*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
2786*89c4ff92SAndroid Build Coastguard Worker
2787*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2788*89c4ff92SAndroid Build Coastguard Worker {
2789*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
2790*89c4ff92SAndroid Build Coastguard Worker "must be positive and less than or equal to 1.");
2791*89c4ff92SAndroid Build Coastguard Worker }
2792*89c4ff92SAndroid Build Coastguard Worker
2793*89c4ff92SAndroid Build Coastguard Worker if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2794*89c4ff92SAndroid Build Coastguard Worker {
2795*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Number of classes with background "
2796*89c4ff92SAndroid Build Coastguard Worker "should be equal to number of classes + 1.");
2797*89c4ff92SAndroid Build Coastguard Worker }
2798*89c4ff92SAndroid Build Coastguard Worker }
2799*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2800*89c4ff92SAndroid Build Coastguard Worker void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2801*89c4ff92SAndroid Build Coastguard Worker {
2802*89c4ff92SAndroid Build Coastguard Worker const std::string& descriptorName{"DequantizeQueueDescriptor"};
2803*89c4ff92SAndroid Build Coastguard Worker
2804*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2805*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2806*89c4ff92SAndroid Build Coastguard Worker
2807*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2808*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2809*89c4ff92SAndroid Build Coastguard Worker
2810*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> inputSupportedTypes =
2811*89c4ff92SAndroid Build Coastguard Worker {
2812*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2813*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2814*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2815*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2816*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
2817*89c4ff92SAndroid Build Coastguard Worker };
2818*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
2819*89c4ff92SAndroid Build Coastguard Worker
2820*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> outputSupportedTypes =
2821*89c4ff92SAndroid Build Coastguard Worker {
2822*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2823*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2824*89c4ff92SAndroid Build Coastguard Worker DataType::Float16
2825*89c4ff92SAndroid Build Coastguard Worker };
2826*89c4ff92SAndroid Build Coastguard Worker
2827*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
2828*89c4ff92SAndroid Build Coastguard Worker }
2829*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2830*89c4ff92SAndroid Build Coastguard Worker void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2831*89c4ff92SAndroid Build Coastguard Worker {
2832*89c4ff92SAndroid Build Coastguard Worker const std::string& descriptorName{"MergeQueueDescriptor"};
2833*89c4ff92SAndroid Build Coastguard Worker
2834*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2835*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2836*89c4ff92SAndroid Build Coastguard Worker
2837*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2838*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2839*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2840*89c4ff92SAndroid Build Coastguard Worker
2841*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2842*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2843*89c4ff92SAndroid Build Coastguard Worker
2844*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2845*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2846*89c4ff92SAndroid Build Coastguard Worker }
2847*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2848*89c4ff92SAndroid Build Coastguard Worker void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2849*89c4ff92SAndroid Build Coastguard Worker {
2850*89c4ff92SAndroid Build Coastguard Worker const std::string& descriptorName{"ShapeQueueDescriptor"};
2851*89c4ff92SAndroid Build Coastguard Worker
2852*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2853*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2854*89c4ff92SAndroid Build Coastguard Worker
2855*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2856*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2857*89c4ff92SAndroid Build Coastguard Worker
2858*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2859*89c4ff92SAndroid Build Coastguard Worker {
2860*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2861*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2862*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2863*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2864*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2865*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2866*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
2867*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
2868*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
2869*89c4ff92SAndroid Build Coastguard Worker };
2870*89c4ff92SAndroid Build Coastguard Worker
2871*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2872*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2873*89c4ff92SAndroid Build Coastguard Worker }
2874*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2875*89c4ff92SAndroid Build Coastguard Worker void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2876*89c4ff92SAndroid Build Coastguard Worker {
2877*89c4ff92SAndroid Build Coastguard Worker const std::string& descriptorName{"SwitchQueueDescriptor"};
2878*89c4ff92SAndroid Build Coastguard Worker
2879*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2880*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 2);
2881*89c4ff92SAndroid Build Coastguard Worker
2882*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2883*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2884*89c4ff92SAndroid Build Coastguard Worker
2885*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2886*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2887*89c4ff92SAndroid Build Coastguard Worker
2888*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2889*89c4ff92SAndroid Build Coastguard Worker {
2890*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2891*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2892*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2893*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2894*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2895*89c4ff92SAndroid Build Coastguard Worker };
2896*89c4ff92SAndroid Build Coastguard Worker
2897*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2898*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2899*89c4ff92SAndroid Build Coastguard Worker
2900*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2901*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
2902*89c4ff92SAndroid Build Coastguard Worker
2903*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo0,
2904*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo0,
2905*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2906*89c4ff92SAndroid Build Coastguard Worker "input_0",
2907*89c4ff92SAndroid Build Coastguard Worker "output_0");
2908*89c4ff92SAndroid Build Coastguard Worker
2909*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo0,
2910*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo1,
2911*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2912*89c4ff92SAndroid Build Coastguard Worker "input_0",
2913*89c4ff92SAndroid Build Coastguard Worker "output_1");
2914*89c4ff92SAndroid Build Coastguard Worker }
2915*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo &) const2916*89c4ff92SAndroid Build Coastguard Worker void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
2917*89c4ff92SAndroid Build Coastguard Worker {
2918*89c4ff92SAndroid Build Coastguard Worker // This is internally generated so it should not need validation.
2919*89c4ff92SAndroid Build Coastguard Worker }
2920*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2921*89c4ff92SAndroid Build Coastguard Worker void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2922*89c4ff92SAndroid Build Coastguard Worker {
2923*89c4ff92SAndroid Build Coastguard Worker const std::string& descriptorName{"PreluQueueDescriptor"};
2924*89c4ff92SAndroid Build Coastguard Worker
2925*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
2926*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2927*89c4ff92SAndroid Build Coastguard Worker
2928*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2929*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2930*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2931*89c4ff92SAndroid Build Coastguard Worker
2932*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes
2933*89c4ff92SAndroid Build Coastguard Worker {
2934*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2935*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
2936*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
2937*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
2938*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
2939*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
2940*89c4ff92SAndroid Build Coastguard Worker };
2941*89c4ff92SAndroid Build Coastguard Worker
2942*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2943*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
2944*89c4ff92SAndroid Build Coastguard Worker
2945*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2946*89c4ff92SAndroid Build Coastguard Worker
2947*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2948*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
2949*89c4ff92SAndroid Build Coastguard Worker
2950*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2951*89c4ff92SAndroid Build Coastguard Worker alphaTensorInfo,
2952*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2953*89c4ff92SAndroid Build Coastguard Worker descriptorName,
2954*89c4ff92SAndroid Build Coastguard Worker "input",
2955*89c4ff92SAndroid Build Coastguard Worker "alpha");
2956*89c4ff92SAndroid Build Coastguard Worker }
2957*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const2958*89c4ff92SAndroid Build Coastguard Worker void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2959*89c4ff92SAndroid Build Coastguard Worker {
2960*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2961*89c4ff92SAndroid Build Coastguard Worker
2962*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
2963*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
2964*89c4ff92SAndroid Build Coastguard Worker
2965*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2966*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2967*89c4ff92SAndroid Build Coastguard Worker
2968*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2969*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2970*89c4ff92SAndroid Build Coastguard Worker
2971*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_Weight, descriptorName, "weight");
2972*89c4ff92SAndroid Build Coastguard Worker
2973*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2974*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2975*89c4ff92SAndroid Build Coastguard Worker
2976*89c4ff92SAndroid Build Coastguard Worker ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2977*89c4ff92SAndroid Build Coastguard Worker
2978*89c4ff92SAndroid Build Coastguard Worker Optional<TensorInfo> optionalBiasTensorInfo;
2979*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BiasEnabled)
2980*89c4ff92SAndroid Build Coastguard Worker {
2981*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_Bias, descriptorName, "bias");
2982*89c4ff92SAndroid Build Coastguard Worker
2983*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2984*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
2985*89c4ff92SAndroid Build Coastguard Worker
2986*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
2987*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
2988*89c4ff92SAndroid Build Coastguard Worker }
2989*89c4ff92SAndroid Build Coastguard Worker
2990*89c4ff92SAndroid Build Coastguard Worker ValidatePerAxisQuantization(inputTensorInfo,
2991*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
2992*89c4ff92SAndroid Build Coastguard Worker weightTensorInfo,
2993*89c4ff92SAndroid Build Coastguard Worker optionalBiasTensorInfo,
2994*89c4ff92SAndroid Build Coastguard Worker descriptorName);
2995*89c4ff92SAndroid Build Coastguard Worker
2996*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
2997*89c4ff92SAndroid Build Coastguard Worker {
2998*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
2999*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3000*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3001*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3002*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3003*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
3004*89c4ff92SAndroid Build Coastguard Worker };
3005*89c4ff92SAndroid Build Coastguard Worker
3006*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3007*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3008*89c4ff92SAndroid Build Coastguard Worker }
3009*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3010*89c4ff92SAndroid Build Coastguard Worker void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3011*89c4ff92SAndroid Build Coastguard Worker {
3012*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"TransposeQueueDescriptor"};
3013*89c4ff92SAndroid Build Coastguard Worker
3014*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3015*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3016*89c4ff92SAndroid Build Coastguard Worker
3017*89c4ff92SAndroid Build Coastguard Worker const PermutationVector& mapping = m_Parameters.m_DimMappings;
3018*89c4ff92SAndroid Build Coastguard Worker
3019*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3020*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3021*89c4ff92SAndroid Build Coastguard Worker
3022*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3023*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3024*89c4ff92SAndroid Build Coastguard Worker
3025*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3026*89c4ff92SAndroid Build Coastguard Worker {
3027*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3028*89c4ff92SAndroid Build Coastguard Worker {
3029*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3030*89c4ff92SAndroid Build Coastguard Worker " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3031*89c4ff92SAndroid Build Coastguard Worker "must match dst dimension " + to_string(i) +
3032*89c4ff92SAndroid Build Coastguard Worker " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3033*89c4ff92SAndroid Build Coastguard Worker }
3034*89c4ff92SAndroid Build Coastguard Worker }
3035*89c4ff92SAndroid Build Coastguard Worker
3036*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3037*89c4ff92SAndroid Build Coastguard Worker }
3038*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3039*89c4ff92SAndroid Build Coastguard Worker void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3040*89c4ff92SAndroid Build Coastguard Worker {
3041*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"TransposeQueueDescriptor"};
3042*89c4ff92SAndroid Build Coastguard Worker
3043*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3044*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3045*89c4ff92SAndroid Build Coastguard Worker
3046*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3047*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3048*89c4ff92SAndroid Build Coastguard Worker
3049*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3050*89c4ff92SAndroid Build Coastguard Worker }
3051*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3052*89c4ff92SAndroid Build Coastguard Worker void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3053*89c4ff92SAndroid Build Coastguard Worker {
3054*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"QLstmQueueDescriptor"};
3055*89c4ff92SAndroid Build Coastguard Worker
3056*89c4ff92SAndroid Build Coastguard Worker // Validate number of inputs/outputs
3057*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 3);
3058*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 3);
3059*89c4ff92SAndroid Build Coastguard Worker
3060*89c4ff92SAndroid Build Coastguard Worker // Input/output tensor info
3061*89c4ff92SAndroid Build Coastguard Worker auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3062*89c4ff92SAndroid Build Coastguard Worker auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3063*89c4ff92SAndroid Build Coastguard Worker auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3064*89c4ff92SAndroid Build Coastguard Worker
3065*89c4ff92SAndroid Build Coastguard Worker auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3066*89c4ff92SAndroid Build Coastguard Worker auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3067*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3068*89c4ff92SAndroid Build Coastguard Worker
3069*89c4ff92SAndroid Build Coastguard Worker // Supported types for various tensors in QLSTM
3070*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> inputOutputSupportedTypes =
3071*89c4ff92SAndroid Build Coastguard Worker {
3072*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8
3073*89c4ff92SAndroid Build Coastguard Worker };
3074*89c4ff92SAndroid Build Coastguard Worker
3075*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> cellStateSupportedTypes =
3076*89c4ff92SAndroid Build Coastguard Worker {
3077*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
3078*89c4ff92SAndroid Build Coastguard Worker };
3079*89c4ff92SAndroid Build Coastguard Worker
3080*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> weightsSupportedTypes =
3081*89c4ff92SAndroid Build Coastguard Worker {
3082*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8
3083*89c4ff92SAndroid Build Coastguard Worker };
3084*89c4ff92SAndroid Build Coastguard Worker
3085*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3086*89c4ff92SAndroid Build Coastguard Worker {
3087*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
3088*89c4ff92SAndroid Build Coastguard Worker };
3089*89c4ff92SAndroid Build Coastguard Worker
3090*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> biasSupportedTypes =
3091*89c4ff92SAndroid Build Coastguard Worker {
3092*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3093*89c4ff92SAndroid Build Coastguard Worker };
3094*89c4ff92SAndroid Build Coastguard Worker
3095*89c4ff92SAndroid Build Coastguard Worker // Validate types of input/output tensors
3096*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3097*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3098*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3099*89c4ff92SAndroid Build Coastguard Worker
3100*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3101*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3102*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3103*89c4ff92SAndroid Build Coastguard Worker
3104*89c4ff92SAndroid Build Coastguard Worker // Validate matching types of input/output tensors
3105*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3106*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3107*89c4ff92SAndroid Build Coastguard Worker "outputStateIn", "outputStateOut");
3108*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3109*89c4ff92SAndroid Build Coastguard Worker
3110*89c4ff92SAndroid Build Coastguard Worker // Infer number of batches, number of units, input size and output size from tensor dimensions
3111*89c4ff92SAndroid Build Coastguard Worker const uint32_t numBatches = inputInfo.GetShape()[0];
3112*89c4ff92SAndroid Build Coastguard Worker const uint32_t inputSize = inputInfo.GetShape()[1];
3113*89c4ff92SAndroid Build Coastguard Worker const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3114*89c4ff92SAndroid Build Coastguard Worker const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3115*89c4ff92SAndroid Build Coastguard Worker
3116*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements for input/output tensors
3117*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3118*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3119*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3120*89c4ff92SAndroid Build Coastguard Worker
3121*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3122*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3123*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3124*89c4ff92SAndroid Build Coastguard Worker
3125*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements for MANDATORY weight tensors
3126*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3127*89c4ff92SAndroid Build Coastguard Worker auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3128*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3129*89c4ff92SAndroid Build Coastguard Worker
3130*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3131*89c4ff92SAndroid Build Coastguard Worker auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3132*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3133*89c4ff92SAndroid Build Coastguard Worker
3134*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3135*89c4ff92SAndroid Build Coastguard Worker auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3136*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3137*89c4ff92SAndroid Build Coastguard Worker
3138*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3139*89c4ff92SAndroid Build Coastguard Worker auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3140*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3141*89c4ff92SAndroid Build Coastguard Worker " RecurrentToForgetWeights");
3142*89c4ff92SAndroid Build Coastguard Worker
3143*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3144*89c4ff92SAndroid Build Coastguard Worker auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3145*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3146*89c4ff92SAndroid Build Coastguard Worker
3147*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3148*89c4ff92SAndroid Build Coastguard Worker auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3149*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3150*89c4ff92SAndroid Build Coastguard Worker
3151*89c4ff92SAndroid Build Coastguard Worker // Validate data types for MANDATORY weights tensors (all should match each other)
3152*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3153*89c4ff92SAndroid Build Coastguard Worker
3154*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3155*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "inputToCellWeights");
3156*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3157*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "inputToOutputWeights");
3158*89c4ff92SAndroid Build Coastguard Worker
3159*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3160*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "recurrentToForgeteights");
3161*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3162*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "recurrentToCellWeights");
3163*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3164*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "recurrentToOutputWeights");
3165*89c4ff92SAndroid Build Coastguard Worker
3166*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements for MANDATORY bias tensors
3167*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3168*89c4ff92SAndroid Build Coastguard Worker auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3169*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3170*89c4ff92SAndroid Build Coastguard Worker
3171*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_CellBias, descriptorName, "CellBias");
3172*89c4ff92SAndroid Build Coastguard Worker auto cellBiasInfo = m_CellBias->GetTensorInfo();
3173*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3174*89c4ff92SAndroid Build Coastguard Worker
3175*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3176*89c4ff92SAndroid Build Coastguard Worker auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3177*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3178*89c4ff92SAndroid Build Coastguard Worker
3179*89c4ff92SAndroid Build Coastguard Worker // Validate data types for MANDATORY bias tensors
3180*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3181*89c4ff92SAndroid Build Coastguard Worker
3182*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3183*89c4ff92SAndroid Build Coastguard Worker "forgetGateBias", "cellBias");
3184*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3185*89c4ff92SAndroid Build Coastguard Worker "forgetGateBias", "outputGateBias");
3186*89c4ff92SAndroid Build Coastguard Worker
3187*89c4ff92SAndroid Build Coastguard Worker // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3188*89c4ff92SAndroid Build Coastguard Worker const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3189*89c4ff92SAndroid Build Coastguard Worker !m_Parameters.m_CifgEnabled) ||
3190*89c4ff92SAndroid Build Coastguard Worker (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3191*89c4ff92SAndroid Build Coastguard Worker !m_InputGateBias && m_Parameters.m_CifgEnabled));
3192*89c4ff92SAndroid Build Coastguard Worker
3193*89c4ff92SAndroid Build Coastguard Worker if (!allCifgParamsPresentOrNot)
3194*89c4ff92SAndroid Build Coastguard Worker {
3195*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3196*89c4ff92SAndroid Build Coastguard Worker ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3197*89c4ff92SAndroid Build Coastguard Worker "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3198*89c4ff92SAndroid Build Coastguard Worker "set appropriately.");
3199*89c4ff92SAndroid Build Coastguard Worker }
3200*89c4ff92SAndroid Build Coastguard Worker
3201*89c4ff92SAndroid Build Coastguard Worker if (!m_Parameters.m_CifgEnabled)
3202*89c4ff92SAndroid Build Coastguard Worker {
3203*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements
3204*89c4ff92SAndroid Build Coastguard Worker auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3205*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3206*89c4ff92SAndroid Build Coastguard Worker
3207*89c4ff92SAndroid Build Coastguard Worker auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3208*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3209*89c4ff92SAndroid Build Coastguard Worker " RecurrentToInputWeights");
3210*89c4ff92SAndroid Build Coastguard Worker
3211*89c4ff92SAndroid Build Coastguard Worker auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3212*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3213*89c4ff92SAndroid Build Coastguard Worker
3214*89c4ff92SAndroid Build Coastguard Worker // Validate data types
3215*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3216*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "inputToInputWeights");
3217*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3218*89c4ff92SAndroid Build Coastguard Worker "inputToForgetWeights", "recurrentToInputWeights");
3219*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3220*89c4ff92SAndroid Build Coastguard Worker "forgetGateBias", "inputGateBias");
3221*89c4ff92SAndroid Build Coastguard Worker }
3222*89c4ff92SAndroid Build Coastguard Worker
3223*89c4ff92SAndroid Build Coastguard Worker // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3224*89c4ff92SAndroid Build Coastguard Worker bool allPeepholeWeightsPresentOrNot =
3225*89c4ff92SAndroid Build Coastguard Worker (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3226*89c4ff92SAndroid Build Coastguard Worker && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3227*89c4ff92SAndroid Build Coastguard Worker || (!m_CellToInputWeights && !m_CellToForgetWeights
3228*89c4ff92SAndroid Build Coastguard Worker && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3229*89c4ff92SAndroid Build Coastguard Worker
3230*89c4ff92SAndroid Build Coastguard Worker if (!allPeepholeWeightsPresentOrNot)
3231*89c4ff92SAndroid Build Coastguard Worker {
3232*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3233*89c4ff92SAndroid Build Coastguard Worker ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3234*89c4ff92SAndroid Build Coastguard Worker "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3235*89c4ff92SAndroid Build Coastguard Worker "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3236*89c4ff92SAndroid Build Coastguard Worker "appropriately.");
3237*89c4ff92SAndroid Build Coastguard Worker }
3238*89c4ff92SAndroid Build Coastguard Worker
3239*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_PeepholeEnabled)
3240*89c4ff92SAndroid Build Coastguard Worker {
3241*89c4ff92SAndroid Build Coastguard Worker auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3242*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3243*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3244*89c4ff92SAndroid Build Coastguard Worker
3245*89c4ff92SAndroid Build Coastguard Worker auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3246*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3247*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3248*89c4ff92SAndroid Build Coastguard Worker "cellToForgetWeight", "cellToOutputWeights");
3249*89c4ff92SAndroid Build Coastguard Worker
3250*89c4ff92SAndroid Build Coastguard Worker if (!m_Parameters.m_CifgEnabled)
3251*89c4ff92SAndroid Build Coastguard Worker {
3252*89c4ff92SAndroid Build Coastguard Worker auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3253*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3254*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3255*89c4ff92SAndroid Build Coastguard Worker "cellToForgetWeights", "cellToInputWeights");
3256*89c4ff92SAndroid Build Coastguard Worker }
3257*89c4ff92SAndroid Build Coastguard Worker }
3258*89c4ff92SAndroid Build Coastguard Worker
3259*89c4ff92SAndroid Build Coastguard Worker // Validate OPTIONAL params: Layer Norm Weights
3260*89c4ff92SAndroid Build Coastguard Worker bool allLayerNormWeightsPresentOrNot =
3261*89c4ff92SAndroid Build Coastguard Worker (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3262*89c4ff92SAndroid Build Coastguard Worker && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3263*89c4ff92SAndroid Build Coastguard Worker || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3264*89c4ff92SAndroid Build Coastguard Worker && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3265*89c4ff92SAndroid Build Coastguard Worker
3266*89c4ff92SAndroid Build Coastguard Worker if (!allLayerNormWeightsPresentOrNot)
3267*89c4ff92SAndroid Build Coastguard Worker {
3268*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3269*89c4ff92SAndroid Build Coastguard Worker ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3270*89c4ff92SAndroid Build Coastguard Worker "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3271*89c4ff92SAndroid Build Coastguard Worker "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3272*89c4ff92SAndroid Build Coastguard Worker "only be present when Layer Norm is enabled and CIFG is disabled. "
3273*89c4ff92SAndroid Build Coastguard Worker "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3274*89c4ff92SAndroid Build Coastguard Worker }
3275*89c4ff92SAndroid Build Coastguard Worker
3276*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_LayerNormEnabled)
3277*89c4ff92SAndroid Build Coastguard Worker {
3278*89c4ff92SAndroid Build Coastguard Worker auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3279*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3280*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3281*89c4ff92SAndroid Build Coastguard Worker
3282*89c4ff92SAndroid Build Coastguard Worker auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3283*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3284*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3285*89c4ff92SAndroid Build Coastguard Worker "forgetLayerNormWeights", "cellLayerNormWeights");
3286*89c4ff92SAndroid Build Coastguard Worker
3287*89c4ff92SAndroid Build Coastguard Worker auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3288*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3289*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3290*89c4ff92SAndroid Build Coastguard Worker "forgetLayerNormWeights", "outputLayerNormWeights");
3291*89c4ff92SAndroid Build Coastguard Worker
3292*89c4ff92SAndroid Build Coastguard Worker if (!m_Parameters.m_CifgEnabled)
3293*89c4ff92SAndroid Build Coastguard Worker {
3294*89c4ff92SAndroid Build Coastguard Worker auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3295*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3296*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3297*89c4ff92SAndroid Build Coastguard Worker "forgetLayerNormWeights", "inputLayerNormWeights");
3298*89c4ff92SAndroid Build Coastguard Worker }
3299*89c4ff92SAndroid Build Coastguard Worker }
3300*89c4ff92SAndroid Build Coastguard Worker
3301*89c4ff92SAndroid Build Coastguard Worker // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3302*89c4ff92SAndroid Build Coastguard Worker bool correctProjectionTensorsPresent =
3303*89c4ff92SAndroid Build Coastguard Worker ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3304*89c4ff92SAndroid Build Coastguard Worker (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3305*89c4ff92SAndroid Build Coastguard Worker (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3306*89c4ff92SAndroid Build Coastguard Worker
3307*89c4ff92SAndroid Build Coastguard Worker if (!correctProjectionTensorsPresent)
3308*89c4ff92SAndroid Build Coastguard Worker {
3309*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3310*89c4ff92SAndroid Build Coastguard Worker ": If projection is enabled, ProjectionWeights should be present and "
3311*89c4ff92SAndroid Build Coastguard Worker "ProjectionBias is optional. If projection is disabled, neither "
3312*89c4ff92SAndroid Build Coastguard Worker "ProjectionWeights nor ProjectionBias should be present.");
3313*89c4ff92SAndroid Build Coastguard Worker }
3314*89c4ff92SAndroid Build Coastguard Worker
3315*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_ProjectionEnabled)
3316*89c4ff92SAndroid Build Coastguard Worker {
3317*89c4ff92SAndroid Build Coastguard Worker auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3318*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3319*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3320*89c4ff92SAndroid Build Coastguard Worker
3321*89c4ff92SAndroid Build Coastguard Worker if (m_ProjectionBias)
3322*89c4ff92SAndroid Build Coastguard Worker {
3323*89c4ff92SAndroid Build Coastguard Worker auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
3324*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
3325*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3326*89c4ff92SAndroid Build Coastguard Worker }
3327*89c4ff92SAndroid Build Coastguard Worker
3328*89c4ff92SAndroid Build Coastguard Worker }
3329*89c4ff92SAndroid Build Coastguard Worker else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3330*89c4ff92SAndroid Build Coastguard Worker outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3331*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3332*89c4ff92SAndroid Build Coastguard Worker ": If projection is disabled, output quantization info (scale, offset) "
3333*89c4ff92SAndroid Build Coastguard Worker "should match HiddenStateScale and HiddenStateZeroPoint.");
3334*89c4ff92SAndroid Build Coastguard Worker }
3335*89c4ff92SAndroid Build Coastguard Worker
3336*89c4ff92SAndroid Build Coastguard Worker }
3337*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3338*89c4ff92SAndroid Build Coastguard Worker void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3339*89c4ff92SAndroid Build Coastguard Worker {
3340*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3341*89c4ff92SAndroid Build Coastguard Worker
3342*89c4ff92SAndroid Build Coastguard Worker // Validate number of inputs/outputs
3343*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 3);
3344*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 2);
3345*89c4ff92SAndroid Build Coastguard Worker
3346*89c4ff92SAndroid Build Coastguard Worker // Input/output tensor infos
3347*89c4ff92SAndroid Build Coastguard Worker auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3348*89c4ff92SAndroid Build Coastguard Worker auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3349*89c4ff92SAndroid Build Coastguard Worker auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3350*89c4ff92SAndroid Build Coastguard Worker
3351*89c4ff92SAndroid Build Coastguard Worker auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3352*89c4ff92SAndroid Build Coastguard Worker auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3353*89c4ff92SAndroid Build Coastguard Worker
3354*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> inputOutputSupportedTypes =
3355*89c4ff92SAndroid Build Coastguard Worker {
3356*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8
3357*89c4ff92SAndroid Build Coastguard Worker };
3358*89c4ff92SAndroid Build Coastguard Worker
3359*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> cellStateSupportedTypes =
3360*89c4ff92SAndroid Build Coastguard Worker {
3361*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
3362*89c4ff92SAndroid Build Coastguard Worker };
3363*89c4ff92SAndroid Build Coastguard Worker
3364*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> weightsSupportedTypes =
3365*89c4ff92SAndroid Build Coastguard Worker {
3366*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8
3367*89c4ff92SAndroid Build Coastguard Worker };
3368*89c4ff92SAndroid Build Coastguard Worker
3369*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> biasSupportedTypes =
3370*89c4ff92SAndroid Build Coastguard Worker {
3371*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3372*89c4ff92SAndroid Build Coastguard Worker };
3373*89c4ff92SAndroid Build Coastguard Worker
3374*89c4ff92SAndroid Build Coastguard Worker // Validate types of input/output tensors
3375*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3376*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3377*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3378*89c4ff92SAndroid Build Coastguard Worker
3379*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3380*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3381*89c4ff92SAndroid Build Coastguard Worker
3382*89c4ff92SAndroid Build Coastguard Worker // Validate matching types of input/output tensors
3383*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3384*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3385*89c4ff92SAndroid Build Coastguard Worker "outputStateIn", "outputStateOut");
3386*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3387*89c4ff92SAndroid Build Coastguard Worker
3388*89c4ff92SAndroid Build Coastguard Worker // Validate matching quantization info for input/output tensors
3389*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3390*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3391*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3392*89c4ff92SAndroid Build Coastguard Worker
3393*89c4ff92SAndroid Build Coastguard Worker // Infer number of batches, input size and output size from tensor dimensions
3394*89c4ff92SAndroid Build Coastguard Worker const uint32_t numBatches = inputInfo.GetShape()[0];
3395*89c4ff92SAndroid Build Coastguard Worker const uint32_t inputSize = inputInfo.GetShape()[1];
3396*89c4ff92SAndroid Build Coastguard Worker const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3397*89c4ff92SAndroid Build Coastguard Worker
3398*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements for input/output tensors
3399*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3400*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3401*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3402*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3403*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3404*89c4ff92SAndroid Build Coastguard Worker
3405*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements for weights tensors
3406*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3407*89c4ff92SAndroid Build Coastguard Worker auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3408*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3409*89c4ff92SAndroid Build Coastguard Worker
3410*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3411*89c4ff92SAndroid Build Coastguard Worker auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3412*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3413*89c4ff92SAndroid Build Coastguard Worker
3414*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3415*89c4ff92SAndroid Build Coastguard Worker auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3416*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3417*89c4ff92SAndroid Build Coastguard Worker
3418*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3419*89c4ff92SAndroid Build Coastguard Worker auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3420*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3421*89c4ff92SAndroid Build Coastguard Worker
3422*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3423*89c4ff92SAndroid Build Coastguard Worker auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3424*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3425*89c4ff92SAndroid Build Coastguard Worker
3426*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3427*89c4ff92SAndroid Build Coastguard Worker auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3428*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3429*89c4ff92SAndroid Build Coastguard Worker " RecurrentToForgetWeights");
3430*89c4ff92SAndroid Build Coastguard Worker
3431*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3432*89c4ff92SAndroid Build Coastguard Worker auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3433*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3434*89c4ff92SAndroid Build Coastguard Worker
3435*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3436*89c4ff92SAndroid Build Coastguard Worker auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3437*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3438*89c4ff92SAndroid Build Coastguard Worker
3439*89c4ff92SAndroid Build Coastguard Worker // Validate data types for weights tensors (all should match each other)
3440*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3441*89c4ff92SAndroid Build Coastguard Worker
3442*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3443*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "inputToForgetWeights");
3444*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3445*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "inputToCellWeights");
3446*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3447*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "inputToOutputWeights");
3448*89c4ff92SAndroid Build Coastguard Worker
3449*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3450*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "recurrentToInputWeights");
3451*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3452*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "recurrentToForgeteights");
3453*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3454*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "recurrentToCellWeights");
3455*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3456*89c4ff92SAndroid Build Coastguard Worker "inputToInputWeights", "recurrentToOutputWeights");
3457*89c4ff92SAndroid Build Coastguard Worker
3458*89c4ff92SAndroid Build Coastguard Worker // Validate matching quantization info for weight tensors (all should match each other)
3459*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3460*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "inputToForgetWeights");
3461*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3462*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "inputToCellWeights");
3463*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3464*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "inputToOutputWeights");
3465*89c4ff92SAndroid Build Coastguard Worker
3466*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3467*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3468*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3469*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3470*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3471*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3472*89c4ff92SAndroid Build Coastguard Worker ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3473*89c4ff92SAndroid Build Coastguard Worker descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3474*89c4ff92SAndroid Build Coastguard Worker
3475*89c4ff92SAndroid Build Coastguard Worker // Validate number of dimensions and number of elements in bias tensors
3476*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3477*89c4ff92SAndroid Build Coastguard Worker auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3478*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3479*89c4ff92SAndroid Build Coastguard Worker
3480*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3481*89c4ff92SAndroid Build Coastguard Worker auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3482*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3483*89c4ff92SAndroid Build Coastguard Worker
3484*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_CellBias, descriptorName, "CellBias");
3485*89c4ff92SAndroid Build Coastguard Worker auto cellBiasInfo = m_CellBias->GetTensorInfo();
3486*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3487*89c4ff92SAndroid Build Coastguard Worker
3488*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3489*89c4ff92SAndroid Build Coastguard Worker auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3490*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3491*89c4ff92SAndroid Build Coastguard Worker
3492*89c4ff92SAndroid Build Coastguard Worker // Validate data types for bias tensors (all should match each other)
3493*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3494*89c4ff92SAndroid Build Coastguard Worker
3495*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3496*89c4ff92SAndroid Build Coastguard Worker "inputGateBias", "forgetGateBias");
3497*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3498*89c4ff92SAndroid Build Coastguard Worker "inputGateBias", "cellBias");
3499*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3500*89c4ff92SAndroid Build Coastguard Worker "inputGateBias", "outputGateBias");
3501*89c4ff92SAndroid Build Coastguard Worker
3502*89c4ff92SAndroid Build Coastguard Worker // Validate bias tensor quantization info
3503*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3504*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3505*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3506*89c4ff92SAndroid Build Coastguard Worker ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3507*89c4ff92SAndroid Build Coastguard Worker }
3508*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3509*89c4ff92SAndroid Build Coastguard Worker void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3510*89c4ff92SAndroid Build Coastguard Worker {
3511*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"AbsQueueDescriptor"};
3512*89c4ff92SAndroid Build Coastguard Worker
3513*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3514*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3515*89c4ff92SAndroid Build Coastguard Worker
3516*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3517*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3518*89c4ff92SAndroid Build Coastguard Worker
3519*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3520*89c4ff92SAndroid Build Coastguard Worker
3521*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3522*89c4ff92SAndroid Build Coastguard Worker {
3523*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
3524*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3525*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3526*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3527*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3528*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
3529*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3530*89c4ff92SAndroid Build Coastguard Worker };
3531*89c4ff92SAndroid Build Coastguard Worker
3532*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3533*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3534*89c4ff92SAndroid Build Coastguard Worker }
3535*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3536*89c4ff92SAndroid Build Coastguard Worker void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3537*89c4ff92SAndroid Build Coastguard Worker {
3538*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"SliceQueueDescriptor"};
3539*89c4ff92SAndroid Build Coastguard Worker
3540*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3541*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3542*89c4ff92SAndroid Build Coastguard Worker
3543*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3544*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3545*89c4ff92SAndroid Build Coastguard Worker
3546*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3547*89c4ff92SAndroid Build Coastguard Worker
3548*89c4ff92SAndroid Build Coastguard Worker const unsigned int rank = inputTensorInfo.GetNumDimensions();
3549*89c4ff92SAndroid Build Coastguard Worker if (rank > 4)
3550*89c4ff92SAndroid Build Coastguard Worker {
3551*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3552*89c4ff92SAndroid Build Coastguard Worker }
3553*89c4ff92SAndroid Build Coastguard Worker
3554*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3555*89c4ff92SAndroid Build Coastguard Worker
3556*89c4ff92SAndroid Build Coastguard Worker // Check if m_Begin and m_Size have the expected length
3557*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Begin.size() != rank)
3558*89c4ff92SAndroid Build Coastguard Worker {
3559*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3560*89c4ff92SAndroid Build Coastguard Worker ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3561*89c4ff92SAndroid Build Coastguard Worker }
3562*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Size.size() != rank)
3563*89c4ff92SAndroid Build Coastguard Worker {
3564*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
3565*89c4ff92SAndroid Build Coastguard Worker ": Length of size descriptor must equal rank " + std::to_string(rank));
3566*89c4ff92SAndroid Build Coastguard Worker }
3567*89c4ff92SAndroid Build Coastguard Worker
3568*89c4ff92SAndroid Build Coastguard Worker // Check if the shape of the output tensor matches m_Size
3569*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = outputTensorInfo.GetShape();
3570*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < rank; ++i)
3571*89c4ff92SAndroid Build Coastguard Worker {
3572*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Size[i] != outputShape[i])
3573*89c4ff92SAndroid Build Coastguard Worker {
3574*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3575*89c4ff92SAndroid Build Coastguard Worker }
3576*89c4ff92SAndroid Build Coastguard Worker }
3577*89c4ff92SAndroid Build Coastguard Worker
3578*89c4ff92SAndroid Build Coastguard Worker // Check if the sum of begin offset and size in a given dimension
3579*89c4ff92SAndroid Build Coastguard Worker // does not exceed the size of corresponding input
3580*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = inputTensorInfo.GetShape();
3581*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i = 0u; i < rank; ++i)
3582*89c4ff92SAndroid Build Coastguard Worker {
3583*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
3584*89c4ff92SAndroid Build Coastguard Worker {
3585*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3586*89c4ff92SAndroid Build Coastguard Worker std::to_string(i) + " exceeds input size.");
3587*89c4ff92SAndroid Build Coastguard Worker }
3588*89c4ff92SAndroid Build Coastguard Worker }
3589*89c4ff92SAndroid Build Coastguard Worker }
3590*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3591*89c4ff92SAndroid Build Coastguard Worker void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3592*89c4ff92SAndroid Build Coastguard Worker {
3593*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3594*89c4ff92SAndroid Build Coastguard Worker
3595*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3596*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3597*89c4ff92SAndroid Build Coastguard Worker
3598*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3599*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3600*89c4ff92SAndroid Build Coastguard Worker
3601*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3602*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3603*89c4ff92SAndroid Build Coastguard Worker
3604*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3605*89c4ff92SAndroid Build Coastguard Worker {
3606*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
3607*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3608*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3609*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3610*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3611*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
3612*89c4ff92SAndroid Build Coastguard Worker };
3613*89c4ff92SAndroid Build Coastguard Worker
3614*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3615*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3616*89c4ff92SAndroid Build Coastguard Worker
3617*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3618*89c4ff92SAndroid Build Coastguard Worker
3619*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_BlockSize == 0)
3620*89c4ff92SAndroid Build Coastguard Worker {
3621*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3622*89c4ff92SAndroid Build Coastguard Worker }
3623*89c4ff92SAndroid Build Coastguard Worker
3624*89c4ff92SAndroid Build Coastguard Worker DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3625*89c4ff92SAndroid Build Coastguard Worker const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3626*89c4ff92SAndroid Build Coastguard Worker const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3627*89c4ff92SAndroid Build Coastguard Worker const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3628*89c4ff92SAndroid Build Coastguard Worker
3629*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = outputInfo.GetShape();
3630*89c4ff92SAndroid Build Coastguard Worker if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3631*89c4ff92SAndroid Build Coastguard Worker {
3632*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3633*89c4ff92SAndroid Build Coastguard Worker "must be divisible by block size.");
3634*89c4ff92SAndroid Build Coastguard Worker }
3635*89c4ff92SAndroid Build Coastguard Worker
3636*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = inputInfo.GetShape();
3637*89c4ff92SAndroid Build Coastguard Worker if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3638*89c4ff92SAndroid Build Coastguard Worker {
3639*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3640*89c4ff92SAndroid Build Coastguard Worker "must be divisible by the square of block size." );
3641*89c4ff92SAndroid Build Coastguard Worker }
3642*89c4ff92SAndroid Build Coastguard Worker }
3643*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3644*89c4ff92SAndroid Build Coastguard Worker void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3645*89c4ff92SAndroid Build Coastguard Worker {
3646*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ComparisonQueueDescriptor"};
3647*89c4ff92SAndroid Build Coastguard Worker
3648*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
3649*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3650*89c4ff92SAndroid Build Coastguard Worker
3651*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3652*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3653*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3654*89c4ff92SAndroid Build Coastguard Worker
3655*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3656*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
3657*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
3658*89c4ff92SAndroid Build Coastguard Worker descriptorName,
3659*89c4ff92SAndroid Build Coastguard Worker "input_0",
3660*89c4ff92SAndroid Build Coastguard Worker "input_1");
3661*89c4ff92SAndroid Build Coastguard Worker
3662*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Boolean)
3663*89c4ff92SAndroid Build Coastguard Worker {
3664*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3665*89c4ff92SAndroid Build Coastguard Worker }
3666*89c4ff92SAndroid Build Coastguard Worker }
3667*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3668*89c4ff92SAndroid Build Coastguard Worker void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3669*89c4ff92SAndroid Build Coastguard Worker {
3670*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3671*89c4ff92SAndroid Build Coastguard Worker
3672*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
3673*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3674*89c4ff92SAndroid Build Coastguard Worker
3675*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3676*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3677*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3678*89c4ff92SAndroid Build Coastguard Worker
3679*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3680*89c4ff92SAndroid Build Coastguard Worker {
3681*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
3682*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3683*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3684*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3685*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3686*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
3687*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3688*89c4ff92SAndroid Build Coastguard Worker };
3689*89c4ff92SAndroid Build Coastguard Worker
3690*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3691*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3692*89c4ff92SAndroid Build Coastguard Worker
3693*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3694*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3695*89c4ff92SAndroid Build Coastguard Worker }
3696*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3697*89c4ff92SAndroid Build Coastguard Worker void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3698*89c4ff92SAndroid Build Coastguard Worker {
3699*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3700*89c4ff92SAndroid Build Coastguard Worker
3701*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3702*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3703*89c4ff92SAndroid Build Coastguard Worker
3704*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3705*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3706*89c4ff92SAndroid Build Coastguard Worker
3707*89c4ff92SAndroid Build Coastguard Worker ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3708*89c4ff92SAndroid Build Coastguard Worker
3709*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3710*89c4ff92SAndroid Build Coastguard Worker {
3711*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
3712*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3713*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3714*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3715*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3716*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
3717*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3718*89c4ff92SAndroid Build Coastguard Worker };
3719*89c4ff92SAndroid Build Coastguard Worker
3720*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> logicalSupportedTypes =
3721*89c4ff92SAndroid Build Coastguard Worker {
3722*89c4ff92SAndroid Build Coastguard Worker DataType::Boolean
3723*89c4ff92SAndroid Build Coastguard Worker };
3724*89c4ff92SAndroid Build Coastguard Worker
3725*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3726*89c4ff92SAndroid Build Coastguard Worker {
3727*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3728*89c4ff92SAndroid Build Coastguard Worker }
3729*89c4ff92SAndroid Build Coastguard Worker else
3730*89c4ff92SAndroid Build Coastguard Worker {
3731*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3732*89c4ff92SAndroid Build Coastguard Worker }
3733*89c4ff92SAndroid Build Coastguard Worker
3734*89c4ff92SAndroid Build Coastguard Worker
3735*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3736*89c4ff92SAndroid Build Coastguard Worker }
3737*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3738*89c4ff92SAndroid Build Coastguard Worker void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3739*89c4ff92SAndroid Build Coastguard Worker {
3740*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"RankQueueDescriptor"};
3741*89c4ff92SAndroid Build Coastguard Worker
3742*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3743*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3744*89c4ff92SAndroid Build Coastguard Worker
3745*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3746*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3747*89c4ff92SAndroid Build Coastguard Worker
3748*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3749*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3750*89c4ff92SAndroid Build Coastguard Worker
3751*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3752*89c4ff92SAndroid Build Coastguard Worker {
3753*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
3754*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3755*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3756*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3757*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3758*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS8,
3759*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
3760*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3761*89c4ff92SAndroid Build Coastguard Worker };
3762*89c4ff92SAndroid Build Coastguard Worker
3763*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3764*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3765*89c4ff92SAndroid Build Coastguard Worker }
3766*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3767*89c4ff92SAndroid Build Coastguard Worker void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3768*89c4ff92SAndroid Build Coastguard Worker {
3769*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3770*89c4ff92SAndroid Build Coastguard Worker
3771*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
3772*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3773*89c4ff92SAndroid Build Coastguard Worker
3774*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3775*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3776*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3777*89c4ff92SAndroid Build Coastguard Worker
3778*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3779*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
3780*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
3781*89c4ff92SAndroid Build Coastguard Worker descriptorName,
3782*89c4ff92SAndroid Build Coastguard Worker "input_0",
3783*89c4ff92SAndroid Build Coastguard Worker "input_1");
3784*89c4ff92SAndroid Build Coastguard Worker
3785*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3786*89c4ff92SAndroid Build Coastguard Worker {
3787*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3788*89c4ff92SAndroid Build Coastguard Worker }
3789*89c4ff92SAndroid Build Coastguard Worker
3790*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3791*89c4ff92SAndroid Build Coastguard Worker {
3792*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3793*89c4ff92SAndroid Build Coastguard Worker }
3794*89c4ff92SAndroid Build Coastguard Worker
3795*89c4ff92SAndroid Build Coastguard Worker if (outputTensorInfo.GetDataType() != DataType::Boolean)
3796*89c4ff92SAndroid Build Coastguard Worker {
3797*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3798*89c4ff92SAndroid Build Coastguard Worker }
3799*89c4ff92SAndroid Build Coastguard Worker }
3800*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3801*89c4ff92SAndroid Build Coastguard Worker void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3802*89c4ff92SAndroid Build Coastguard Worker {
3803*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"ReduceQueueDescriptor"};
3804*89c4ff92SAndroid Build Coastguard Worker
3805*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 1);
3806*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
3807*89c4ff92SAndroid Build Coastguard Worker
3808*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3809*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3810*89c4ff92SAndroid Build Coastguard Worker
3811*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3812*89c4ff92SAndroid Build Coastguard Worker {
3813*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
3814*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
3815*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3816*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
3817*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
3818*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16,
3819*89c4ff92SAndroid Build Coastguard Worker DataType::Signed32
3820*89c4ff92SAndroid Build Coastguard Worker };
3821*89c4ff92SAndroid Build Coastguard Worker
3822*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3823*89c4ff92SAndroid Build Coastguard Worker ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3824*89c4ff92SAndroid Build Coastguard Worker }
3825*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const3826*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3827*89c4ff92SAndroid Build Coastguard Worker {
3828*89c4ff92SAndroid Build Coastguard Worker // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3829*89c4ff92SAndroid Build Coastguard Worker
3830*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3831*89c4ff92SAndroid Build Coastguard Worker
3832*89c4ff92SAndroid Build Coastguard Worker // check dimensions of all inputs and outputs
3833*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_InputTensorInfos.size() != 3)
3834*89c4ff92SAndroid Build Coastguard Worker {
3835*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3836*89c4ff92SAndroid Build Coastguard Worker }
3837*89c4ff92SAndroid Build Coastguard Worker if (workloadInfo.m_OutputTensorInfos.size() != 3)
3838*89c4ff92SAndroid Build Coastguard Worker {
3839*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3840*89c4ff92SAndroid Build Coastguard Worker }
3841*89c4ff92SAndroid Build Coastguard Worker
3842*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
3843*89c4ff92SAndroid Build Coastguard Worker {
3844*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
3845*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8
3846*89c4ff92SAndroid Build Coastguard Worker };
3847*89c4ff92SAndroid Build Coastguard Worker
3848*89c4ff92SAndroid Build Coastguard Worker // check for supported type of one input and match them with all the other input and output
3849*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3850*89c4ff92SAndroid Build Coastguard Worker
3851*89c4ff92SAndroid Build Coastguard Worker // Making sure clipping parameters have valid values.
3852*89c4ff92SAndroid Build Coastguard Worker // == 0 means no clipping
3853*89c4ff92SAndroid Build Coastguard Worker // > 0 means clipping
3854*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_ClippingThresCell < 0.0f)
3855*89c4ff92SAndroid Build Coastguard Worker {
3856*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3857*89c4ff92SAndroid Build Coastguard Worker }
3858*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_ClippingThresProj < 0.0f)
3859*89c4ff92SAndroid Build Coastguard Worker {
3860*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3861*89c4ff92SAndroid Build Coastguard Worker }
3862*89c4ff92SAndroid Build Coastguard Worker
3863*89c4ff92SAndroid Build Coastguard Worker unsigned int batchIndx = 0;
3864*89c4ff92SAndroid Build Coastguard Worker unsigned int inputIndx = 1;
3865*89c4ff92SAndroid Build Coastguard Worker uint32_t timeStep = 1;
3866*89c4ff92SAndroid Build Coastguard Worker unsigned int timeIndx = 1;
3867*89c4ff92SAndroid Build Coastguard Worker inputIndx = 2;
3868*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_TimeMajor)
3869*89c4ff92SAndroid Build Coastguard Worker {
3870*89c4ff92SAndroid Build Coastguard Worker batchIndx = 1;
3871*89c4ff92SAndroid Build Coastguard Worker timeIndx = 0;
3872*89c4ff92SAndroid Build Coastguard Worker
3873*89c4ff92SAndroid Build Coastguard Worker }
3874*89c4ff92SAndroid Build Coastguard Worker timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3875*89c4ff92SAndroid Build Coastguard Worker
3876*89c4ff92SAndroid Build Coastguard Worker // Inferring batch size, number of outputs and number of cells from the inputs.
3877*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3878*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3879*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3880*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3881*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3882*89c4ff92SAndroid Build Coastguard Worker const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3883*89c4ff92SAndroid Build Coastguard Worker
3884*89c4ff92SAndroid Build Coastguard Worker // input tensor
3885*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3886*89c4ff92SAndroid Build Coastguard Worker descriptorName + " input_0");
3887*89c4ff92SAndroid Build Coastguard Worker // outputStateInTensor
3888*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3889*89c4ff92SAndroid Build Coastguard Worker descriptorName + " input_1");
3890*89c4ff92SAndroid Build Coastguard Worker // outputStateInTensor
3891*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3892*89c4ff92SAndroid Build Coastguard Worker descriptorName + " input_2");
3893*89c4ff92SAndroid Build Coastguard Worker
3894*89c4ff92SAndroid Build Coastguard Worker // outputTensor
3895*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
3896*89c4ff92SAndroid Build Coastguard Worker descriptorName + " output_0");
3897*89c4ff92SAndroid Build Coastguard Worker
3898*89c4ff92SAndroid Build Coastguard Worker // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3899*89c4ff92SAndroid Build Coastguard Worker if ( m_InputToInputWeights )
3900*89c4ff92SAndroid Build Coastguard Worker {
3901*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3902*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_input), "InputLayerNormWeights");
3903*89c4ff92SAndroid Build Coastguard Worker }
3904*89c4ff92SAndroid Build Coastguard Worker
3905*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3906*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3907*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_input), "InputToForgetWeights");
3908*89c4ff92SAndroid Build Coastguard Worker
3909*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3910*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3911*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_input), "InputToCellWeights");
3912*89c4ff92SAndroid Build Coastguard Worker
3913*89c4ff92SAndroid Build Coastguard Worker if ( m_RecurrentToInputWeights )
3914*89c4ff92SAndroid Build Coastguard Worker {
3915*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3916*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "RecurrentToInputWeights");
3917*89c4ff92SAndroid Build Coastguard Worker }
3918*89c4ff92SAndroid Build Coastguard Worker
3919*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3920*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3921*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "RecurrentToForgetWeights");
3922*89c4ff92SAndroid Build Coastguard Worker
3923*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3924*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3925*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "RecurrentToCellWeights");
3926*89c4ff92SAndroid Build Coastguard Worker
3927*89c4ff92SAndroid Build Coastguard Worker // Make sure the input-gate's parameters are either both present (regular
3928*89c4ff92SAndroid Build Coastguard Worker // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3929*89c4ff92SAndroid Build Coastguard Worker bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3930*89c4ff92SAndroid Build Coastguard Worker !m_Parameters.m_CifgEnabled) ||
3931*89c4ff92SAndroid Build Coastguard Worker (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3932*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_CifgEnabled));
3933*89c4ff92SAndroid Build Coastguard Worker if (!cifg_weights_all_or_none)
3934*89c4ff92SAndroid Build Coastguard Worker {
3935*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3936*89c4ff92SAndroid Build Coastguard Worker "RecurrentToInputWeights must either both be present (regular LSTM) "
3937*89c4ff92SAndroid Build Coastguard Worker "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3938*89c4ff92SAndroid Build Coastguard Worker "accordingly.");
3939*89c4ff92SAndroid Build Coastguard Worker }
3940*89c4ff92SAndroid Build Coastguard Worker
3941*89c4ff92SAndroid Build Coastguard Worker if ( m_CellToInputWeights )
3942*89c4ff92SAndroid Build Coastguard Worker {
3943*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3944*89c4ff92SAndroid Build Coastguard Worker n_cell, "CellToInputWeights");
3945*89c4ff92SAndroid Build Coastguard Worker }
3946*89c4ff92SAndroid Build Coastguard Worker if ( m_CellToForgetWeights )
3947*89c4ff92SAndroid Build Coastguard Worker {
3948*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3949*89c4ff92SAndroid Build Coastguard Worker n_cell, "CellToForgetWeights");
3950*89c4ff92SAndroid Build Coastguard Worker }
3951*89c4ff92SAndroid Build Coastguard Worker if ( m_CellToOutputWeights )
3952*89c4ff92SAndroid Build Coastguard Worker {
3953*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3954*89c4ff92SAndroid Build Coastguard Worker n_cell, "CellToOutputWeights");
3955*89c4ff92SAndroid Build Coastguard Worker }
3956*89c4ff92SAndroid Build Coastguard Worker
3957*89c4ff92SAndroid Build Coastguard Worker // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3958*89c4ff92SAndroid Build Coastguard Worker bool peephole_weights_all_or_none =
3959*89c4ff92SAndroid Build Coastguard Worker (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3960*89c4ff92SAndroid Build Coastguard Worker && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3961*89c4ff92SAndroid Build Coastguard Worker || ( !m_CellToInputWeights && !m_CellToForgetWeights
3962*89c4ff92SAndroid Build Coastguard Worker && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3963*89c4ff92SAndroid Build Coastguard Worker if (!peephole_weights_all_or_none)
3964*89c4ff92SAndroid Build Coastguard Worker {
3965*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3966*89c4ff92SAndroid Build Coastguard Worker }
3967*89c4ff92SAndroid Build Coastguard Worker
3968*89c4ff92SAndroid Build Coastguard Worker // Make sure the input gate bias is present only when not a CIFG-LSTM.
3969*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_CifgEnabled)
3970*89c4ff92SAndroid Build Coastguard Worker {
3971*89c4ff92SAndroid Build Coastguard Worker if (m_InputGateBias)
3972*89c4ff92SAndroid Build Coastguard Worker {
3973*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3974*89c4ff92SAndroid Build Coastguard Worker }
3975*89c4ff92SAndroid Build Coastguard Worker }
3976*89c4ff92SAndroid Build Coastguard Worker else
3977*89c4ff92SAndroid Build Coastguard Worker {
3978*89c4ff92SAndroid Build Coastguard Worker if (!m_InputGateBias)
3979*89c4ff92SAndroid Build Coastguard Worker {
3980*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3981*89c4ff92SAndroid Build Coastguard Worker "must be present.");
3982*89c4ff92SAndroid Build Coastguard Worker }
3983*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3984*89c4ff92SAndroid Build Coastguard Worker n_cell, "InputGateBias");
3985*89c4ff92SAndroid Build Coastguard Worker }
3986*89c4ff92SAndroid Build Coastguard Worker
3987*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3988*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3989*89c4ff92SAndroid Build Coastguard Worker
3990*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
3991*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
3992*89c4ff92SAndroid Build Coastguard Worker
3993*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
3994*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
3995*89c4ff92SAndroid Build Coastguard Worker
3996*89c4ff92SAndroid Build Coastguard Worker if (m_ProjectionWeights)
3997*89c4ff92SAndroid Build Coastguard Worker {
3998*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
3999*89c4ff92SAndroid Build Coastguard Worker (n_cell * n_output), "ProjectionWeights");
4000*89c4ff92SAndroid Build Coastguard Worker }
4001*89c4ff92SAndroid Build Coastguard Worker if (m_ProjectionBias)
4002*89c4ff92SAndroid Build Coastguard Worker {
4003*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4004*89c4ff92SAndroid Build Coastguard Worker }
4005*89c4ff92SAndroid Build Coastguard Worker
4006*89c4ff92SAndroid Build Coastguard Worker // Making sure the projection tensors are consistent:
4007*89c4ff92SAndroid Build Coastguard Worker // 1) If projection weight is not present, then projection bias should not be
4008*89c4ff92SAndroid Build Coastguard Worker // present.
4009*89c4ff92SAndroid Build Coastguard Worker // 2) If projection weight is present, then projection bias is optional.
4010*89c4ff92SAndroid Build Coastguard Worker bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4011*89c4ff92SAndroid Build Coastguard Worker !m_Parameters.m_ProjectionEnabled)
4012*89c4ff92SAndroid Build Coastguard Worker || (m_ProjectionWeights && !m_ProjectionBias &&
4013*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_ProjectionEnabled)
4014*89c4ff92SAndroid Build Coastguard Worker || (m_ProjectionWeights && m_ProjectionBias &&
4015*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_ProjectionEnabled));
4016*89c4ff92SAndroid Build Coastguard Worker if (!projecton_tensors_consistent)
4017*89c4ff92SAndroid Build Coastguard Worker {
4018*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4019*89c4ff92SAndroid Build Coastguard Worker }
4020*89c4ff92SAndroid Build Coastguard Worker
4021*89c4ff92SAndroid Build Coastguard Worker // The four layer normalization weights either all have values or none of them have values. Additionally, if
4022*89c4ff92SAndroid Build Coastguard Worker // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4023*89c4ff92SAndroid Build Coastguard Worker // either all have values or none of them have values. Layer normalization is used when the values of all the
4024*89c4ff92SAndroid Build Coastguard Worker // layer normalization weights are present
4025*89c4ff92SAndroid Build Coastguard Worker if (m_InputLayerNormWeights)
4026*89c4ff92SAndroid Build Coastguard Worker {
4027*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4028*89c4ff92SAndroid Build Coastguard Worker }
4029*89c4ff92SAndroid Build Coastguard Worker if (m_ForgetLayerNormWeights)
4030*89c4ff92SAndroid Build Coastguard Worker {
4031*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4032*89c4ff92SAndroid Build Coastguard Worker }
4033*89c4ff92SAndroid Build Coastguard Worker if (m_CellLayerNormWeights)
4034*89c4ff92SAndroid Build Coastguard Worker {
4035*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4036*89c4ff92SAndroid Build Coastguard Worker }
4037*89c4ff92SAndroid Build Coastguard Worker if (m_OutputLayerNormWeights)
4038*89c4ff92SAndroid Build Coastguard Worker {
4039*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4040*89c4ff92SAndroid Build Coastguard Worker }
4041*89c4ff92SAndroid Build Coastguard Worker
4042*89c4ff92SAndroid Build Coastguard Worker if (m_Parameters.m_LayerNormEnabled)
4043*89c4ff92SAndroid Build Coastguard Worker {
4044*89c4ff92SAndroid Build Coastguard Worker if (!m_Parameters.m_CifgEnabled)
4045*89c4ff92SAndroid Build Coastguard Worker {
4046*89c4ff92SAndroid Build Coastguard Worker if (!m_InputLayerNormWeights)
4047*89c4ff92SAndroid Build Coastguard Worker {
4048*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4049*89c4ff92SAndroid Build Coastguard Worker "disabled but InputLayerNormWeights are not present");
4050*89c4ff92SAndroid Build Coastguard Worker }
4051*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4052*89c4ff92SAndroid Build Coastguard Worker 1, n_cell, "InputLayerNormWeights");
4053*89c4ff92SAndroid Build Coastguard Worker }
4054*89c4ff92SAndroid Build Coastguard Worker else if (m_InputLayerNormWeights)
4055*89c4ff92SAndroid Build Coastguard Worker {
4056*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4057*89c4ff92SAndroid Build Coastguard Worker "enabled");
4058*89c4ff92SAndroid Build Coastguard Worker }
4059*89c4ff92SAndroid Build Coastguard Worker
4060*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4061*89c4ff92SAndroid Build Coastguard Worker "ForgetLayerNormWeights");
4062*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4063*89c4ff92SAndroid Build Coastguard Worker
4064*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4065*89c4ff92SAndroid Build Coastguard Worker "OutputLayerNormWeights");
4066*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4067*89c4ff92SAndroid Build Coastguard Worker
4068*89c4ff92SAndroid Build Coastguard Worker ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4069*89c4ff92SAndroid Build Coastguard Worker "CellLayerNormWeights");
4070*89c4ff92SAndroid Build Coastguard Worker ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4071*89c4ff92SAndroid Build Coastguard Worker }
4072*89c4ff92SAndroid Build Coastguard Worker else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4073*89c4ff92SAndroid Build Coastguard Worker {
4074*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4075*89c4ff92SAndroid Build Coastguard Worker "normalisation weights are present.");
4076*89c4ff92SAndroid Build Coastguard Worker }
4077*89c4ff92SAndroid Build Coastguard Worker }
4078*89c4ff92SAndroid Build Coastguard Worker
Validate(const WorkloadInfo & workloadInfo) const4079*89c4ff92SAndroid Build Coastguard Worker void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4080*89c4ff92SAndroid Build Coastguard Worker {
4081*89c4ff92SAndroid Build Coastguard Worker const std::string descriptorName{"BatchMatMulDescriptor"};
4082*89c4ff92SAndroid Build Coastguard Worker
4083*89c4ff92SAndroid Build Coastguard Worker ValidateNumInputs(workloadInfo, descriptorName, 2);
4084*89c4ff92SAndroid Build Coastguard Worker ValidateNumOutputs(workloadInfo, descriptorName, 1);
4085*89c4ff92SAndroid Build Coastguard Worker
4086*89c4ff92SAndroid Build Coastguard Worker // Inputs must be: both 2D+
4087*89c4ff92SAndroid Build Coastguard Worker // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4088*89c4ff92SAndroid Build Coastguard Worker // axes N and I must be the same size
4089*89c4ff92SAndroid Build Coastguard Worker
4090*89c4ff92SAndroid Build Coastguard Worker const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4091*89c4ff92SAndroid Build Coastguard Worker const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4092*89c4ff92SAndroid Build Coastguard Worker const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4093*89c4ff92SAndroid Build Coastguard Worker // Output info has already been inferred
4094*89c4ff92SAndroid Build Coastguard Worker
4095*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> supportedTypes =
4096*89c4ff92SAndroid Build Coastguard Worker {
4097*89c4ff92SAndroid Build Coastguard Worker DataType::BFloat16,
4098*89c4ff92SAndroid Build Coastguard Worker DataType::Float16,
4099*89c4ff92SAndroid Build Coastguard Worker DataType::Float32,
4100*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmS8,
4101*89c4ff92SAndroid Build Coastguard Worker DataType::QAsymmU8,
4102*89c4ff92SAndroid Build Coastguard Worker DataType::QSymmS16
4103*89c4ff92SAndroid Build Coastguard Worker };
4104*89c4ff92SAndroid Build Coastguard Worker
4105*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4106*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4107*89c4ff92SAndroid Build Coastguard Worker ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4108*89c4ff92SAndroid Build Coastguard Worker
4109*89c4ff92SAndroid Build Coastguard Worker if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4110*89c4ff92SAndroid Build Coastguard Worker (inputYInfoBeforeParams.GetNumDimensions() < 2))
4111*89c4ff92SAndroid Build Coastguard Worker {
4112*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4113*89c4ff92SAndroid Build Coastguard Worker }
4114*89c4ff92SAndroid Build Coastguard Worker
4115*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputXInfoAfterParams;
4116*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputYInfoAfterParams;
4117*89c4ff92SAndroid Build Coastguard Worker
4118*89c4ff92SAndroid Build Coastguard Worker if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4119*89c4ff92SAndroid Build Coastguard Worker (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
4120*89c4ff92SAndroid Build Coastguard Worker {
4121*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4122*89c4ff92SAndroid Build Coastguard Worker ": Invalid descriptor parameters - Transpose and Adjoint "
4123*89c4ff92SAndroid Build Coastguard Worker "cannot both be true for a given input tensor.");
4124*89c4ff92SAndroid Build Coastguard Worker }
4125*89c4ff92SAndroid Build Coastguard Worker if(m_Parameters.m_TransposeX)
4126*89c4ff92SAndroid Build Coastguard Worker {
4127*89c4ff92SAndroid Build Coastguard Worker inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4128*89c4ff92SAndroid Build Coastguard Worker BatchMatMulDescriptor::GetPermuteVec(
4129*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_DataLayoutX,
4130*89c4ff92SAndroid Build Coastguard Worker inputXInfoBeforeParams.GetShape()));
4131*89c4ff92SAndroid Build Coastguard Worker }
4132*89c4ff92SAndroid Build Coastguard Worker else if(m_Parameters.m_AdjointX)
4133*89c4ff92SAndroid Build Coastguard Worker {
4134*89c4ff92SAndroid Build Coastguard Worker auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4135*89c4ff92SAndroid Build Coastguard Worker inputXInfoBeforeParams.GetShape());
4136*89c4ff92SAndroid Build Coastguard Worker if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4137*89c4ff92SAndroid Build Coastguard Worker inputXInfoBeforeParams.GetShape()[axesToMul.second])
4138*89c4ff92SAndroid Build Coastguard Worker {
4139*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4140*89c4ff92SAndroid Build Coastguard Worker ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4141*89c4ff92SAndroid Build Coastguard Worker }
4142*89c4ff92SAndroid Build Coastguard Worker // Shape remains the same as it's square
4143*89c4ff92SAndroid Build Coastguard Worker inputXInfoAfterParams = inputXInfoBeforeParams;
4144*89c4ff92SAndroid Build Coastguard Worker }
4145*89c4ff92SAndroid Build Coastguard Worker else
4146*89c4ff92SAndroid Build Coastguard Worker {
4147*89c4ff92SAndroid Build Coastguard Worker inputXInfoAfterParams = inputXInfoBeforeParams;
4148*89c4ff92SAndroid Build Coastguard Worker }
4149*89c4ff92SAndroid Build Coastguard Worker
4150*89c4ff92SAndroid Build Coastguard Worker if(m_Parameters.m_TransposeY)
4151*89c4ff92SAndroid Build Coastguard Worker {
4152*89c4ff92SAndroid Build Coastguard Worker inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4153*89c4ff92SAndroid Build Coastguard Worker BatchMatMulDescriptor::GetPermuteVec(
4154*89c4ff92SAndroid Build Coastguard Worker m_Parameters.m_DataLayoutY,
4155*89c4ff92SAndroid Build Coastguard Worker inputYInfoBeforeParams.GetShape()));
4156*89c4ff92SAndroid Build Coastguard Worker }
4157*89c4ff92SAndroid Build Coastguard Worker else if(m_Parameters.m_AdjointY)
4158*89c4ff92SAndroid Build Coastguard Worker {
4159*89c4ff92SAndroid Build Coastguard Worker auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4160*89c4ff92SAndroid Build Coastguard Worker inputYInfoBeforeParams.GetShape());
4161*89c4ff92SAndroid Build Coastguard Worker if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4162*89c4ff92SAndroid Build Coastguard Worker inputYInfoBeforeParams.GetShape()[axesToMul.second])
4163*89c4ff92SAndroid Build Coastguard Worker {
4164*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4165*89c4ff92SAndroid Build Coastguard Worker ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4166*89c4ff92SAndroid Build Coastguard Worker }
4167*89c4ff92SAndroid Build Coastguard Worker // Shape remains the same as it's square
4168*89c4ff92SAndroid Build Coastguard Worker inputYInfoAfterParams = inputYInfoBeforeParams;
4169*89c4ff92SAndroid Build Coastguard Worker }
4170*89c4ff92SAndroid Build Coastguard Worker else
4171*89c4ff92SAndroid Build Coastguard Worker {
4172*89c4ff92SAndroid Build Coastguard Worker inputYInfoAfterParams = inputYInfoBeforeParams;
4173*89c4ff92SAndroid Build Coastguard Worker }
4174*89c4ff92SAndroid Build Coastguard Worker
4175*89c4ff92SAndroid Build Coastguard Worker switch(m_Parameters.m_DataLayoutX)
4176*89c4ff92SAndroid Build Coastguard Worker {
4177*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCDHW:
4178*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NDHWC:
4179*89c4ff92SAndroid Build Coastguard Worker if(inputXInfoAfterParams.GetNumDimensions() < 3)
4180*89c4ff92SAndroid Build Coastguard Worker {
4181*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4182*89c4ff92SAndroid Build Coastguard Worker ": Input tensor X does not have the correct "
4183*89c4ff92SAndroid Build Coastguard Worker "number of dimensions for the Data Layout that it has been assigned.");
4184*89c4ff92SAndroid Build Coastguard Worker }
4185*89c4ff92SAndroid Build Coastguard Worker break;
4186*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
4187*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
4188*89c4ff92SAndroid Build Coastguard Worker default:
4189*89c4ff92SAndroid Build Coastguard Worker break;
4190*89c4ff92SAndroid Build Coastguard Worker }
4191*89c4ff92SAndroid Build Coastguard Worker
4192*89c4ff92SAndroid Build Coastguard Worker switch(m_Parameters.m_DataLayoutY)
4193*89c4ff92SAndroid Build Coastguard Worker {
4194*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCDHW:
4195*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NDHWC:
4196*89c4ff92SAndroid Build Coastguard Worker if(inputYInfoAfterParams.GetNumDimensions() < 3)
4197*89c4ff92SAndroid Build Coastguard Worker {
4198*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4199*89c4ff92SAndroid Build Coastguard Worker ": Input tensor Y does not have the correct "
4200*89c4ff92SAndroid Build Coastguard Worker "number of dimensions for the Data Layout that it has been assigned.");
4201*89c4ff92SAndroid Build Coastguard Worker }
4202*89c4ff92SAndroid Build Coastguard Worker break;
4203*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
4204*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
4205*89c4ff92SAndroid Build Coastguard Worker default:
4206*89c4ff92SAndroid Build Coastguard Worker break;
4207*89c4ff92SAndroid Build Coastguard Worker }
4208*89c4ff92SAndroid Build Coastguard Worker
4209*89c4ff92SAndroid Build Coastguard Worker auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4210*89c4ff92SAndroid Build Coastguard Worker inputXInfoAfterParams.GetShape());
4211*89c4ff92SAndroid Build Coastguard Worker auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4212*89c4ff92SAndroid Build Coastguard Worker inputXInfoBeforeParams.GetShape());
4213*89c4ff92SAndroid Build Coastguard Worker
4214*89c4ff92SAndroid Build Coastguard Worker if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4215*89c4ff92SAndroid Build Coastguard Worker != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4216*89c4ff92SAndroid Build Coastguard Worker {
4217*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4218*89c4ff92SAndroid Build Coastguard Worker ": The final axis of input tensor X must be the same size as "
4219*89c4ff92SAndroid Build Coastguard Worker "the second last axis of input tensor Y.");
4220*89c4ff92SAndroid Build Coastguard Worker }
4221*89c4ff92SAndroid Build Coastguard Worker
4222*89c4ff92SAndroid Build Coastguard Worker { // Separate scope so we don't pollute the rest of the scope with our temp variables
4223*89c4ff92SAndroid Build Coastguard Worker // e.g. NHWC isnt compatible with NCHW as of now
4224*89c4ff92SAndroid Build Coastguard Worker DataLayout xLayout = m_Parameters.m_DataLayoutX;
4225*89c4ff92SAndroid Build Coastguard Worker DataLayout yLayout = m_Parameters.m_DataLayoutY;
4226*89c4ff92SAndroid Build Coastguard Worker
4227*89c4ff92SAndroid Build Coastguard Worker if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4228*89c4ff92SAndroid Build Coastguard Worker {
4229*89c4ff92SAndroid Build Coastguard Worker if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4230*89c4ff92SAndroid Build Coastguard Worker {
4231*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4232*89c4ff92SAndroid Build Coastguard Worker ": Invalid input tensor data layout combination.");
4233*89c4ff92SAndroid Build Coastguard Worker }
4234*89c4ff92SAndroid Build Coastguard Worker }
4235*89c4ff92SAndroid Build Coastguard Worker if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4236*89c4ff92SAndroid Build Coastguard Worker {
4237*89c4ff92SAndroid Build Coastguard Worker if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4238*89c4ff92SAndroid Build Coastguard Worker {
4239*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(descriptorName +
4240*89c4ff92SAndroid Build Coastguard Worker ": Invalid input tensor data layout combination.");
4241*89c4ff92SAndroid Build Coastguard Worker }
4242*89c4ff92SAndroid Build Coastguard Worker }
4243*89c4ff92SAndroid Build Coastguard Worker }
4244*89c4ff92SAndroid Build Coastguard Worker
4245*89c4ff92SAndroid Build Coastguard Worker // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4246*89c4ff92SAndroid Build Coastguard Worker unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4247*89c4ff92SAndroid Build Coastguard Worker inputYInfoAfterParams.GetNumDimensions());
4248*89c4ff92SAndroid Build Coastguard Worker if(outputTensorDimSize-2 > 0)
4249*89c4ff92SAndroid Build Coastguard Worker {
4250*89c4ff92SAndroid Build Coastguard Worker TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4251*89c4ff92SAndroid Build Coastguard Worker DataType::Float32);
4252*89c4ff92SAndroid Build Coastguard Worker TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4253*89c4ff92SAndroid Build Coastguard Worker DataType::Float32);
4254*89c4ff92SAndroid Build Coastguard Worker TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4255*89c4ff92SAndroid Build Coastguard Worker DataType::Float32);
4256*89c4ff92SAndroid Build Coastguard Worker
4257*89c4ff92SAndroid Build Coastguard Worker auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4258*89c4ff92SAndroid Build Coastguard Worker {
4259*89c4ff92SAndroid Build Coastguard Worker auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4260*89c4ff92SAndroid Build Coastguard Worker
4261*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i = 0; i < sizeDiff; i++)
4262*89c4ff92SAndroid Build Coastguard Worker {
4263*89c4ff92SAndroid Build Coastguard Worker axisIndices.insert(axisIndices.begin(), 1);
4264*89c4ff92SAndroid Build Coastguard Worker }
4265*89c4ff92SAndroid Build Coastguard Worker
4266*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4267*89c4ff92SAndroid Build Coastguard Worker {
4268*89c4ff92SAndroid Build Coastguard Worker ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4269*89c4ff92SAndroid Build Coastguard Worker }
4270*89c4ff92SAndroid Build Coastguard Worker };
4271*89c4ff92SAndroid Build Coastguard Worker
4272*89c4ff92SAndroid Build Coastguard Worker auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4273*89c4ff92SAndroid Build Coastguard Worker inputXInfoAfterParams.GetShape());
4274*89c4ff92SAndroid Build Coastguard Worker auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4275*89c4ff92SAndroid Build Coastguard Worker inputYInfoAfterParams.GetShape());
4276*89c4ff92SAndroid Build Coastguard Worker
4277*89c4ff92SAndroid Build Coastguard Worker doAxisExtension(axesXNotMul, tiXNotMul);
4278*89c4ff92SAndroid Build Coastguard Worker doAxisExtension(axesYNotMul, tiYNotMul);
4279*89c4ff92SAndroid Build Coastguard Worker
4280*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4281*89c4ff92SAndroid Build Coastguard Worker {
4282*89c4ff92SAndroid Build Coastguard Worker tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4283*89c4ff92SAndroid Build Coastguard Worker tiYNotMul.GetShape()[i]);
4284*89c4ff92SAndroid Build Coastguard Worker }
4285*89c4ff92SAndroid Build Coastguard Worker
4286*89c4ff92SAndroid Build Coastguard Worker ValidateBroadcastTensorShapesMatch(tiXNotMul,
4287*89c4ff92SAndroid Build Coastguard Worker tiYNotMul,
4288*89c4ff92SAndroid Build Coastguard Worker tiOutNotMul,
4289*89c4ff92SAndroid Build Coastguard Worker descriptorName,
4290*89c4ff92SAndroid Build Coastguard Worker "input_X",
4291*89c4ff92SAndroid Build Coastguard Worker "input_Y");
4292*89c4ff92SAndroid Build Coastguard Worker }
4293*89c4ff92SAndroid Build Coastguard Worker }
4294*89c4ff92SAndroid Build Coastguard Worker
4295*89c4ff92SAndroid Build Coastguard Worker
4296*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn