xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefWorkloadUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1  //
2  // Copyright © 2017 Arm Ltd. All rights reserved.
3  // SPDX-License-Identifier: MIT
4  //
5  
6  #pragma once
7  
8  #include <armnn/backends/TensorHandle.hpp>
9  
10  #include <armnn/Tensor.hpp>
11  #include <armnn/Types.hpp>
12  #include <armnn/utility/PolymorphicDowncast.hpp>
13  
14  #include <reference/RefTensorHandle.hpp>
15  
16  #include <BFloat16.hpp>
17  #include <Half.hpp>
18  
19  namespace armnn
20  {
21  
22  ////////////////////////////////////////////
23  /// float32 helpers
24  ////////////////////////////////////////////
25  
26  template <typename TensorHandleType = RefTensorHandle>
GetTensorInfo(const ITensorHandle * tensorHandle)27  inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle)
28  {
29      // We know that reference workloads use RefTensorHandles for inputs and outputs
30      const TensorHandleType* refTensorHandle =
31          PolymorphicDowncast<const TensorHandleType*>(tensorHandle);
32      return refTensorHandle->GetTensorInfo();
33  }
34  
35  template <typename DataType, typename PayloadType>
GetInputTensorData(unsigned int idx,const PayloadType & data)36  const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data)
37  {
38      const ITensorHandle* tensorHandle = data.m_Inputs[idx];
39      return reinterpret_cast<const DataType*>(tensorHandle->Map());
40  }
41  
42  template <typename DataType, typename PayloadType>
GetOutputTensorData(unsigned int idx,const PayloadType & data)43  DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data)
44  {
45      ITensorHandle* tensorHandle = data.m_Outputs[idx];
46      return reinterpret_cast<DataType*>(tensorHandle->Map());
47  }
48  
49  template <typename DataType>
GetOutputTensorData(ITensorHandle * tensorHandle)50  DataType* GetOutputTensorData(ITensorHandle* tensorHandle)
51  {
52      return reinterpret_cast<DataType*>(tensorHandle->Map());
53  }
54  
55  template <typename PayloadType>
GetInputTensorDataFloat(unsigned int idx,const PayloadType & data)56  const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data)
57  {
58      return GetInputTensorData<float>(idx, data);
59  }
60  
61  template <typename PayloadType>
GetOutputTensorDataFloat(unsigned int idx,const PayloadType & data)62  float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data)
63  {
64      return GetOutputTensorData<float>(idx, data);
65  }
66  
67  template <typename PayloadType>
GetInputTensorDataHalf(unsigned int idx,const PayloadType & data)68  const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data)
69  {
70      return GetInputTensorData<Half>(idx, data);
71  }
72  
73  template <typename PayloadType>
GetOutputTensorDataHalf(unsigned int idx,const PayloadType & data)74  Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data)
75  {
76      return GetOutputTensorData<Half>(idx, data);
77  }
78  
79  template <typename PayloadType>
GetInputTensorDataBFloat16(unsigned int idx,const PayloadType & data)80  const BFloat16* GetInputTensorDataBFloat16(unsigned int idx, const PayloadType& data)
81  {
82      return GetInputTensorData<BFloat16>(idx, data);
83  }
84  
85  template <typename PayloadType>
GetOutputTensorDataBFloat16(unsigned int idx,const PayloadType & data)86  BFloat16* GetOutputTensorDataBFloat16(unsigned int idx, const PayloadType& data)
87  {
88      return GetOutputTensorData<BFloat16>(idx, data);
89  }
90  
91  ////////////////////////////////////////////
92  /// u8 helpers
93  ////////////////////////////////////////////
94  
95  template<typename T>
Dequantize(const T * quant,const TensorInfo & info)96  std::vector<float> Dequantize(const T* quant, const TensorInfo& info)
97  {
98      std::vector<float> ret(info.GetNumElements());
99      for (size_t i = 0; i < info.GetNumElements(); i++)
100      {
101          ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
102      }
103      return ret;
104  }
105  
106  template<typename T>
Dequantize(const T * inputData,float * outputData,const TensorInfo & info)107  inline void Dequantize(const T* inputData, float* outputData, const TensorInfo& info)
108  {
109      for (unsigned int i = 0; i < info.GetNumElements(); i++)
110      {
111          outputData[i] = Dequantize<T>(inputData[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
112      }
113  }
114  
Quantize(uint8_t * quant,const float * dequant,const TensorInfo & info)115  inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info)
116  {
117      for (size_t i = 0; i < info.GetNumElements(); i++)
118      {
119          quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
120      }
121  }
122  
123  } //namespace armnn
124