1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/vec.h>
7 #include <ATen/cpu/vec/functional.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <ATen/native/transformers/attention.h>
11 #include <ATen/native/transformers/sdp_utils_cpp.h>
12 #include <c10/util/irange.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #endif
19
20 namespace at::native {
21
22 namespace {
23
24 // out = val * a + b
25 template <typename T1, typename T2>
_scale_attn_mask_fusion_kernel(T1 * a,T2 * b,const int & size,T1 * out,T1 & val)26 inline void _scale_attn_mask_fusion_kernel(
27 T1* a,
28 T2* b,
29 const int& size,
30 T1* out,
31 T1& val) {
32 const auto vec_size1 = at::vec::Vectorized<T1>::size();
33 const auto vec_size2 = at::vec::Vectorized<T2>::size();
34 constexpr int64_t T1_n = (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
35 constexpr int64_t T2_n = 1;
36 auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
37 int64_t i = 0;
38 for (; i < size - (size % vec_size2); i += vec_size2) {
39 auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
40 auto b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
41 auto b_n_convert = at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
42 auto res = a_n * vec_scale + b_n_convert;
43 res.store(out + i);
44 }
45 for (; i < size; i++) {
46 auto tmp0 = a[i];
47 auto tmp1 = (T1) b[i];
48 out[i] = tmp0 * val + tmp1;
49 }
50 }
51
52 // 1) out = exp(a - val)
53 // 2) val = sum(out)
54 template <typename T1, typename T2>
_exp_reduce_sum_fusion_kernel(T1 * a,const int & size,T2 * out,T1 & val)55 inline void _exp_reduce_sum_fusion_kernel(
56 T1* a,
57 const int& size,
58 T2* out,
59 T1& val) {
60 auto vec_size = vec::Vectorized<T1>::size();
61 auto vec_max = vec::Vectorized<T1>(val);
62 T1 tmp_sum = 0;
63 auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
64 for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
65 auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
66 auto tmp1 = tmp0 - vec_max;
67 auto tmp2 = tmp1.exp_u20();
68 vec_tmp_sum += tmp2;
69 _store(out + i, tmp2);
70 }
71 tmp_sum = vec::vec_reduce_all<T1>(
72 [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) {
73 return x + y;
74 },
75 vec_tmp_sum);
76 for (long i = vec_size * (size / vec_size); i < size; i++) {
77 auto tmp0 = a[i];
78 auto tmp1 = tmp0 - val;
79 auto tmp2 = exp(tmp1);
80 tmp_sum += tmp2;
81 out[i] = tmp2;
82 }
83 val = tmp_sum;
84 }
85
86 // 1) out = a * scale
87 // 2) max = max(out)
88 template <typename scalar_t>
_mul_reduce_max_fusion_kernel(const scalar_t * a,const scalar_t & scale,const int & size,scalar_t * out,scalar_t & max)89 inline void _mul_reduce_max_fusion_kernel(
90 const scalar_t* a,
91 const scalar_t& scale,
92 const int& size,
93 scalar_t* out,
94 scalar_t& max) {
95 auto vec_size = vec::Vectorized<scalar_t>::size();
96 auto vec_scale = vec::Vectorized<scalar_t>(scale);
97 scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
98 auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
99 for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
100 auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
101 auto tmp1 = tmp0 * vec_scale;
102 vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
103 _store(out + i, tmp1);
104 }
105 for (long i = vec_size * (size / vec_size); i < size; i++) {
106 auto tmp0 = a[i];
107 auto tmp1 = tmp0 * scale;
108 tmp_max = std::max(tmp_max, tmp1);
109 out[i] = tmp1;
110 }
111 max = std::max(
112 tmp_max,
113 vec::vec_reduce_all<scalar_t>(
114 [](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
115 return vec::maximum(x, y);
116 },
117 vec_tmp_max));
118 }
119
120 template <typename scalar_t>
conditional_data_ptr(scalar_t * ptr,scalar_t * ptr2)121 static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
122 TORCH_CHECK(ptr2 == nullptr);
123 return ptr;
124 }
125
126 template <typename scalar_t,
127 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
conditional_data_ptr(float * ptr,scalar_t * ptr2)128 static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
129 return ptr2;
130 }
131
132 template <typename scalar_t>
fill_stub(scalar_t * data,scalar_t val,int64_t size)133 inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
134 using Vec = Vectorized<scalar_t>;
135 Vec data_vec = Vec(val);
136 int64_t d = 0;
137 for (; d < size - (size % Vec::size()); d += Vec::size()) {
138 data_vec.store(data + d);
139 }
140 #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
141 # pragma unroll
142 #endif
143 for (; d < size; d++) {
144 data[d] = val;
145 }
146 }
147
reshape_attn_mask_to_4d(Tensor & attn_mask,int64_t batchSize,int64_t num_head,int64_t qSize,int64_t kvSize)148 void reshape_attn_mask_to_4d(
149 Tensor& attn_mask,
150 int64_t batchSize,
151 int64_t num_head,
152 int64_t qSize,
153 int64_t kvSize) {
154 // Support mask shapes:
155 // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
156 // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
157 // Guaranteed in check_attn_mask_shape
158 int64_t attn_mask_size_0 = 1;
159 int64_t attn_mask_size_1 = 1;
160 if (attn_mask.dim() == 4) {
161 if (attn_mask.size(0) == batchSize) {
162 attn_mask_size_0 = batchSize;
163 }
164 if (attn_mask.size(1) == num_head) {
165 attn_mask_size_1 = num_head;
166 }
167 }
168 attn_mask = attn_mask
169 .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)})
170 .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
171 }
172
173 template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
cpu_flash_attention(const Tensor & output,const Tensor & logsumexp,const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)174 void cpu_flash_attention(
175 const Tensor& output,
176 const Tensor& logsumexp,
177 const at::Tensor& q,
178 const at::Tensor& k,
179 const at::Tensor& v,
180 double dropout_p,
181 bool is_causal,
182 std::optional<Tensor> attn_mask,
183 std::optional<double> scale) {
184 // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
185 // -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
186 // Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
187 // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
188 // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
189 // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
190 at::Tensor query = q.transpose(1, 2);
191 at::Tensor key = k.transpose(1, 2);
192 at::Tensor value = v.transpose(1, 2);
193
194 constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
195 using accum_t = at::opmath_type<scalar_t>;
196 using Vec = vec::Vectorized<accum_t>;
197 accum_t scaling_factor =
198 sdp::calculate_scale(query, scale).as_float_unchecked();
199
200 // Sizes
201 TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
202 "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
203 int64_t batchSize = query.size(0);
204 int64_t qSize = query.size(1);
205 int64_t kvSize = value.size(1);
206 int64_t num_head = query.size(2);
207 int64_t headSize = query.size(3);
208
209 bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
210 if (has_attn_mask) {
211 reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
212 }
213
214 // Strides
215 int64_t qStrideB = query.stride(0);
216 int64_t qStrideM = query.stride(1);
217 int64_t qStrideH = query.stride(2);
218 int64_t kStrideB = key.stride(0);
219 int64_t kStrideN = key.stride(1);
220 int64_t kStrideH = key.stride(2);
221 int64_t vStrideB = value.stride(0);
222 int64_t vStrideN = value.stride(1);
223 int64_t vStrideH = value.stride(2);
224 int64_t oStrideB = output.stride(0);
225 int64_t oStrideM = output.stride(1);
226 int64_t oStrideH = output.stride(2);
227 int64_t lStrideB = logsumexp.stride(0);
228 int64_t lStrideM = logsumexp.stride(1);
229 int64_t lStrideH = logsumexp.stride(2);
230 int64_t mStrideB =
231 (has_attn_mask && attn_mask.value().size(0) > 1)
232 ? attn_mask.value().stride(0)
233 : 0;
234 int64_t mStrideH =
235 (has_attn_mask && attn_mask.value().size(1) > 1)
236 ? attn_mask.value().stride(1)
237 : 0;
238 int64_t mStrideM =
239 has_attn_mask ? attn_mask.value().stride(2) : 0;
240
241 int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
242 int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
243 int64_t qSlice = (qSize - 1) / qSplitSize + 1;
244 int64_t num_thread = at::get_num_threads();
245
246 const auto dtype = query.scalar_type();
247 const auto accumulate_dtype = toOpMathType(dtype);
248
249 // allocate per thread temp buf (accumulate type)
250 int64_t size_per_thread =
251 /* qk */ qSplitSize * kvSplitSize +
252 /* qk_max */ qSplitSize +
253 /* qk_sum */ qSplitSize +
254 /* dst */ qSplitSize * headSize;
255
256 at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
257 at::Tensor buf_reduced = at::empty({num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, query.options());
258
259 // Data ptrs
260 const scalar_t* q_data = query.const_data_ptr<scalar_t>();
261 const scalar_t* k_data = key.const_data_ptr<scalar_t>();
262 const scalar_t* v_data = value.const_data_ptr<scalar_t>();
263 mask_t* mask_data = has_attn_mask
264 ? attn_mask.value().data_ptr<mask_t>()
265 : nullptr;
266 scalar_t* out_data = output.data_ptr<scalar_t>();
267 accum_t* lse_data = logsumexp.data_ptr<accum_t>();
268 accum_t* buf_data = buf.data_ptr<accum_t>();
269 scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
270
271 at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
272 int64_t i = 0, j = 0, k = 0;
273 data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
274 int ompIdx = at::get_thread_num();
275 accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
276 accum_t* qk_data = buf_ptr;
277 accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
278 accum_t* qk_sum_data = qk_max_data + qSplitSize;
279 accum_t* dst_data = qk_sum_data + qSplitSize;
280 scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr;
281
282 for (const auto z : c10::irange(begin, end)) {
283 (void)z; // Suppress unused variable
284 int64_t m = k * qSplitSize;
285 int64_t qBlockSize = std::min(qSplitSize, qSize - m);
286 // Initialize max and sum
287 fill_stub(qk_max_data,
288 -std::numeric_limits<accum_t>::infinity(), qBlockSize);
289 fill_stub(qk_sum_data,
290 static_cast<accum_t>(0), qBlockSize);
291 int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
292 for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
293 int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
294 // Calculate scale * q @ k.T
295 cpublas::gemm(
296 TransposeType::Transpose,
297 TransposeType::NoTranspose,
298 kvBlockSize,
299 qBlockSize,
300 headSize,
301 static_cast<accum_t>(1),
302 k_data + i * kStrideB + j * kStrideH +
303 n * kStrideN,
304 kStrideN,
305 q_data + i * qStrideB + j * qStrideH +
306 m * qStrideM,
307 qStrideM,
308 static_cast<accum_t>(0),
309 qk_data,
310 kvBlockSize);
311 // Apply causal mask, fill unused with -inf
312 if (is_causal && num_keys - n <= kvSplitSize) {
313 for (const auto row : c10::irange(qBlockSize)) {
314 int64_t last_col = m + row - n;
315 accum_t* row_ptr = qk_data + row * kvBlockSize;
316 fill_stub(row_ptr + last_col + 1,
317 -std::numeric_limits<accum_t>::infinity(),
318 kvBlockSize - last_col - 1);
319 }
320 }
321 // Update attention weights with attention mask
322 // And apply scaling factor
323 // qk <- qk * scaling + attn_mask
324 if (has_attn_mask) {
325 for (int64_t row = 0; row < qBlockSize; ++row) {
326 _scale_attn_mask_fusion_kernel(
327 qk_data + row * kvBlockSize,
328 mask_data + i * mStrideB + j * mStrideH +
329 (m + row) * mStrideM + n,
330 kvBlockSize,
331 qk_data + row * kvBlockSize,
332 scaling_factor);
333 }
334 }
335 // Update coefficients with Softmax
336 accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
337 for (int64_t row = 0; row < qBlockSize; ++row) {
338 if (has_attn_mask) {
339 // max per row
340 tmp_max = at::vec::reduce_all<accum_t>(
341 [](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
342 qk_data + row * kvBlockSize,
343 kvBlockSize);
344 } else {
345 // apply scaling factor and max per row in fusion
346 _mul_reduce_max_fusion_kernel(
347 qk_data + row * kvBlockSize,
348 scaling_factor,
349 kvBlockSize,
350 qk_data + row * kvBlockSize,
351 tmp_max);
352 }
353 tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
354 if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
355 // to avoid `nan = exp2f(-inf - (-inf))`
356 fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize,
357 static_cast<scalar_t>(0), kvBlockSize);
358 } else {
359 tmp_sum = tmp_max;
360 // qk <- exp(qk - max) and sum per row
361 _exp_reduce_sum_fusion_kernel(
362 qk_data + row * kvBlockSize, kvBlockSize,
363 conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize,
364 tmp_sum);
365 // exp_tmp <- exp(max[row] - max)
366 exp_tmp = std::exp(qk_max_data[row] - tmp_max);
367 // sum[row] <- sum + exp_tmp * sum[row]
368 qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
369 // max[row] <- max
370 qk_max_data[row] = tmp_max;
371 // dst <- dst * exp_tmp
372 if (n > 0) {
373 vec::map<accum_t>(
374 [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
375 dst_data + row * headSize, dst_data + row * headSize, headSize);
376 }
377 }
378 }
379 // Calculate Softmax(q @ k.T) @ v
380 cpublas::gemm(
381 TransposeType::NoTranspose,
382 TransposeType::NoTranspose,
383 headSize,
384 qBlockSize,
385 kvBlockSize,
386 static_cast<accum_t>(1),
387 v_data + i * vStrideB + j * vStrideH +
388 n * vStrideN,
389 vStrideN,
390 conditional_data_ptr(qk_data, qk_reduced_data),
391 kvBlockSize,
392 n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
393 dst_data,
394 headSize);
395 }
396 // dst <- dst / sum[row]
397 // reorder MHA output with strides
398 for (int64_t row = 0; row < qBlockSize; ++row) {
399 accum_t sum_reciprocal = 1 / qk_sum_data[row];
400 vec::map<scalar_t>(
401 [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
402 out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
403 dst_data + row * headSize,
404 headSize);
405 }
406 // Store logsumexp for backward
407 accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
408 for (const auto row : c10::irange(qBlockSize)) {
409 lse_ptr[row * lStrideM] = qk_max_data[row]
410 + std::log(qk_sum_data[row]);
411 }
412 // Move to the next query
413 data_index_step(i, batchSize, j, num_head, k, qSlice);
414 }
415 });
416
417 }
418
419 template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
cpu_flash_attention_backward(const at::Tensor & grad_q,const at::Tensor & grad_k,const at::Tensor & grad_v,const at::Tensor & grad_out,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)420 void cpu_flash_attention_backward(
421 const at::Tensor& grad_q,
422 const at::Tensor& grad_k,
423 const at::Tensor& grad_v,
424 const at::Tensor& grad_out,
425 const at::Tensor& query,
426 const at::Tensor& key,
427 const at::Tensor& value,
428 const at::Tensor& out,
429 const at::Tensor& logsumexp,
430 double dropout_p,
431 bool is_causal,
432 std::optional<Tensor> attn_mask,
433 std::optional<double> scale) {
434 constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
435 using accum_t = at::opmath_type<scalar_t>;
436 using Vec = vec::Vectorized<accum_t>;
437 accum_t scaling_factor =
438 sdp::calculate_scale(query, scale).as_float_unchecked();
439
440 // Sizes
441 TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
442 "scaled_dot_product_attention_flash_attention_backward: Q/K/V should have the same head size");
443 // Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
444 // Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
445 // Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
446 int64_t batchSize = query.size(0);
447 int64_t qSize = query.size(1);
448 int64_t kvSize = value.size(1);
449 int64_t num_head = query.size(2);
450 int64_t headSize = query.size(3);
451
452 bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
453 if (has_attn_mask) {
454 reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
455 }
456
457 // Strides
458 int64_t qStrideB = query.stride(0);
459 int64_t qStrideM = query.stride(1);
460 int64_t qStrideH = query.stride(2);
461 int64_t kStrideB = key.stride(0);
462 int64_t kStrideN = key.stride(1);
463 int64_t kStrideH = key.stride(2);
464 int64_t vStrideB = value.stride(0);
465 int64_t vStrideN = value.stride(1);
466 int64_t vStrideH = value.stride(2);
467 int64_t oStrideB = out.stride(0);
468 int64_t oStrideM = out.stride(1);
469 int64_t oStrideH = out.stride(2);
470 int64_t lStrideB = logsumexp.stride(0);
471 int64_t lStrideM = logsumexp.stride(1);
472 int64_t lStrideH = logsumexp.stride(2);
473 int64_t mStrideB =
474 (has_attn_mask && attn_mask.value().size(0) > 1)
475 ? attn_mask.value().stride(0)
476 : 0;
477 int64_t mStrideH =
478 (has_attn_mask && attn_mask.value().size(1) > 1)
479 ? attn_mask.value().stride(1)
480 : 0;
481 int64_t mStrideM =
482 has_attn_mask ? attn_mask.value().stride(2) : 0;
483
484 int64_t grad_qStrideB = grad_q.stride(0);
485 int64_t grad_qStrideM = grad_q.stride(1);
486 int64_t grad_qStrideH = grad_q.stride(2);
487 int64_t grad_kStrideB = grad_k.stride(0);
488 int64_t grad_kStrideN = grad_k.stride(1);
489 int64_t grad_kStrideH = grad_k.stride(2);
490 int64_t grad_vStrideB = grad_v.stride(0);
491 int64_t grad_vStrideN = grad_v.stride(1);
492 int64_t grad_vStrideH = grad_v.stride(2);
493 int64_t grad_oStrideB = grad_out.stride(0);
494 int64_t grad_oStrideM = grad_out.stride(1);
495 int64_t grad_oStrideH = grad_out.stride(2);
496
497 int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
498 int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
499 int64_t num_thread = at::get_num_threads();
500
501 const auto dtype = query.scalar_type();
502 const auto accumulate_dtype = toOpMathType(dtype);
503
504 // allocate per thread temp buf (accumulate type)
505 int64_t size_per_thread =
506 /* attn */ qSplitSize * kvSplitSize +
507 /* grad_attn */ qSplitSize * kvSplitSize;
508
509 at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
510
511 // allocate per thread temp buf_reduced (scalar type)
512 // buf2 is only needed for bfloat16 and float16
513 int64_t size_per_thread_reduced =
514 /* attn_reduced */ qSplitSize * kvSplitSize +
515 /* grad_attn_reduced */ qSplitSize * kvSplitSize;
516
517 at::Tensor buf_reduced = at::empty({num_thread, is_reduced_type ? size_per_thread_reduced : 0}, query.options());
518
519 scalar_t* grad_q_data = grad_q.data_ptr<scalar_t>();
520 scalar_t* grad_k_data = grad_k.data_ptr<scalar_t>();
521 scalar_t* grad_v_data = grad_v.data_ptr<scalar_t>();
522 const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
523 const scalar_t* q_data = query.const_data_ptr<scalar_t>();
524 const scalar_t* k_data = key.const_data_ptr<scalar_t>();
525 const scalar_t* v_data = value.const_data_ptr<scalar_t>();
526 mask_t* mask_data = has_attn_mask
527 ? attn_mask.value().data_ptr<mask_t>()
528 : nullptr;
529 const scalar_t* out_data = out.const_data_ptr<scalar_t>();
530 const accum_t* lse_data = logsumexp.const_data_ptr<accum_t>();
531 accum_t* buf_data = buf.data_ptr<accum_t>();
532 scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
533
534 at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
535 int64_t i = 0, j = 0;
536 data_index_init(begin, i, batchSize, j, num_head);
537 int ompIdx = at::get_thread_num();
538 accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
539 accum_t* attn_data = buf_ptr;
540 accum_t* grad_attn_data = attn_data + qSplitSize * kvSplitSize;
541 scalar_t* buf_reduced_ptr = is_reduced_type ? buf_reduced_data + ompIdx * size_per_thread_reduced : nullptr;
542 scalar_t* attn_reduced_data = is_reduced_type ? buf_reduced_ptr : nullptr;
543 scalar_t* grad_attn_reduced_data = is_reduced_type ? attn_reduced_data + qSplitSize * kvSplitSize : nullptr;
544
545 at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype));
546 accum_t* dsum_data = dsum.data_ptr<accum_t>();
547 for (const auto z : c10::irange(begin, end)) {
548 (void)z; // Suppress unused variable
549 // rowsum of grad_out * out
550 for (int64_t m = 0; m < qSize; m += qSplitSize) {
551 int64_t qBlockSize = std::min(qSplitSize, qSize - m);
552 // dsum <- rowsum(grad_out * out)
553 for (const auto row : c10::irange(qBlockSize)) {
554 *(dsum_data + row) = vec::map2_reduce_all<scalar_t>(
555 [](Vec x, Vec y) { return x * y; },
556 [](Vec x, Vec y) { return x + y; },
557 grad_out_data + i * grad_oStrideB + j * grad_oStrideH + (m + row) * grad_oStrideM,
558 out_data + i * oStrideB + j * oStrideH + (m + row) * oStrideM,
559 headSize);
560 }
561 int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
562 for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
563 int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
564 // attn <- scale * q @ k.T
565 cpublas::gemm(
566 TransposeType::Transpose,
567 TransposeType::NoTranspose,
568 kvBlockSize,
569 qBlockSize,
570 headSize,
571 scaling_factor,
572 k_data + i * kStrideB + j * kStrideH +
573 n * kStrideN,
574 kStrideN,
575 q_data + i * qStrideB + j * qStrideH +
576 m * qStrideM,
577 qStrideM,
578 static_cast<accum_t>(0),
579 attn_data,
580 kvBlockSize);
581 // attn <- attn + mask
582 if (has_attn_mask) {
583 accum_t one = accum_t(1);
584 for (const auto row : c10::irange(qBlockSize)) {
585 _scale_attn_mask_fusion_kernel(
586 attn_data + row * kvBlockSize,
587 mask_data + i * mStrideB + j * mStrideH +
588 (m + row) * mStrideM + n,
589 kvBlockSize,
590 attn_data + row * kvBlockSize,
591 one);
592 }
593 }
594 // restore self attention after softmax from logsumexp
595 // attn <- exp(attn - normalizer)
596 for (const auto row : c10::irange(qBlockSize)) {
597 accum_t normalizer = lse_data[i * lStrideB + j * lStrideH + (m + row) * lStrideM];
598 vec::map<accum_t>(
599 [normalizer](Vec x) { return (x - Vec(normalizer)).exp(); },
600 attn_data + row * kvBlockSize,
601 attn_data + row * kvBlockSize,
602 kvBlockSize);
603 }
604 // Apply causal mask, filled unused with 0
605 if (is_causal && num_keys - n <= kvSplitSize) {
606 for (const auto row : c10::irange(qBlockSize)) {
607 int64_t last_col = m + row - n;
608 accum_t* row_ptr = attn_data + row * kvBlockSize;
609 fill_stub(row_ptr + last_col + 1, static_cast<accum_t>(0), kvBlockSize - last_col - 1);
610 }
611 }
612 if (is_reduced_type) {
613 for (const auto row : c10::irange(qBlockSize)) {
614 convert<accum_t, scalar_t>(
615 attn_data + row * kvBlockSize,
616 attn_reduced_data + row * kvBlockSize,
617 kvBlockSize);
618 }
619 }
620 // grad_v <- grad_v + attn.T @ grad_out
621 cpublas::gemm(
622 TransposeType::NoTranspose,
623 TransposeType::Transpose,
624 headSize,
625 kvBlockSize,
626 qBlockSize,
627 static_cast<accum_t>(1),
628 grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
629 m * grad_oStrideM,
630 grad_oStrideM,
631 conditional_data_ptr(attn_data, attn_reduced_data),
632 kvBlockSize,
633 static_cast<accum_t>(1),
634 grad_v_data + i * grad_vStrideB + j * grad_vStrideH +
635 n * grad_vStrideN,
636 grad_vStrideN);
637 // grad_attn <- grad_out @ v.T
638 cpublas::gemm(
639 TransposeType::Transpose,
640 TransposeType::NoTranspose,
641 kvBlockSize,
642 qBlockSize,
643 headSize,
644 static_cast<accum_t>(1),
645 v_data + i * vStrideB + j * vStrideH +
646 n * vStrideN,
647 vStrideN,
648 grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
649 m * grad_oStrideM,
650 grad_oStrideM,
651 static_cast<accum_t>(0),
652 grad_attn_data,
653 kvBlockSize);
654 // grad_attn <- attn * (grad_attn - dsum)
655 for (const auto row : c10::irange(qBlockSize)) {
656 accum_t d = *(dsum_data + row);
657 vec::map2<accum_t>(
658 [d](Vec attn, Vec grad_attn) { return attn * (grad_attn - Vec(d)); },
659 grad_attn_data + row * kvBlockSize,
660 attn_data + row * kvBlockSize,
661 grad_attn_data + row * kvBlockSize,
662 kvBlockSize);
663 }
664 if (is_reduced_type) {
665 for (const auto row : c10::irange(qBlockSize)) {
666 convert<accum_t, scalar_t>(
667 grad_attn_data + row * kvBlockSize,
668 grad_attn_reduced_data + row * kvBlockSize,
669 kvBlockSize);
670 }
671 }
672 // grad_q <- grad_q + scale * grad_attn @ k
673 cpublas::gemm(
674 TransposeType::NoTranspose,
675 TransposeType::NoTranspose,
676 headSize,
677 qBlockSize,
678 kvBlockSize,
679 scaling_factor,
680 k_data + i * kStrideB + j * kStrideH +
681 n * kStrideN,
682 kStrideN,
683 conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
684 kvBlockSize,
685 static_cast<accum_t>(1),
686 grad_q_data + i * grad_qStrideB + j * grad_qStrideH +
687 m * grad_qStrideM,
688 grad_qStrideM);
689 // grad_k <- grad_k + scale * grad_attn.T @ q
690 cpublas::gemm(
691 TransposeType::NoTranspose,
692 TransposeType::Transpose,
693 headSize,
694 kvBlockSize,
695 qBlockSize,
696 scaling_factor,
697 q_data + i * qStrideB + j * qStrideH +
698 m * qStrideM,
699 qStrideM,
700 conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
701 kvBlockSize,
702 static_cast<accum_t>(1),
703 grad_k_data + i * grad_kStrideB + j * grad_kStrideH +
704 n * grad_kStrideN,
705 grad_kStrideN);
706 }
707 }
708 // Move to the next query
709 data_index_step(i, batchSize, j, num_head);
710 }
711 });
712 }
713
714 #define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \
715 AT_DISPATCH_SWITCH( \
716 TYPE, \
717 NAME, \
718 AT_PRIVATE_CASE_TYPE_USING_HINT( \
719 at::ScalarType::Bool, mask_t, __VA_ARGS__) \
720 AT_PRIVATE_CASE_TYPE_USING_HINT( \
721 at::ScalarType::Float, mask_t, __VA_ARGS__) \
722 AT_PRIVATE_CASE_TYPE_USING_HINT( \
723 at::ScalarType::Double, mask_t, __VA_ARGS__) \
724 AT_PRIVATE_CASE_TYPE_USING_HINT( \
725 at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \
726 AT_PRIVATE_CASE_TYPE_USING_HINT( \
727 at::ScalarType::Half, mask_t, __VA_ARGS__))
728
flash_attention_kernel_impl(const Tensor & output,const Tensor & logsumexp,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)729 void flash_attention_kernel_impl(
730 const Tensor& output,
731 const Tensor& logsumexp,
732 const at::Tensor& query,
733 const at::Tensor& key,
734 const at::Tensor& value,
735 double dropout_p,
736 bool is_causal,
737 std::optional<Tensor> attn_mask,
738 std::optional<double> scale) {
739 auto q_seq_len = query.size(2);
740
741 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention", [&] {
742 if (!attn_mask.has_value()) {
743 if (q_seq_len >= 768) {
744 cpu_flash_attention<scalar_t, scalar_t, 256, 512>(
745 output, logsumexp, query, key, value,
746 dropout_p, is_causal, attn_mask, scale);
747 } else if (q_seq_len >= 192) {
748 cpu_flash_attention<scalar_t, scalar_t, 64, 512>(
749 output, logsumexp, query, key, value,
750 dropout_p, is_causal, attn_mask, scale);
751 } else {
752 cpu_flash_attention<scalar_t, scalar_t, 32, 512>(
753 output, logsumexp, query, key, value,
754 dropout_p, is_causal, attn_mask, scale);
755 }
756 } else {
757 AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask", [&]() {
758 if (q_seq_len >= 768) {
759 cpu_flash_attention<scalar_t, mask_t, 256, 512>(
760 output, logsumexp, query, key, value,
761 dropout_p, is_causal, attn_mask, scale);
762 } else if (q_seq_len >= 192) {
763 cpu_flash_attention<scalar_t, mask_t, 64, 512>(
764 output, logsumexp, query, key, value,
765 dropout_p, is_causal, attn_mask, scale);
766 } else {
767 cpu_flash_attention<scalar_t, mask_t, 32, 512>(
768 output, logsumexp, query, key, value,
769 dropout_p, is_causal, attn_mask, scale);
770 }
771 });
772 }
773 });
774 }
775
flash_attention_backward_kernel_impl(const at::Tensor & grad_q,const at::Tensor & grad_k,const at::Tensor & grad_v,const at::Tensor & grad_out,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,double dropout_p,bool is_causal,std::optional<Tensor> attn_mask,std::optional<double> scale)776 void flash_attention_backward_kernel_impl(
777 const at::Tensor& grad_q,
778 const at::Tensor& grad_k,
779 const at::Tensor& grad_v,
780 const at::Tensor& grad_out,
781 const at::Tensor& query,
782 const at::Tensor& key,
783 const at::Tensor& value,
784 const at::Tensor& out,
785 const at::Tensor& logsumexp,
786 double dropout_p,
787 bool is_causal,
788 std::optional<Tensor> attn_mask,
789 std::optional<double> scale) {
790 // make sure grad_out has no zero strides (broadcasted dimensions)
791 // since we are going to call gemm next
792 // zero stride in leading dimension would lead to slow impl for gemm
793 auto grad_out_contig = grad_out.contiguous();
794 auto q_seq_len = query.size(1);
795
796 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention_backward", [&] {
797 if (!attn_mask.has_value() || !attn_mask.value().defined()) {
798 using accum_t = at::opmath_type<scalar_t>;
799 if (q_seq_len >= 768) {
800 cpu_flash_attention_backward<scalar_t, accum_t, 256, 512>(
801 grad_q, grad_k, grad_v, grad_out_contig,
802 query, key, value, out, logsumexp,
803 dropout_p, is_causal, attn_mask, scale);
804 } else if (q_seq_len >= 192) {
805 cpu_flash_attention_backward<scalar_t, accum_t, 64, 512>(
806 grad_q, grad_k, grad_v, grad_out_contig,
807 query, key, value, out, logsumexp,
808 dropout_p, is_causal, attn_mask, scale);
809 } else {
810 cpu_flash_attention_backward<scalar_t, accum_t, 32, 512>(
811 grad_q, grad_k, grad_v, grad_out_contig,
812 query, key, value, out, logsumexp,
813 dropout_p, is_causal, attn_mask, scale);
814 }
815 } else {
816 AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask_backward", [&]() {
817 if (q_seq_len >= 768) {
818 cpu_flash_attention_backward<scalar_t, mask_t, 256, 512>(
819 grad_q, grad_k, grad_v, grad_out_contig,
820 query, key, value, out, logsumexp,
821 dropout_p, is_causal, attn_mask, scale);
822 } else if (q_seq_len >= 192) {
823 cpu_flash_attention_backward<scalar_t, mask_t, 64, 512>(
824 grad_q, grad_k, grad_v, grad_out_contig,
825 query, key, value, out, logsumexp,
826 dropout_p, is_causal, attn_mask, scale);
827 } else {
828 cpu_flash_attention_backward<scalar_t, mask_t, 32, 512>(
829 grad_q, grad_k, grad_v, grad_out_contig,
830 query, key, value, out, logsumexp,
831 dropout_p, is_causal, attn_mask, scale);
832 }
833 });
834 }
835 });
836 }
837
838 } // anonymous namespace
839
840 ALSO_REGISTER_AVX512_DISPATCH(flash_attention_kernel, &flash_attention_kernel_impl);
841 ALSO_REGISTER_AVX512_DISPATCH(flash_attention_backward_kernel, &flash_attention_backward_kernel_impl);
842
843 } // at::native
844