1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/backends/TensorHandle.hpp> 9 10 #include <armnn/Tensor.hpp> 11 #include <armnn/Types.hpp> 12 #include <armnn/utility/PolymorphicDowncast.hpp> 13 14 #include <reference/RefTensorHandle.hpp> 15 16 #include <BFloat16.hpp> 17 #include <Half.hpp> 18 19 namespace armnn 20 { 21 22 //////////////////////////////////////////// 23 /// float32 helpers 24 //////////////////////////////////////////// 25 26 template <typename TensorHandleType = RefTensorHandle> GetTensorInfo(const ITensorHandle * tensorHandle)27 inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle) 28 { 29 // We know that reference workloads use RefTensorHandles for inputs and outputs 30 const TensorHandleType* refTensorHandle = 31 PolymorphicDowncast<const TensorHandleType*>(tensorHandle); 32 return refTensorHandle->GetTensorInfo(); 33 } 34 35 template <typename DataType, typename PayloadType> GetInputTensorData(unsigned int idx,const PayloadType & data)36 const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data) 37 { 38 const ITensorHandle* tensorHandle = data.m_Inputs[idx]; 39 return reinterpret_cast<const DataType*>(tensorHandle->Map()); 40 } 41 42 template <typename DataType, typename PayloadType> GetOutputTensorData(unsigned int idx,const PayloadType & data)43 DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data) 44 { 45 ITensorHandle* tensorHandle = data.m_Outputs[idx]; 46 return reinterpret_cast<DataType*>(tensorHandle->Map()); 47 } 48 49 template <typename DataType> GetOutputTensorData(ITensorHandle * tensorHandle)50 DataType* GetOutputTensorData(ITensorHandle* tensorHandle) 51 { 52 return reinterpret_cast<DataType*>(tensorHandle->Map()); 53 } 54 55 template <typename PayloadType> GetInputTensorDataFloat(unsigned int idx,const PayloadType & data)56 const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data) 57 { 58 return GetInputTensorData<float>(idx, data); 59 } 60 61 template <typename PayloadType> GetOutputTensorDataFloat(unsigned int idx,const PayloadType & data)62 float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data) 63 { 64 return GetOutputTensorData<float>(idx, data); 65 } 66 67 template <typename PayloadType> GetInputTensorDataHalf(unsigned int idx,const PayloadType & data)68 const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data) 69 { 70 return GetInputTensorData<Half>(idx, data); 71 } 72 73 template <typename PayloadType> GetOutputTensorDataHalf(unsigned int idx,const PayloadType & data)74 Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data) 75 { 76 return GetOutputTensorData<Half>(idx, data); 77 } 78 79 template <typename PayloadType> GetInputTensorDataBFloat16(unsigned int idx,const PayloadType & data)80 const BFloat16* GetInputTensorDataBFloat16(unsigned int idx, const PayloadType& data) 81 { 82 return GetInputTensorData<BFloat16>(idx, data); 83 } 84 85 template <typename PayloadType> GetOutputTensorDataBFloat16(unsigned int idx,const PayloadType & data)86 BFloat16* GetOutputTensorDataBFloat16(unsigned int idx, const PayloadType& data) 87 { 88 return GetOutputTensorData<BFloat16>(idx, data); 89 } 90 91 //////////////////////////////////////////// 92 /// u8 helpers 93 //////////////////////////////////////////// 94 95 template<typename T> Dequantize(const T * quant,const TensorInfo & info)96 std::vector<float> Dequantize(const T* quant, const TensorInfo& info) 97 { 98 std::vector<float> ret(info.GetNumElements()); 99 for (size_t i = 0; i < info.GetNumElements(); i++) 100 { 101 ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset()); 102 } 103 return ret; 104 } 105 106 template<typename T> Dequantize(const T * inputData,float * outputData,const TensorInfo & info)107 inline void Dequantize(const T* inputData, float* outputData, const TensorInfo& info) 108 { 109 for (unsigned int i = 0; i < info.GetNumElements(); i++) 110 { 111 outputData[i] = Dequantize<T>(inputData[i], info.GetQuantizationScale(), info.GetQuantizationOffset()); 112 } 113 } 114 Quantize(uint8_t * quant,const float * dequant,const TensorInfo & info)115 inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info) 116 { 117 for (size_t i = 0; i < info.GetNumElements(); i++) 118 { 119 quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset()); 120 } 121 } 122 123 } //namespace armnn 124