xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/vec.h>
7 #include <ATen/cpu/vec/functional.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <ATen/native/transformers/attention.h>
11 #include <ATen/native/transformers/sdp_utils_cpp.h>
12 #include <c10/util/irange.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #endif
19 
20 namespace at::native {
21 
22 namespace {
23 
24 // out = val * a + b
25 template <typename T1, typename T2>
_scale_attn_mask_fusion_kernel(T1 * a,T2 * b,const int & size,T1 * out,T1 & val)26 inline void _scale_attn_mask_fusion_kernel(
27     T1* a,
28     T2* b,
29     const int& size,
30     T1* out,
31     T1& val) {
32   const auto vec_size1 = at::vec::Vectorized<T1>::size();
33   const auto vec_size2 = at::vec::Vectorized<T2>::size();
34   constexpr int64_t T1_n = (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
35   constexpr int64_t T2_n = 1;
36   auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
37   int64_t i = 0;
38   for (; i < size - (size % vec_size2); i += vec_size2) {
39     auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
40     auto b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
41     auto b_n_convert = at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
42     auto res = a_n * vec_scale + b_n_convert;
43     res.store(out + i);
44   }
45   for (; i < size; i++) {
46     auto tmp0 = a[i];
47     auto tmp1 = (T1) b[i];
48     out[i] = tmp0 * val + tmp1;
49   }
50 }
51 
52 // 1) out = exp(a - val)
53 // 2) val = sum(out)
54 template <typename T1, typename T2>
_exp_reduce_sum_fusion_kernel(T1 * a,const int & size,T2 * out,T1 & val)55 inline void _exp_reduce_sum_fusion_kernel(
56     T1* a,
57     const int& size,
58     T2* out,
59     T1& val) {
60   auto vec_size = vec::Vectorized<T1>::size();
61   auto vec_max = vec::Vectorized<T1>(val);
62   T1 tmp_sum = 0;
63   auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
64   for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
65     auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
66     auto tmp1 = tmp0 - vec_max;
67     auto tmp2 = tmp1.exp_u20();
68     vec_tmp_sum += tmp2;
69     _store(out + i, tmp2);
70   }
71   tmp_sum = vec::vec_reduce_all<T1>(
72       [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) {
73         return x + y;
74       },
75       vec_tmp_sum);
76   for (long i = vec_size * (size / vec_size); i < size; i++) {
77     auto tmp0 = a[i];
78     auto tmp1 = tmp0 - val;
79     auto tmp2 = exp(tmp1);
80     tmp_sum += tmp2;
81     out[i] = tmp2;
82   }
83   val = tmp_sum;
84 }
85 
86 // 1) out = a * scale
87 // 2) max = max(out)
88 template <typename scalar_t>
_mul_reduce_max_fusion_kernel(const scalar_t * a,const scalar_t & scale,const int & size,scalar_t * out,scalar_t & max)89 inline void _mul_reduce_max_fusion_kernel(
90     const scalar_t* a,
91     const scalar_t& scale,
92     const int& size,
93     scalar_t* out,
94     scalar_t& max) {
95   auto vec_size = vec::Vectorized<scalar_t>::size();
96   auto vec_scale = vec::Vectorized<scalar_t>(scale);
97   scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
98   auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
99   for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
100     auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
101     auto tmp1 = tmp0 * vec_scale;
102     vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
103     _store(out + i, tmp1);
104   }
105   for (long i = vec_size * (size / vec_size); i < size; i++) {
106     auto tmp0 = a[i];
107     auto tmp1 = tmp0 * scale;
108     tmp_max = std::max(tmp_max, tmp1);
109     out[i] = tmp1;
110   }
111   max = std::max(
112       tmp_max,
113       vec::vec_reduce_all<scalar_t>(
114           [](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
115             return vec::maximum(x, y);
116           },
117           vec_tmp_max));
118 }
119 
120 template <typename scalar_t>
conditional_data_ptr(scalar_t * ptr,scalar_t * ptr2)121 static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
122   TORCH_CHECK(ptr2 == nullptr);
123   return ptr;
124 }
125 
126 template <typename scalar_t,
127           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
conditional_data_ptr(float * ptr,scalar_t * ptr2)128 static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
129   return ptr2;
130 }
131 
132 template <typename scalar_t>
fill_stub(scalar_t * data,scalar_t val,int64_t size)133 inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
134   using Vec = Vectorized<scalar_t>;
135   Vec data_vec = Vec(val);
136   int64_t d = 0;
137   for (; d < size - (size % Vec::size()); d += Vec::size()) {
138     data_vec.store(data + d);
139   }
140   #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
141   # pragma unroll
142   #endif
143   for (; d < size; d++) {
144     data[d] = val;
145   }
146 }
147 
reshape_attn_mask_to_4d(Tensor & attn_mask,int64_t batchSize,int64_t num_head,int64_t qSize,int64_t kvSize)148 void reshape_attn_mask_to_4d(
149     Tensor& attn_mask,
150     int64_t batchSize,
151     int64_t num_head,
152     int64_t qSize,
153     int64_t kvSize) {
154   // Support mask shapes:
155   // 2d: ({Q_seq_len, 1}  x {KV_seq_len, 1})
156   // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1}  x {KV_seq_len, 1})
157   // Guaranteed in check_attn_mask_shape
158   int64_t attn_mask_size_0 = 1;
159   int64_t attn_mask_size_1 = 1;
160   if (attn_mask.dim() == 4) {
161     if (attn_mask.size(0) == batchSize) {
162       attn_mask_size_0 = batchSize;
163     }
164     if (attn_mask.size(1) == num_head) {
165       attn_mask_size_1 = num_head;
166     }
167   }
168   attn_mask = attn_mask
169                 .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)})
170                 .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
171 }
172 
173 template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
cpu_flash_attention(const Tensor & output,const Tensor & logsumexp,const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)174 void cpu_flash_attention(
175     const Tensor& output,
176     const Tensor& logsumexp,
177     const at::Tensor& q,
178     const at::Tensor& k,
179     const at::Tensor& v,
180     double dropout_p,
181     bool is_causal,
182     std::optional<Tensor> attn_mask,
183     std::optional<double> scale) {
184   // Query (Batch x Num_heads  x Q_seq_len  x Dim_per_head)
185   //    -> (Batch x Q_seq_len  x Num_heads  x Dim_per_head)
186   // Key   (Batch x Num_heads  x KV_seq_len x Dim_per_head)
187   //    -> (Batch x KV_seq_len x Num_heads  x Dim_per_head)
188   // Value (Batch x Num_heads  x KV_seq_len x Dim_per_head)
189   //    -> (Batch x KV_seq_len x Num_heads  x Dim_per_head)
190   at::Tensor query = q.transpose(1, 2);
191   at::Tensor key = k.transpose(1, 2);
192   at::Tensor value = v.transpose(1, 2);
193 
194   constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
195   using accum_t = at::opmath_type<scalar_t>;
196   using Vec = vec::Vectorized<accum_t>;
197   accum_t scaling_factor =
198       sdp::calculate_scale(query, scale).as_float_unchecked();
199 
200   // Sizes
201   TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
202         "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
203   int64_t batchSize = query.size(0);
204   int64_t qSize = query.size(1);
205   int64_t kvSize = value.size(1);
206   int64_t num_head = query.size(2);
207   int64_t headSize = query.size(3);
208 
209   bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
210   if (has_attn_mask) {
211     reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
212   }
213 
214   // Strides
215   int64_t qStrideB = query.stride(0);
216   int64_t qStrideM = query.stride(1);
217   int64_t qStrideH = query.stride(2);
218   int64_t kStrideB = key.stride(0);
219   int64_t kStrideN = key.stride(1);
220   int64_t kStrideH = key.stride(2);
221   int64_t vStrideB = value.stride(0);
222   int64_t vStrideN = value.stride(1);
223   int64_t vStrideH = value.stride(2);
224   int64_t oStrideB = output.stride(0);
225   int64_t oStrideM = output.stride(1);
226   int64_t oStrideH = output.stride(2);
227   int64_t lStrideB = logsumexp.stride(0);
228   int64_t lStrideM = logsumexp.stride(1);
229   int64_t lStrideH = logsumexp.stride(2);
230   int64_t mStrideB =
231       (has_attn_mask && attn_mask.value().size(0) > 1)
232       ? attn_mask.value().stride(0)
233       : 0;
234   int64_t mStrideH =
235       (has_attn_mask && attn_mask.value().size(1) > 1)
236       ? attn_mask.value().stride(1)
237       : 0;
238   int64_t mStrideM =
239       has_attn_mask ? attn_mask.value().stride(2) : 0;
240 
241   int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
242   int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
243   int64_t qSlice = (qSize - 1) / qSplitSize + 1;
244   int64_t num_thread = at::get_num_threads();
245 
246   const auto dtype = query.scalar_type();
247   const auto accumulate_dtype = toOpMathType(dtype);
248 
249   // allocate per thread temp buf (accumulate type)
250   int64_t size_per_thread =
251       /* qk     */ qSplitSize * kvSplitSize +
252       /* qk_max */ qSplitSize +
253       /* qk_sum */ qSplitSize +
254       /* dst    */ qSplitSize * headSize;
255 
256   at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
257   at::Tensor buf_reduced = at::empty({num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, query.options());
258 
259   // Data ptrs
260   const scalar_t* q_data = query.const_data_ptr<scalar_t>();
261   const scalar_t* k_data = key.const_data_ptr<scalar_t>();
262   const scalar_t* v_data = value.const_data_ptr<scalar_t>();
263   mask_t* mask_data = has_attn_mask
264       ? attn_mask.value().data_ptr<mask_t>()
265       : nullptr;
266   scalar_t* out_data = output.data_ptr<scalar_t>();
267   accum_t* lse_data = logsumexp.data_ptr<accum_t>();
268   accum_t* buf_data = buf.data_ptr<accum_t>();
269   scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
270 
271   at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
272     int64_t i = 0, j = 0, k = 0;
273     data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
274     int ompIdx = at::get_thread_num();
275     accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
276     accum_t* qk_data = buf_ptr;
277     accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
278     accum_t* qk_sum_data = qk_max_data + qSplitSize;
279     accum_t* dst_data = qk_sum_data + qSplitSize;
280     scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr;
281 
282     for (const auto z : c10::irange(begin, end)) {
283       (void)z; // Suppress unused variable
284       int64_t m = k * qSplitSize;
285       int64_t qBlockSize = std::min(qSplitSize, qSize - m);
286       // Initialize max and sum
287       fill_stub(qk_max_data,
288           -std::numeric_limits<accum_t>::infinity(), qBlockSize);
289       fill_stub(qk_sum_data,
290           static_cast<accum_t>(0), qBlockSize);
291       int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
292       for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
293         int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
294         // Calculate scale * q @ k.T
295         cpublas::gemm(
296             TransposeType::Transpose,
297             TransposeType::NoTranspose,
298             kvBlockSize,
299             qBlockSize,
300             headSize,
301             static_cast<accum_t>(1),
302             k_data + i * kStrideB + j * kStrideH +
303                 n * kStrideN,
304             kStrideN,
305             q_data + i * qStrideB + j * qStrideH +
306                 m * qStrideM,
307             qStrideM,
308             static_cast<accum_t>(0),
309             qk_data,
310             kvBlockSize);
311         // Apply causal mask, fill unused with -inf
312         if (is_causal && num_keys - n <= kvSplitSize) {
313           for (const auto row : c10::irange(qBlockSize)) {
314             int64_t last_col = m + row - n;
315             accum_t* row_ptr = qk_data + row * kvBlockSize;
316             fill_stub(row_ptr + last_col + 1,
317                 -std::numeric_limits<accum_t>::infinity(),
318                 kvBlockSize - last_col - 1);
319           }
320         }
321         // Update attention weights with attention mask
322         // And apply scaling factor
323         // qk <- qk * scaling + attn_mask
324         if (has_attn_mask) {
325           for (int64_t row = 0; row < qBlockSize; ++row) {
326             _scale_attn_mask_fusion_kernel(
327                 qk_data + row * kvBlockSize,
328                 mask_data + i * mStrideB + j * mStrideH +
329                         (m + row) * mStrideM + n,
330                 kvBlockSize,
331                 qk_data + row * kvBlockSize,
332                 scaling_factor);
333           }
334         }
335         // Update coefficients with Softmax
336         accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
337         for (int64_t row = 0; row < qBlockSize; ++row) {
338           if (has_attn_mask) {
339             // max per row
340             tmp_max = at::vec::reduce_all<accum_t>(
341                 [](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
342                 qk_data + row * kvBlockSize,
343                 kvBlockSize);
344           } else {
345             // apply scaling factor and max per row in fusion
346             _mul_reduce_max_fusion_kernel(
347                 qk_data + row * kvBlockSize,
348                 scaling_factor,
349                 kvBlockSize,
350                 qk_data + row * kvBlockSize,
351                 tmp_max);
352           }
353           tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
354           if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
355             // to avoid `nan = exp2f(-inf - (-inf))`
356             fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize,
357               static_cast<scalar_t>(0), kvBlockSize);
358           } else {
359             tmp_sum = tmp_max;
360             // qk <- exp(qk - max) and sum per row
361             _exp_reduce_sum_fusion_kernel(
362                 qk_data + row * kvBlockSize, kvBlockSize,
363                 conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize,
364                 tmp_sum);
365             // exp_tmp <- exp(max[row] - max)
366             exp_tmp = std::exp(qk_max_data[row] - tmp_max);
367             // sum[row] <- sum + exp_tmp * sum[row]
368             qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
369             // max[row] <- max
370             qk_max_data[row] = tmp_max;
371             // dst <- dst * exp_tmp
372             if (n > 0) {
373               vec::map<accum_t>(
374                 [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
375                 dst_data + row * headSize, dst_data + row * headSize, headSize);
376             }
377           }
378         }
379         // Calculate Softmax(q @ k.T) @ v
380         cpublas::gemm(
381             TransposeType::NoTranspose,
382             TransposeType::NoTranspose,
383             headSize,
384             qBlockSize,
385             kvBlockSize,
386             static_cast<accum_t>(1),
387             v_data + i * vStrideB + j * vStrideH +
388                 n * vStrideN,
389             vStrideN,
390             conditional_data_ptr(qk_data, qk_reduced_data),
391             kvBlockSize,
392             n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
393             dst_data,
394             headSize);
395       }
396       // dst <- dst / sum[row]
397       // reorder MHA output with strides
398       for (int64_t row = 0; row < qBlockSize; ++row) {
399         accum_t sum_reciprocal = 1 / qk_sum_data[row];
400         vec::map<scalar_t>(
401           [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
402           out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
403           dst_data + row * headSize,
404           headSize);
405       }
406       // Store logsumexp for backward
407       accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
408       for (const auto row : c10::irange(qBlockSize)) {
409         lse_ptr[row * lStrideM] = qk_max_data[row]
410             + std::log(qk_sum_data[row]);
411       }
412       // Move to the next query
413       data_index_step(i, batchSize, j, num_head, k, qSlice);
414     }
415   });
416 
417 }
418 
419 template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
cpu_flash_attention_backward(const at::Tensor & grad_q,const at::Tensor & grad_k,const at::Tensor & grad_v,const at::Tensor & grad_out,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)420 void cpu_flash_attention_backward(
421     const at::Tensor& grad_q,
422     const at::Tensor& grad_k,
423     const at::Tensor& grad_v,
424     const at::Tensor& grad_out,
425     const at::Tensor& query,
426     const at::Tensor& key,
427     const at::Tensor& value,
428     const at::Tensor& out,
429     const at::Tensor& logsumexp,
430     double dropout_p,
431     bool is_causal,
432     std::optional<Tensor> attn_mask,
433     std::optional<double> scale) {
434   constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
435   using accum_t = at::opmath_type<scalar_t>;
436   using Vec = vec::Vectorized<accum_t>;
437   accum_t scaling_factor =
438       sdp::calculate_scale(query, scale).as_float_unchecked();
439 
440   // Sizes
441   TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
442         "scaled_dot_product_attention_flash_attention_backward: Q/K/V should have the same head size");
443   // Query (Batch x Q_seq_len  x Num_heads x Dim_per_head)
444   // Key   (Batch x KV_seq_len x Num_heads x Dim_per_head)
445   // Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
446   int64_t batchSize = query.size(0);
447   int64_t qSize = query.size(1);
448   int64_t kvSize = value.size(1);
449   int64_t num_head = query.size(2);
450   int64_t headSize = query.size(3);
451 
452   bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
453   if (has_attn_mask) {
454     reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
455   }
456 
457   // Strides
458   int64_t qStrideB = query.stride(0);
459   int64_t qStrideM = query.stride(1);
460   int64_t qStrideH = query.stride(2);
461   int64_t kStrideB = key.stride(0);
462   int64_t kStrideN = key.stride(1);
463   int64_t kStrideH = key.stride(2);
464   int64_t vStrideB = value.stride(0);
465   int64_t vStrideN = value.stride(1);
466   int64_t vStrideH = value.stride(2);
467   int64_t oStrideB = out.stride(0);
468   int64_t oStrideM = out.stride(1);
469   int64_t oStrideH = out.stride(2);
470   int64_t lStrideB = logsumexp.stride(0);
471   int64_t lStrideM = logsumexp.stride(1);
472   int64_t lStrideH = logsumexp.stride(2);
473   int64_t mStrideB =
474       (has_attn_mask && attn_mask.value().size(0) > 1)
475       ? attn_mask.value().stride(0)
476       : 0;
477   int64_t mStrideH =
478       (has_attn_mask && attn_mask.value().size(1) > 1)
479       ? attn_mask.value().stride(1)
480       : 0;
481   int64_t mStrideM =
482       has_attn_mask ? attn_mask.value().stride(2) : 0;
483 
484   int64_t grad_qStrideB = grad_q.stride(0);
485   int64_t grad_qStrideM = grad_q.stride(1);
486   int64_t grad_qStrideH = grad_q.stride(2);
487   int64_t grad_kStrideB = grad_k.stride(0);
488   int64_t grad_kStrideN = grad_k.stride(1);
489   int64_t grad_kStrideH = grad_k.stride(2);
490   int64_t grad_vStrideB = grad_v.stride(0);
491   int64_t grad_vStrideN = grad_v.stride(1);
492   int64_t grad_vStrideH = grad_v.stride(2);
493   int64_t grad_oStrideB = grad_out.stride(0);
494   int64_t grad_oStrideM = grad_out.stride(1);
495   int64_t grad_oStrideH = grad_out.stride(2);
496 
497   int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
498   int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
499   int64_t num_thread = at::get_num_threads();
500 
501   const auto dtype = query.scalar_type();
502   const auto accumulate_dtype = toOpMathType(dtype);
503 
504   // allocate per thread temp buf (accumulate type)
505   int64_t size_per_thread =
506       /* attn      */ qSplitSize * kvSplitSize +
507       /* grad_attn */ qSplitSize * kvSplitSize;
508 
509   at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
510 
511   // allocate per thread temp buf_reduced (scalar type)
512   // buf2 is only needed for bfloat16 and float16
513   int64_t size_per_thread_reduced =
514       /* attn_reduced      */ qSplitSize * kvSplitSize +
515       /* grad_attn_reduced */ qSplitSize * kvSplitSize;
516 
517   at::Tensor buf_reduced = at::empty({num_thread, is_reduced_type ? size_per_thread_reduced : 0}, query.options());
518 
519   scalar_t* grad_q_data = grad_q.data_ptr<scalar_t>();
520   scalar_t* grad_k_data = grad_k.data_ptr<scalar_t>();
521   scalar_t* grad_v_data = grad_v.data_ptr<scalar_t>();
522   const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
523   const scalar_t* q_data = query.const_data_ptr<scalar_t>();
524   const scalar_t* k_data = key.const_data_ptr<scalar_t>();
525   const scalar_t* v_data = value.const_data_ptr<scalar_t>();
526   mask_t* mask_data = has_attn_mask
527       ? attn_mask.value().data_ptr<mask_t>()
528       : nullptr;
529   const scalar_t* out_data = out.const_data_ptr<scalar_t>();
530   const accum_t* lse_data = logsumexp.const_data_ptr<accum_t>();
531   accum_t* buf_data = buf.data_ptr<accum_t>();
532   scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
533 
534   at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
535     int64_t i = 0, j = 0;
536     data_index_init(begin, i, batchSize, j, num_head);
537     int ompIdx = at::get_thread_num();
538     accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
539     accum_t* attn_data = buf_ptr;
540     accum_t* grad_attn_data = attn_data + qSplitSize * kvSplitSize;
541     scalar_t* buf_reduced_ptr = is_reduced_type ? buf_reduced_data + ompIdx * size_per_thread_reduced : nullptr;
542     scalar_t* attn_reduced_data = is_reduced_type ? buf_reduced_ptr : nullptr;
543     scalar_t* grad_attn_reduced_data = is_reduced_type ? attn_reduced_data + qSplitSize * kvSplitSize : nullptr;
544 
545     at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype));
546     accum_t* dsum_data = dsum.data_ptr<accum_t>();
547     for (const auto z : c10::irange(begin, end)) {
548       (void)z; // Suppress unused variable
549       // rowsum of grad_out * out
550       for (int64_t m = 0; m < qSize; m += qSplitSize) {
551         int64_t qBlockSize = std::min(qSplitSize, qSize - m);
552         // dsum <- rowsum(grad_out * out)
553         for (const auto row : c10::irange(qBlockSize)) {
554           *(dsum_data + row) = vec::map2_reduce_all<scalar_t>(
555             [](Vec x, Vec y) { return x * y; },
556             [](Vec x, Vec y) { return x + y; },
557             grad_out_data + i * grad_oStrideB + j * grad_oStrideH + (m + row) * grad_oStrideM,
558             out_data + i * oStrideB + j * oStrideH + (m + row) * oStrideM,
559             headSize);
560         }
561         int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
562         for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
563           int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
564           // attn <- scale * q @ k.T
565           cpublas::gemm(
566             TransposeType::Transpose,
567             TransposeType::NoTranspose,
568             kvBlockSize,
569             qBlockSize,
570             headSize,
571             scaling_factor,
572             k_data + i * kStrideB + j * kStrideH +
573                 n * kStrideN,
574             kStrideN,
575             q_data + i * qStrideB + j * qStrideH +
576                 m * qStrideM,
577             qStrideM,
578             static_cast<accum_t>(0),
579             attn_data,
580             kvBlockSize);
581           // attn <- attn + mask
582           if (has_attn_mask) {
583             accum_t one = accum_t(1);
584             for (const auto row : c10::irange(qBlockSize)) {
585               _scale_attn_mask_fusion_kernel(
586                   attn_data + row * kvBlockSize,
587                   mask_data + i * mStrideB + j * mStrideH +
588                       (m + row) * mStrideM + n,
589                   kvBlockSize,
590                   attn_data + row * kvBlockSize,
591                   one);
592             }
593           }
594           // restore self attention after softmax from logsumexp
595           // attn <- exp(attn - normalizer)
596           for (const auto row : c10::irange(qBlockSize)) {
597             accum_t normalizer = lse_data[i * lStrideB + j * lStrideH + (m + row) * lStrideM];
598             vec::map<accum_t>(
599               [normalizer](Vec x) { return (x - Vec(normalizer)).exp(); },
600               attn_data + row * kvBlockSize,
601               attn_data + row * kvBlockSize,
602               kvBlockSize);
603           }
604           // Apply causal mask, filled unused with 0
605           if (is_causal && num_keys - n <= kvSplitSize) {
606             for (const auto row : c10::irange(qBlockSize)) {
607               int64_t last_col = m + row - n;
608               accum_t* row_ptr = attn_data + row * kvBlockSize;
609               fill_stub(row_ptr + last_col + 1, static_cast<accum_t>(0), kvBlockSize - last_col - 1);
610             }
611           }
612           if (is_reduced_type) {
613             for (const auto row : c10::irange(qBlockSize)) {
614               convert<accum_t, scalar_t>(
615                 attn_data + row * kvBlockSize,
616                 attn_reduced_data + row * kvBlockSize,
617                 kvBlockSize);
618             }
619           }
620           // grad_v <- grad_v + attn.T @ grad_out
621           cpublas::gemm(
622             TransposeType::NoTranspose,
623             TransposeType::Transpose,
624             headSize,
625             kvBlockSize,
626             qBlockSize,
627             static_cast<accum_t>(1),
628             grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
629                 m * grad_oStrideM,
630             grad_oStrideM,
631             conditional_data_ptr(attn_data, attn_reduced_data),
632             kvBlockSize,
633             static_cast<accum_t>(1),
634             grad_v_data + i * grad_vStrideB + j * grad_vStrideH +
635                 n * grad_vStrideN,
636             grad_vStrideN);
637           // grad_attn <- grad_out @ v.T
638           cpublas::gemm(
639             TransposeType::Transpose,
640             TransposeType::NoTranspose,
641             kvBlockSize,
642             qBlockSize,
643             headSize,
644             static_cast<accum_t>(1),
645             v_data + i * vStrideB + j * vStrideH +
646                 n * vStrideN,
647             vStrideN,
648             grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
649                 m * grad_oStrideM,
650             grad_oStrideM,
651             static_cast<accum_t>(0),
652             grad_attn_data,
653             kvBlockSize);
654           // grad_attn <- attn * (grad_attn - dsum)
655           for (const auto row : c10::irange(qBlockSize)) {
656             accum_t d = *(dsum_data + row);
657             vec::map2<accum_t>(
658               [d](Vec attn, Vec grad_attn) { return attn * (grad_attn - Vec(d)); },
659               grad_attn_data + row * kvBlockSize,
660               attn_data + row * kvBlockSize,
661               grad_attn_data + row * kvBlockSize,
662               kvBlockSize);
663           }
664           if (is_reduced_type) {
665             for (const auto row : c10::irange(qBlockSize)) {
666               convert<accum_t, scalar_t>(
667                 grad_attn_data + row * kvBlockSize,
668                 grad_attn_reduced_data + row * kvBlockSize,
669                 kvBlockSize);
670             }
671           }
672           // grad_q <- grad_q + scale * grad_attn @ k
673           cpublas::gemm(
674             TransposeType::NoTranspose,
675             TransposeType::NoTranspose,
676             headSize,
677             qBlockSize,
678             kvBlockSize,
679             scaling_factor,
680             k_data + i * kStrideB + j * kStrideH +
681                 n * kStrideN,
682             kStrideN,
683             conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
684             kvBlockSize,
685             static_cast<accum_t>(1),
686             grad_q_data + i * grad_qStrideB + j * grad_qStrideH +
687                 m * grad_qStrideM,
688             grad_qStrideM);
689           // grad_k <- grad_k + scale * grad_attn.T @ q
690           cpublas::gemm(
691             TransposeType::NoTranspose,
692             TransposeType::Transpose,
693             headSize,
694             kvBlockSize,
695             qBlockSize,
696             scaling_factor,
697             q_data + i * qStrideB + j * qStrideH +
698                 m * qStrideM,
699             qStrideM,
700             conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
701             kvBlockSize,
702             static_cast<accum_t>(1),
703             grad_k_data + i * grad_kStrideB + j * grad_kStrideH +
704                 n * grad_kStrideN,
705             grad_kStrideN);
706         }
707       }
708       // Move to the next query
709       data_index_step(i, batchSize, j, num_head);
710     }
711   });
712 }
713 
714 #define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...)            \
715   AT_DISPATCH_SWITCH(                                      \
716       TYPE,                                                \
717       NAME,                                                \
718       AT_PRIVATE_CASE_TYPE_USING_HINT(                     \
719           at::ScalarType::Bool, mask_t, __VA_ARGS__)       \
720       AT_PRIVATE_CASE_TYPE_USING_HINT(                     \
721           at::ScalarType::Float, mask_t, __VA_ARGS__)      \
722       AT_PRIVATE_CASE_TYPE_USING_HINT(                     \
723           at::ScalarType::Double, mask_t, __VA_ARGS__)     \
724       AT_PRIVATE_CASE_TYPE_USING_HINT(                     \
725           at::ScalarType::BFloat16, mask_t, __VA_ARGS__)   \
726       AT_PRIVATE_CASE_TYPE_USING_HINT(                     \
727           at::ScalarType::Half, mask_t, __VA_ARGS__))
728 
flash_attention_kernel_impl(const Tensor & output,const Tensor & logsumexp,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)729 void flash_attention_kernel_impl(
730     const Tensor& output,
731     const Tensor& logsumexp,
732     const at::Tensor& query,
733     const at::Tensor& key,
734     const at::Tensor& value,
735     double dropout_p,
736     bool is_causal,
737     std::optional<Tensor> attn_mask,
738     std::optional<double> scale) {
739   auto q_seq_len = query.size(2);
740 
741   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention", [&] {
742     if (!attn_mask.has_value()) {
743       if (q_seq_len >= 768) {
744         cpu_flash_attention<scalar_t, scalar_t, 256, 512>(
745           output, logsumexp, query, key, value,
746           dropout_p, is_causal, attn_mask, scale);
747       } else if (q_seq_len >= 192) {
748         cpu_flash_attention<scalar_t, scalar_t, 64, 512>(
749           output, logsumexp, query, key, value,
750           dropout_p, is_causal, attn_mask, scale);
751       } else {
752         cpu_flash_attention<scalar_t, scalar_t, 32, 512>(
753           output, logsumexp, query, key, value,
754           dropout_p, is_causal, attn_mask, scale);
755       }
756     } else {
757       AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask", [&]() {
758         if (q_seq_len >= 768) {
759           cpu_flash_attention<scalar_t, mask_t, 256, 512>(
760             output, logsumexp, query, key, value,
761             dropout_p, is_causal, attn_mask, scale);
762         } else if (q_seq_len >= 192) {
763           cpu_flash_attention<scalar_t, mask_t, 64, 512>(
764             output, logsumexp, query, key, value,
765             dropout_p, is_causal, attn_mask, scale);
766         } else {
767           cpu_flash_attention<scalar_t, mask_t, 32, 512>(
768             output, logsumexp, query, key, value,
769             dropout_p, is_causal, attn_mask, scale);
770         }
771       });
772     }
773   });
774 }
775 
flash_attention_backward_kernel_impl(const at::Tensor & grad_q,const at::Tensor & grad_k,const at::Tensor & grad_v,const at::Tensor & grad_out,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)776 void flash_attention_backward_kernel_impl(
777     const at::Tensor& grad_q,
778     const at::Tensor& grad_k,
779     const at::Tensor& grad_v,
780     const at::Tensor& grad_out,
781     const at::Tensor& query,
782     const at::Tensor& key,
783     const at::Tensor& value,
784     const at::Tensor& out,
785     const at::Tensor& logsumexp,
786     double dropout_p,
787     bool is_causal,
788     std::optional<Tensor> attn_mask,
789     std::optional<double> scale) {
790   // make sure grad_out has no zero strides (broadcasted dimensions)
791   // since we are going to call gemm next
792   // zero stride in leading dimension would lead to slow impl for gemm
793   auto grad_out_contig = grad_out.contiguous();
794   auto q_seq_len = query.size(1);
795 
796   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention_backward", [&] {
797     if (!attn_mask.has_value() || !attn_mask.value().defined()) {
798       using accum_t = at::opmath_type<scalar_t>;
799       if (q_seq_len >= 768) {
800         cpu_flash_attention_backward<scalar_t, accum_t, 256, 512>(
801           grad_q, grad_k, grad_v, grad_out_contig,
802           query, key, value, out, logsumexp,
803           dropout_p, is_causal, attn_mask, scale);
804       } else if (q_seq_len >= 192) {
805         cpu_flash_attention_backward<scalar_t, accum_t, 64, 512>(
806           grad_q, grad_k, grad_v, grad_out_contig,
807           query, key, value, out, logsumexp,
808           dropout_p, is_causal, attn_mask, scale);
809       } else {
810         cpu_flash_attention_backward<scalar_t, accum_t, 32, 512>(
811           grad_q, grad_k, grad_v, grad_out_contig,
812           query, key, value, out, logsumexp,
813           dropout_p, is_causal, attn_mask, scale);
814       }
815     } else {
816       AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask_backward", [&]() {
817         if (q_seq_len >= 768) {
818           cpu_flash_attention_backward<scalar_t, mask_t, 256, 512>(
819             grad_q, grad_k, grad_v, grad_out_contig,
820             query, key, value, out, logsumexp,
821             dropout_p, is_causal, attn_mask, scale);
822         } else if (q_seq_len >= 192) {
823           cpu_flash_attention_backward<scalar_t, mask_t, 64, 512>(
824             grad_q, grad_k, grad_v, grad_out_contig,
825             query, key, value, out, logsumexp,
826             dropout_p, is_causal, attn_mask, scale);
827         } else {
828           cpu_flash_attention_backward<scalar_t, mask_t, 32, 512>(
829             grad_q, grad_k, grad_v, grad_out_contig,
830             query, key, value, out, logsumexp,
831             dropout_p, is_causal, attn_mask, scale);
832         }
833       });
834     }
835   });
836 }
837 
838 } // anonymous namespace
839 
840 ALSO_REGISTER_AVX512_DISPATCH(flash_attention_kernel, &flash_attention_kernel_impl);
841 ALSO_REGISTER_AVX512_DISPATCH(flash_attention_backward_kernel, &flash_attention_backward_kernel_impl);
842 
843 } // at::native
844