xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/LstmUtils.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1  //
2  // Copyright © 2017 Arm Ltd. All rights reserved.
3  // SPDX-License-Identifier: MIT
4  //
5  
6  //#pragma once
7  
8  #include "LstmUtils.hpp"
9  #include "BaseIterator.hpp"
10  #include <armnn/backends/TensorHandle.hpp>
11  
12  
13  // Helper functions ported from the Android code base
14  // Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
15  
VectorBatchVectorAdd(armnn::Decoder<float> & vector,uint32_t vSize,armnn::Decoder<float> & batchVector,uint32_t nBatch,armnn::Encoder<float> & outResult)16  void VectorBatchVectorAdd(armnn::Decoder<float>& vector,
17                            uint32_t vSize,
18                            armnn::Decoder<float>& batchVector,
19                            uint32_t nBatch,
20                            armnn::Encoder<float>& outResult )
21  {
22      for (uint32_t b = 0; b < nBatch; b++)
23      {
24          for (uint32_t v = 0; v < vSize; v++)
25          {
26              outResult.Set(batchVector.Get() + vector.Get());
27              ++outResult;
28              ++vector;
29              ++batchVector;
30          }
31          vector -= vSize;
32      }
33      batchVector -= vSize * nBatch;
34      outResult -= vSize * nBatch;
35  }
36  
37  
38  // Layer norm for each batch.
39  // normalization_epsilon is added to avoid divergence.
MeanStddevNormalization(armnn::Decoder<float> & input_vector,armnn::Encoder<float> & output_vector,uint32_t v_size,uint32_t n_batch,float normalization_epsilon)40  void MeanStddevNormalization(armnn::Decoder<float>& input_vector,
41                               armnn::Encoder<float>& output_vector,
42                               uint32_t v_size,
43                               uint32_t n_batch,
44                               float normalization_epsilon)
45  {
46      for (uint32_t batch = 0; batch < n_batch; ++batch) {
47          float sum = 0.0f;
48          float sum_sq = 0.0f;
49          for (uint32_t i = 0; i < v_size; ++i) {
50              sum += input_vector.Get();
51              sum_sq += input_vector.Get() * input_vector.Get();
52              ++input_vector;
53          }
54          input_vector -= v_size;
55  
56          const float mean = sum / static_cast<float>(v_size);
57          float stddev_inv = 0.0f;
58          const float variance = sum_sq / static_cast<float>(v_size) - mean * mean;
59          if (variance == 0) {
60              stddev_inv = 1.0f / std::sqrt(normalization_epsilon);
61          } else {
62              stddev_inv = 1.0f / std::sqrt(variance);
63          }
64  
65          for (uint32_t i = 0; i < v_size; ++i) {
66              output_vector.Set((input_vector.Get() - mean) * stddev_inv);
67              ++output_vector;
68              ++input_vector;
69          }
70          // Don't reset iterator to handle next batch
71      }
72      output_vector -= v_size * n_batch;
73      input_vector -= v_size * n_batch;
74  }
75  
ZeroVector(armnn::Encoder<float> & vector,uint32_t vSize)76  void ZeroVector(armnn::Encoder<float>& vector,
77                  uint32_t vSize)
78  {
79      for (uint32_t v = 0; v < vSize; v++)
80      {
81          vector.Set(0.0f);
82          ++vector;
83      }
84      vector -= vSize;
85  }
86  
MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float> & matrix,uint32_t mRows,uint32_t mCols,armnn::Decoder<float> & vector,uint32_t nBatch,armnn::Encoder<float> & outResult)87  void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float>& matrix,
88                                           uint32_t mRows,
89                                           uint32_t mCols,
90                                           armnn::Decoder<float>& vector,
91                                           uint32_t nBatch,
92                                           armnn::Encoder<float>& outResult)
93  {
94      for (uint32_t b = 0; b < nBatch; b++)
95      {
96          for (uint32_t r = 0; r < mRows; r++)
97          {
98              vector += b * mCols;
99              for (uint32_t c = 0; c < mCols; c++)
100              {
101                  outResult.Set(outResult.Get() + matrix.Get() * vector.Get());
102                  ++matrix;
103                  ++vector;
104              }
105              outResult += 1;
106              vector -= (b+1) * mCols;
107          }
108          matrix -= (mRows * mCols);
109      }
110      outResult -= (mRows * nBatch);
111  }
112  
VectorBatchVectorAssign(armnn::Decoder<float> & vector,uint32_t vSize,uint32_t nBatch,armnn::Encoder<float> & outBatchVector)113  void VectorBatchVectorAssign(armnn::Decoder<float>& vector,
114                               uint32_t vSize,
115                               uint32_t nBatch,
116                               armnn::Encoder<float>& outBatchVector)
117  {
118      for (uint32_t b = 0; b < nBatch; b++)
119      {
120          for (uint32_t v = 0; v < vSize; v++)
121          {
122              outBatchVector.Set(vector.Get());
123              ++outBatchVector;
124              ++vector;
125          }
126          vector -= vSize;
127      }
128      outBatchVector -= (nBatch * vSize);
129  }
130  
VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float> & vector,uint32_t vSize,armnn::Decoder<float> & batchVector,uint32_t nBatch,armnn::Encoder<float> & outResult)131  void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float>& vector,
132                                               uint32_t vSize,
133                                               armnn::Decoder<float>& batchVector,
134                                               uint32_t nBatch,
135                                               armnn::Encoder<float>& outResult)
136  {
137      for (uint32_t b = 0; b < nBatch; b++)
138      {
139          for (uint32_t v = 0; v < vSize; v++)
140          {
141              outResult.Set(outResult.Get() + vector.Get() * batchVector.Get());
142              ++outResult;
143              ++vector;
144              ++batchVector;
145          }
146          vector -= vSize;
147      }
148      batchVector -= vSize * nBatch;
149      outResult -= vSize * nBatch;
150  }
151  
VectorBatchVectorCwiseProduct(armnn::Decoder<float> & vector,uint32_t vSize,armnn::Decoder<float> & batchVector,uint32_t nBatch,armnn::Encoder<float> & outResult)152  void VectorBatchVectorCwiseProduct(armnn::Decoder<float>& vector,
153                                     uint32_t vSize,
154                                     armnn::Decoder<float>& batchVector,
155                                     uint32_t nBatch,
156                                     armnn::Encoder<float>& outResult)
157  {
158      for (uint32_t b = 0; b < nBatch; b++)
159      {
160          for (uint32_t v = 0; v < vSize; v++)
161          {
162              outResult.Set(vector.Get() * batchVector.Get());
163              ++outResult;
164              ++vector;
165              ++batchVector;
166          }
167          vector -= vSize;
168      }
169      batchVector -= vSize * nBatch;
170      outResult -= vSize * nBatch;
171  }
172  
Sub1Vector(armnn::Decoder<float> & vector,uint32_t vSize,armnn::Encoder<float> & result)173  void Sub1Vector(armnn::Decoder<float>& vector,
174                  uint32_t vSize,
175                  armnn::Encoder<float>& result)
176  {
177      for (uint32_t v = 0; v < vSize; v++)
178      {
179          result.Set(1.0f - vector.Get());
180          ++vector;
181          ++result;
182      }
183      vector -= vSize;
184      result -= vSize;
185  }
186  
VectorVectorCwiseProduct(armnn::Decoder<float> & vector1,armnn::Decoder<float> & vector2,uint32_t vSize,armnn::Encoder<float> & outResult)187  void VectorVectorCwiseProduct(armnn::Decoder<float>& vector1,
188                                armnn::Decoder<float>& vector2,
189                                uint32_t vSize,
190                                armnn::Encoder<float>& outResult)
191  {
192      for (uint32_t v = 0; v < vSize; v++)
193      {
194          outResult.Set(vector1.Get() * vector2.Get());
195          ++outResult;
196          ++vector1;
197          ++vector2;
198      }
199      outResult -= vSize;
200      vector1 -= vSize;
201      vector2 -= vSize;
202  }
203  
VectorVectorCwiseProductAccumulate(armnn::Decoder<float> & vector1,armnn::Decoder<float> & vector2,uint32_t vSize,armnn::Encoder<float> & outResult)204  void VectorVectorCwiseProductAccumulate(armnn::Decoder<float>& vector1,
205                                          armnn::Decoder<float>& vector2,
206                                          uint32_t vSize,
207                                          armnn::Encoder<float>& outResult)
208  {
209      for (uint32_t v = 0; v < vSize; v++)
210      {
211          outResult.Set(outResult.Get() + vector1.Get() * vector2.Get());
212          ++outResult;
213          ++vector1;
214          ++vector2;
215      }
216      outResult -= vSize;
217      vector1 -= vSize;
218      vector2 -= vSize;
219  }
220  
Clip(float f,float absLimit)221  float Clip(float f,
222             float absLimit)
223  {
224      float result = (absLimit < f) ? absLimit : f;
225      result = (-absLimit > result) ? -absLimit : result;
226      return result;
227  }
228  
ClipVector(armnn::Decoder<float> & vector,uint32_t vSize,float absLimit,armnn::Encoder<float> & outResult)229  void ClipVector(armnn::Decoder<float>& vector,
230                  uint32_t vSize,
231                  float absLimit,
232                  armnn::Encoder<float>& outResult)
233  {
234      for (uint32_t v = 0; v < vSize; v++)
235      {
236          outResult.Set(Clip(vector.Get(), absLimit));
237          ++vector;
238          ++outResult;
239      }
240      vector -= vSize;
241      outResult -= vSize;
242  }
243  
CopyVector(armnn::Decoder<float> & vector,uint32_t vSize,armnn::Encoder<float> & outResult)244  void CopyVector(armnn::Decoder<float>& vector,
245                  uint32_t vSize,
246                  armnn::Encoder<float>& outResult)
247  {
248      for (uint32_t v = 0; v < vSize; v++)
249      {
250          outResult.Set(vector.Get());
251          ++outResult;
252          ++vector;
253      }
254      outResult -= vSize;
255      vector -= vSize;
256  }
257  
SetActivationParameters(uint32_t activation,armnn::ActivationFunction & outArmnnActivation,float & outA,float & outB)258  void SetActivationParameters(uint32_t activation,
259                               armnn::ActivationFunction& outArmnnActivation,
260                               float& outA,
261                               float& outB)
262  {
263      switch (activation)
264      {
265          case 0: // None
266              outA = 0;
267              outB = 0;
268              return;
269  
270          case 1: // Relu
271              outArmnnActivation = armnn::ActivationFunction::ReLu;
272              outA = 0;
273              outB = 0;
274              return;
275  
276          case 3: // Relu6
277              outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
278              outA = 6;
279              outB = 0;
280              return;
281  
282          case 4: // Tanh
283              outArmnnActivation = armnn::ActivationFunction::TanH;
284              outA = 1;
285              outB = 1;
286              return;
287  
288          case 6: // Sigmoid
289              outArmnnActivation = armnn::ActivationFunction::Sigmoid;
290              outA = 0;
291              outB = 0;
292              return;
293  
294          default:
295              throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
296      }
297  }
298  
AssignScopedTensorHandle(const armnn::ConstTensorHandle * ptr)299  std::unique_ptr<armnn::ScopedTensorHandle> AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
300  {
301      if (!ptr)
302      {
303          return nullptr;
304      }
305  
306      return std::make_unique<armnn::ScopedTensorHandle>(*ptr);
307  }
308