1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/utils/math_utils.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/utils/unroll.h>
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/parallel/thread_parallel.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/portable_type/bfloat16.h>
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker #include <array>
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
20*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
scale_(int64_t m,int64_t n,opmath_t alpha,scalar_t * a,int64_t lda)23*523fa7a6SAndroid Build Coastguard Worker void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t* a, int64_t lda) {
24*523fa7a6SAndroid Build Coastguard Worker if (alpha == opmath_t(1)) {
25*523fa7a6SAndroid Build Coastguard Worker return; // identity
26*523fa7a6SAndroid Build Coastguard Worker }
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker if (alpha == opmath_t(0)) {
29*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
30*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < m; ++i) {
31*523fa7a6SAndroid Build Coastguard Worker a[j * lda + i] = scalar_t(0);
32*523fa7a6SAndroid Build Coastguard Worker }
33*523fa7a6SAndroid Build Coastguard Worker }
34*523fa7a6SAndroid Build Coastguard Worker return;
35*523fa7a6SAndroid Build Coastguard Worker }
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
38*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < m; ++i) {
39*523fa7a6SAndroid Build Coastguard Worker a[j * lda + i] *= alpha;
40*523fa7a6SAndroid Build Coastguard Worker }
41*523fa7a6SAndroid Build Coastguard Worker }
42*523fa7a6SAndroid Build Coastguard Worker }
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker template <typename Func>
sum(int64_t N,Func f)45*523fa7a6SAndroid Build Coastguard Worker auto sum(int64_t N, Func f) {
46*523fa7a6SAndroid Build Coastguard Worker constexpr int ilp_factor = 4;
47*523fa7a6SAndroid Build Coastguard Worker using acc_t = decltype(f(0));
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker // Calculate independent partial sums then add together at the end
50*523fa7a6SAndroid Build Coastguard Worker std::array<acc_t, ilp_factor> partial_sums{};
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker size_t i = 0;
53*523fa7a6SAndroid Build Coastguard Worker for (; i + ilp_factor <= N; i += ilp_factor) {
54*523fa7a6SAndroid Build Coastguard Worker utils::ForcedUnroll<ilp_factor>{}(
55*523fa7a6SAndroid Build Coastguard Worker [&i, &f, &partial_sums](int k) { partial_sums[k] += f(i + k); });
56*523fa7a6SAndroid Build Coastguard Worker }
57*523fa7a6SAndroid Build Coastguard Worker for (; i < N; ++i) {
58*523fa7a6SAndroid Build Coastguard Worker partial_sums[0] += f(i);
59*523fa7a6SAndroid Build Coastguard Worker }
60*523fa7a6SAndroid Build Coastguard Worker for (int k = 1; k < ilp_factor; ++k) {
61*523fa7a6SAndroid Build Coastguard Worker partial_sums[0] += partial_sums[k];
62*523fa7a6SAndroid Build Coastguard Worker }
63*523fa7a6SAndroid Build Coastguard Worker return partial_sums[0];
64*523fa7a6SAndroid Build Coastguard Worker }
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
67*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)68*523fa7a6SAndroid Build Coastguard Worker gemm_notrans_(
69*523fa7a6SAndroid Build Coastguard Worker int64_t m,
70*523fa7a6SAndroid Build Coastguard Worker int64_t n,
71*523fa7a6SAndroid Build Coastguard Worker int64_t k,
72*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
73*523fa7a6SAndroid Build Coastguard Worker const scalar_t* a,
74*523fa7a6SAndroid Build Coastguard Worker int64_t lda,
75*523fa7a6SAndroid Build Coastguard Worker const scalar_t* b,
76*523fa7a6SAndroid Build Coastguard Worker int64_t ldb,
77*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
78*523fa7a6SAndroid Build Coastguard Worker scalar_t* c,
79*523fa7a6SAndroid Build Coastguard Worker int64_t ldc) {
80*523fa7a6SAndroid Build Coastguard Worker // c *= beta
81*523fa7a6SAndroid Build Coastguard Worker scale_(m, n, beta, c, ldc);
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker // c += alpha * (a @ b)
84*523fa7a6SAndroid Build Coastguard Worker for (size_t l = 0; l < k; ++l) {
85*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
86*523fa7a6SAndroid Build Coastguard Worker opmath_t val = b[l + j * ldb] * alpha;
87*523fa7a6SAndroid Build Coastguard Worker int64_t i_m = m / 4;
88*523fa7a6SAndroid Build Coastguard Worker for (int64_t i_i = 0; i_i < i_m; ++i_i) {
89*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
90*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
91*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
92*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
93*523fa7a6SAndroid Build Coastguard Worker }
94*523fa7a6SAndroid Build Coastguard Worker int64_t i = i_m * 4;
95*523fa7a6SAndroid Build Coastguard Worker for (; i < m; i++) {
96*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] += a[i + l * lda] * val;
97*523fa7a6SAndroid Build Coastguard Worker }
98*523fa7a6SAndroid Build Coastguard Worker }
99*523fa7a6SAndroid Build Coastguard Worker }
100*523fa7a6SAndroid Build Coastguard Worker }
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
103*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
104*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)105*523fa7a6SAndroid Build Coastguard Worker gemm_notrans_(
106*523fa7a6SAndroid Build Coastguard Worker int64_t m,
107*523fa7a6SAndroid Build Coastguard Worker int64_t n,
108*523fa7a6SAndroid Build Coastguard Worker int64_t k,
109*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
110*523fa7a6SAndroid Build Coastguard Worker const scalar_t* a,
111*523fa7a6SAndroid Build Coastguard Worker int64_t lda,
112*523fa7a6SAndroid Build Coastguard Worker const scalar_t* b,
113*523fa7a6SAndroid Build Coastguard Worker int64_t ldb,
114*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
115*523fa7a6SAndroid Build Coastguard Worker scalar_t* c,
116*523fa7a6SAndroid Build Coastguard Worker int64_t ldc) {
117*523fa7a6SAndroid Build Coastguard Worker // c += alpha * (a @ b)
118*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < m; ++i) {
119*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
120*523fa7a6SAndroid Build Coastguard Worker const auto dot = sum(k, [&](int64_t l) -> opmath_t {
121*523fa7a6SAndroid Build Coastguard Worker return static_cast<opmath_t>(a[l * lda + i]) *
122*523fa7a6SAndroid Build Coastguard Worker static_cast<opmath_t>(b[j * ldb + l]);
123*523fa7a6SAndroid Build Coastguard Worker });
124*523fa7a6SAndroid Build Coastguard Worker if (beta == opmath_t(0)) {
125*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] = alpha * dot;
126*523fa7a6SAndroid Build Coastguard Worker } else {
127*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
128*523fa7a6SAndroid Build Coastguard Worker }
129*523fa7a6SAndroid Build Coastguard Worker }
130*523fa7a6SAndroid Build Coastguard Worker }
131*523fa7a6SAndroid Build Coastguard Worker }
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker // clang-format off
134*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
gemm_transa_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)135*523fa7a6SAndroid Build Coastguard Worker void gemm_transa_(
136*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
137*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
138*523fa7a6SAndroid Build Coastguard Worker const scalar_t *a, int64_t lda,
139*523fa7a6SAndroid Build Coastguard Worker const scalar_t *b, int64_t ldb,
140*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
141*523fa7a6SAndroid Build Coastguard Worker scalar_t *c, int64_t ldc) {
142*523fa7a6SAndroid Build Coastguard Worker // c = alpha * (a.T @ b) + beta * c
143*523fa7a6SAndroid Build Coastguard Worker const scalar_t *a_ = a;
144*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < m; ++i) {
145*523fa7a6SAndroid Build Coastguard Worker const scalar_t *b_ = b;
146*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
147*523fa7a6SAndroid Build Coastguard Worker const auto dot = sum(k, [&](int64_t l) -> opmath_t {
148*523fa7a6SAndroid Build Coastguard Worker return static_cast<opmath_t>(a_[l]) * static_cast<opmath_t>(b_[l]);
149*523fa7a6SAndroid Build Coastguard Worker });
150*523fa7a6SAndroid Build Coastguard Worker b_ += ldb;
151*523fa7a6SAndroid Build Coastguard Worker if (beta == opmath_t(0)) {
152*523fa7a6SAndroid Build Coastguard Worker c[j*ldc+i] = alpha*dot;
153*523fa7a6SAndroid Build Coastguard Worker } else {
154*523fa7a6SAndroid Build Coastguard Worker c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
155*523fa7a6SAndroid Build Coastguard Worker }
156*523fa7a6SAndroid Build Coastguard Worker }
157*523fa7a6SAndroid Build Coastguard Worker a_ += lda;
158*523fa7a6SAndroid Build Coastguard Worker }
159*523fa7a6SAndroid Build Coastguard Worker }
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker #ifdef __aarch64__
162*523fa7a6SAndroid Build Coastguard Worker namespace internal {
163*523fa7a6SAndroid Build Coastguard Worker float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len);
164*523fa7a6SAndroid Build Coastguard Worker } // namespace internal
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker template <>
167*523fa7a6SAndroid Build Coastguard Worker inline void gemm_transa_<torch::executor::BFloat16, torch::executor::BFloat16>(
168*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
169*523fa7a6SAndroid Build Coastguard Worker torch::executor::BFloat16 alpha,
170*523fa7a6SAndroid Build Coastguard Worker const torch::executor::BFloat16 *a, int64_t lda,
171*523fa7a6SAndroid Build Coastguard Worker const torch::executor::BFloat16 *b, int64_t ldb,
172*523fa7a6SAndroid Build Coastguard Worker torch::executor::BFloat16 beta,
173*523fa7a6SAndroid Build Coastguard Worker torch::executor::BFloat16 *c, int64_t ldc) {
174*523fa7a6SAndroid Build Coastguard Worker // c = alpha * (a.T @ b) + beta * c
175*523fa7a6SAndroid Build Coastguard Worker if (alpha == 1 && beta == 0) {
176*523fa7a6SAndroid Build Coastguard Worker executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
177*523fa7a6SAndroid Build Coastguard Worker const auto *a_ = a + begin * lda;
178*523fa7a6SAndroid Build Coastguard Worker for (int i = begin; i < end; ++i) {
179*523fa7a6SAndroid Build Coastguard Worker const auto *b_ = b;
180*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < n; ++j) {
181*523fa7a6SAndroid Build Coastguard Worker const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
182*523fa7a6SAndroid Build Coastguard Worker b_ += ldb;
183*523fa7a6SAndroid Build Coastguard Worker c[j*ldc+i] = dot;
184*523fa7a6SAndroid Build Coastguard Worker }
185*523fa7a6SAndroid Build Coastguard Worker a_ += lda;
186*523fa7a6SAndroid Build Coastguard Worker }
187*523fa7a6SAndroid Build Coastguard Worker });
188*523fa7a6SAndroid Build Coastguard Worker return;
189*523fa7a6SAndroid Build Coastguard Worker }
190*523fa7a6SAndroid Build Coastguard Worker executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
191*523fa7a6SAndroid Build Coastguard Worker const auto *a_ = a + begin * lda;
192*523fa7a6SAndroid Build Coastguard Worker for (int i = begin; i < end; ++i) {
193*523fa7a6SAndroid Build Coastguard Worker const auto *b_ = b;
194*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < n; ++j) {
195*523fa7a6SAndroid Build Coastguard Worker const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
196*523fa7a6SAndroid Build Coastguard Worker b_ += ldb;
197*523fa7a6SAndroid Build Coastguard Worker if (beta == 0) {
198*523fa7a6SAndroid Build Coastguard Worker c[j*ldc+i] = alpha*dot;
199*523fa7a6SAndroid Build Coastguard Worker } else {
200*523fa7a6SAndroid Build Coastguard Worker c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
201*523fa7a6SAndroid Build Coastguard Worker }
202*523fa7a6SAndroid Build Coastguard Worker }
203*523fa7a6SAndroid Build Coastguard Worker a_ += lda;
204*523fa7a6SAndroid Build Coastguard Worker }
205*523fa7a6SAndroid Build Coastguard Worker });
206*523fa7a6SAndroid Build Coastguard Worker }
207*523fa7a6SAndroid Build Coastguard Worker #endif
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker // clang-format on
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
212*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)213*523fa7a6SAndroid Build Coastguard Worker gemm_transb_(
214*523fa7a6SAndroid Build Coastguard Worker int64_t m,
215*523fa7a6SAndroid Build Coastguard Worker int64_t n,
216*523fa7a6SAndroid Build Coastguard Worker int64_t k,
217*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
218*523fa7a6SAndroid Build Coastguard Worker const scalar_t* a,
219*523fa7a6SAndroid Build Coastguard Worker int64_t lda,
220*523fa7a6SAndroid Build Coastguard Worker const scalar_t* b,
221*523fa7a6SAndroid Build Coastguard Worker int64_t ldb,
222*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
223*523fa7a6SAndroid Build Coastguard Worker scalar_t* c,
224*523fa7a6SAndroid Build Coastguard Worker int64_t ldc) {
225*523fa7a6SAndroid Build Coastguard Worker // c *= beta
226*523fa7a6SAndroid Build Coastguard Worker scale_(m, n, beta, c, ldc);
227*523fa7a6SAndroid Build Coastguard Worker
228*523fa7a6SAndroid Build Coastguard Worker // c += alpha * (a @ b.T)
229*523fa7a6SAndroid Build Coastguard Worker for (size_t l = 0; l < k; ++l) {
230*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
231*523fa7a6SAndroid Build Coastguard Worker opmath_t val = b[j + l * ldb] * alpha;
232*523fa7a6SAndroid Build Coastguard Worker int64_t i_m = m / 4;
233*523fa7a6SAndroid Build Coastguard Worker for (int64_t i_i = 0; i_i < i_m; ++i_i) {
234*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
235*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
236*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
237*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
238*523fa7a6SAndroid Build Coastguard Worker }
239*523fa7a6SAndroid Build Coastguard Worker int64_t i = i_m * 4;
240*523fa7a6SAndroid Build Coastguard Worker for (; i < m; i++) {
241*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] += a[i + l * lda] * val;
242*523fa7a6SAndroid Build Coastguard Worker }
243*523fa7a6SAndroid Build Coastguard Worker }
244*523fa7a6SAndroid Build Coastguard Worker }
245*523fa7a6SAndroid Build Coastguard Worker }
246*523fa7a6SAndroid Build Coastguard Worker
247*523fa7a6SAndroid Build Coastguard Worker // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
248*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
249*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)250*523fa7a6SAndroid Build Coastguard Worker gemm_transb_(
251*523fa7a6SAndroid Build Coastguard Worker int64_t m,
252*523fa7a6SAndroid Build Coastguard Worker int64_t n,
253*523fa7a6SAndroid Build Coastguard Worker int64_t k,
254*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
255*523fa7a6SAndroid Build Coastguard Worker const scalar_t* a,
256*523fa7a6SAndroid Build Coastguard Worker int64_t lda,
257*523fa7a6SAndroid Build Coastguard Worker const scalar_t* b,
258*523fa7a6SAndroid Build Coastguard Worker int64_t ldb,
259*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
260*523fa7a6SAndroid Build Coastguard Worker scalar_t* c,
261*523fa7a6SAndroid Build Coastguard Worker int64_t ldc) {
262*523fa7a6SAndroid Build Coastguard Worker // c += alpha * (a @ b.T)
263*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < m; ++i) {
264*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
265*523fa7a6SAndroid Build Coastguard Worker const auto dot = sum(k, [&](int64_t l) -> opmath_t {
266*523fa7a6SAndroid Build Coastguard Worker return static_cast<opmath_t>(a[l * lda + i]) *
267*523fa7a6SAndroid Build Coastguard Worker static_cast<opmath_t>(b[l * ldb + j]);
268*523fa7a6SAndroid Build Coastguard Worker });
269*523fa7a6SAndroid Build Coastguard Worker if (beta == opmath_t(0)) {
270*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] = alpha * dot;
271*523fa7a6SAndroid Build Coastguard Worker } else {
272*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
273*523fa7a6SAndroid Build Coastguard Worker }
274*523fa7a6SAndroid Build Coastguard Worker }
275*523fa7a6SAndroid Build Coastguard Worker }
276*523fa7a6SAndroid Build Coastguard Worker }
277*523fa7a6SAndroid Build Coastguard Worker
278*523fa7a6SAndroid Build Coastguard Worker // clang-format off
279*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
gemm_transab_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)280*523fa7a6SAndroid Build Coastguard Worker void gemm_transab_(
281*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
282*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
283*523fa7a6SAndroid Build Coastguard Worker const scalar_t *a, int64_t lda,
284*523fa7a6SAndroid Build Coastguard Worker const scalar_t *b, int64_t ldb,
285*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
286*523fa7a6SAndroid Build Coastguard Worker scalar_t *c, int64_t ldc) {
287*523fa7a6SAndroid Build Coastguard Worker // c = beta * c + alpha * (a.T @ b.T)
288*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < m; ++i) {
289*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < n; ++j) {
290*523fa7a6SAndroid Build Coastguard Worker const auto dot = sum(k, [&](int64_t l) -> opmath_t {
291*523fa7a6SAndroid Build Coastguard Worker return static_cast<opmath_t>(a[i * lda + l]) *
292*523fa7a6SAndroid Build Coastguard Worker static_cast<opmath_t>(b[l * ldb + j]);
293*523fa7a6SAndroid Build Coastguard Worker });
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker if (beta == opmath_t(0)) {
296*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] = alpha * dot;
297*523fa7a6SAndroid Build Coastguard Worker } else {
298*523fa7a6SAndroid Build Coastguard Worker c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
299*523fa7a6SAndroid Build Coastguard Worker }
300*523fa7a6SAndroid Build Coastguard Worker }
301*523fa7a6SAndroid Build Coastguard Worker }
302*523fa7a6SAndroid Build Coastguard Worker }
303*523fa7a6SAndroid Build Coastguard Worker // clang-format on
304*523fa7a6SAndroid Build Coastguard Worker
305*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
306*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
307