-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathflash_attention.cu
775 lines (650 loc) · 32 KB
/
flash_attention.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
#include "attention_api.cuh"
#include <cassert>
#include <cmath>
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <torch/extension.h>
#include <torch/python.h>
#include <vector>
#include "static_switch.h"
#include "kernel_traits.h"
#include "flash.h"
#include "utils.h"
namespace flash {
using namespace cute;
template <int kBlockM, int kBlockN, int kNWarps,typename Engine, typename Layout>
inline __device__ void mask_within_nblock(Tensor<Engine, Layout> &tensor, const int m_block, const int nbi) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
// NOTE:
// 确定一个MMA内的index也是一个难点
// (nrow=(2, MMA_M), ncol=(2, MMA_N))形如:
// T1.V0 T1.V1
// T1.V0 T1.V1
// 根据mma_tile的示意图来确定col和row值
// NOTE:
// 计算thread的处理范围, mask掉超出范围的部分
//
// NOTE:
// % 32表示32做组, 因为SM80_16x8x16_F32F16F16F32_TN _1_2_1中最大线程数id是32
// (lane_id % 4) * 2表示在哪个"颜色"的col(thread)中, *2是为了靠右(即处理的哪个value2)
// 因此col_idx_offset表示当前thread所处理的单个Atom中4列的哪列
// lane_id表示一个MMA tile中的"线程组"
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = kBlockN * nbi + (lane_id % 4) * 2;
const int nrow_group = threadIdx.x / 32;
const int row_idx_offset = kBlockM * m_block + lane_id / 4 + nrow_group * 16 /* 2*8 */;
// (2, nrow), 2*8 for each
const int group_stride = kNWarps * 16;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// SM80_16x8x16_F32F16F16F32_TN中的一组中, 一行4个线程处理8个value
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
// j用于计算value 1和value 2对应col
// col_idx最终表示当前thread所处理的value的列号
const int col_idx = col_idx_base + j;
// mask掉scores中(QK后的结果)超出范围的部分
// 列号和行号对比
// Without the "make_coord" we get wrong results
// for nrow(2, MMA_M)
#pragma unroll
for (int mi = 0; mi < size<0, 0>(tensor); ++mi) {
#pragma unroll
for (int mj = 0; mj < size<0, 1>(tensor); ++mj) {
const int row_idx = row_idx_offset + mi * 8 + mj * group_stride;
if (col_idx > row_idx) {
tensor(make_coord(mi, mj), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
// NOTE: A矩阵已经在寄存器中的gemm封装
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
// NOTE: 符合M N K描述: A[M, K] @ B[N, K] = C[M, N]
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
// NOTE: retile 成拷贝需要的大小
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
template<typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm_smem(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// NOTE: s -> reg
cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
// copy from S to D with tiled_copy
// TODO: 需要支持causal模式的的跳过拷贝
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
// TODO: 原版处这里identity_MN是用来跳过大块的block的, predicate用于跳过block内的拷贝
// TODO: 添加predicate逻辑, 用于跳过无用拷贝
// if (get<0>(identity_MN(0, m, 0)) < max_MN)
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
}
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
// TD [2023-08-13]: Same error as above on Cutlass 3.2
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
// get<0, 1>(l),
// get<1, 1, 1>(l));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l))));
};
// TODO: not work
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
// TODO:
// https://github.com/NVIDIA/cutlass/issues/802
// TODO: convert出来后数据是否在寄存器?
template <typename Fragment>
inline __device__ auto convert_type_f32_to_f16(Fragment const &acc_fp32) {
Tensor acc_fp16 = make_tensor<cute::half_t>(shape(acc_fp32));
{
Tensor acc_fp32x2 = recast< float2>(acc_fp32);
Tensor acc_fp16x2 = recast<__half2>(acc_fp16);
for (int i = 0; i < size(acc_fp32x2); ++i) { acc_fp16x2(i) = __float22half2_rn(acc_fp32x2(i)); }
}
return acc_fp16;
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled);
}
}
}
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// TODO: 搞清楚经过convert_layout_acc_rowcol后(nrow=(2, MMA_M), ncol=(2, MMA_N))的数学含义
// 形象的解释是把
// T1.V0
// T1.V1
// T1.V0
// T1.V1
// 变为
// T1.V0 T1.V1
// T1.V0 T1.V1
// 这样符合MMA tile的行列直觉
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
};
template<bool Is_first, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) {
// NOTE: scores来自acc_s: [email protected]
// acc_s用来存储QK和softmax的结果[seqlen, seqlen]
// acc_o用来存储softmax(QK)结果的分子部分, 用于rescale
// 流式计算不断用当前分块计算的结果scors来rescale
if (Is_first) {
// NOTE: 优化, 第一次softmax不需要rescale, 只需要记录分子, max, sum
reduce_max</*zero_init=*/true>(scores, scores_max);
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
reduce_sum(scores, scores_sum);
} else {
// 记录上一次的max
Tensor scores_max_prev = make_fragment_like(scores_max);
cute::copy(scores_max, scores_max_prev);
// TODO: reduce的实现学习一下
// NOTE: 计算新max到scores_max
// reduce_max包含步:
// 1. 求当前thread内max: 遍历
// 2. reduce thread间的max: 使用shift技巧reduce
reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// 将acc_o转换成符合2D直觉的(nrow, ncol)的形状
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scores_max); ++mi) {
// NOTE: 辅助变量: 当前max
float scores_max_cur = scores_max(mi);
// NOTE: 计算旧score的rescale值
// NOTE: 因为QK(影响max)计算时没有考虑softmax_scale, 所以这里要补上
float scores_scale = expf((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
// NOTE: rescale旧分母部分
scores_sum(mi) *= scores_scale;
// NOTE: 旧分子部分rescale
// acc_o_rowcol.shape = (nrow, ncol)
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
// NOTE: 计算新分子部分: 对所有scores进行rescale
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
// NOTE: 累加新分母
Tensor scores_sum_cur = make_fragment_like(scores_sum);
// NOTE:利用新分子来累加新分母
// 1. 线程内累加: 遍历
// 2. 线程间累加: 使用shift技巧reduce
reduce_sum(scores, scores_sum_cur);
// NOTE: 新分母累加到旧分母
#pragma unroll
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
}
};
} // namespace flash
void set_params_fprop(Flash_fwd_params ¶ms,
// device pointers
const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v,
torch::Tensor out,
void *softmax_lse_d,
float softmax_scale,
bool is_causal) {
memset(¶ms, 0, sizeof(params));
params.bs = q.size(0);
params.head = q.size(1);
params.q_seqlen = q.size(2);
params.dim = q.size(3);
params.k_head = k.size(1);
params.k_seqlen = k.size(2);
params.bs_stride = q.stride(0);
params.head_stride = q.stride(1);
params.seqlen_stride = q.stride(2);
params.dim_stride = q.stride(3);
params.softmax_scale = softmax_scale;
// TODO: 使用log2做scale
params.softmax_scale_log2 = softmax_scale * M_LOG2E;
params.is_causal = is_causal;
params.is_bf16 = q.dtype() == torch::kBFloat16;
// LogSumExp save for backward
params.softmax_lse_ptr = softmax_lse_d;
// TODO: get ptr
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.out_ptr = out.data_ptr();
}
// Shared Storage with Aligned addresses.
template <class ElementType, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutV>
struct SharedStorage {
// TODO: Aligned的话smem的计算是否有问题
cute::array_aligned<ElementType, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<ElementType, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<ElementType, cute::cosize_v<SmemLayoutV>> smem_v;
};
template <typename Kernel_traits, bool Is_causal=false, typename Params>
__global__ void flash_attention_v2_cutlass_kernel(const Params params) {
using namespace cute;
// m block index
const int m_block = blockIdx.x;
// bs * head
const int base_id = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
// TODO: 传入泛型
// NOTE: 小技巧
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
// using TiledMMA = typename Kernel_traits::MMA;
using TiledMMA = typename Kernel_traits::TiledMma;
using index_t = typename Kernel_traits::index_t;
using SmemLayoutQ = typename Kernel_traits::SmemLayoutQ;
using SmemLayoutK = typename Kernel_traits::SmemLayoutKV;
using SmemLayoutV = typename Kernel_traits::SmemLayoutKV;
using SmemLayoutVt = typename Kernel_traits::SmemLayoutVtransposed;
using SmemLayoutVtNoSwizzle = typename Kernel_traits::SmemLayoutVtransposedNoSwizzle;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
// Shared memory.
extern __shared__ char smem_[];
using SharedStorage = SharedStorage<Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV>;
SharedStorage &shared_storage = *reinterpret_cast<SharedStorage *>(smem_);
const int bs_head_offset = base_id * params.head_stride;
// TODO: base offset for MHA
// NOTE: convert C pointer to Tensor for convenience
Tensor Q = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + bs_head_offset),
make_shape(params.q_seqlen, Int<kHeadDim>{}),
make_stride(Int<kHeadDim>{}, Int<1>{}));
Tensor K = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + bs_head_offset),
make_shape(params.k_seqlen, Int<kHeadDim>{}),
make_stride(Int<kHeadDim>{}, Int<1>{}));
Tensor V = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + bs_head_offset),
make_shape(params.k_seqlen, Int<kHeadDim>{}),
make_stride(Int<kHeadDim>{}, Int<1>{}));
Tensor O = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.out_ptr) + bs_head_offset),
make_shape(params.q_seqlen, Int<kHeadDim>{}),
make_stride(Int<kHeadDim>{}, Int<1>{}));
// TODO:
Tensor LSE = make_tensor(
make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + base_id * params.q_seqlen),
// Shape<Int<kBlockM>, Stride<_1>{}>{},
make_shape(params.q_seqlen),
make_stride(Int<1>{}));
// 加载Q, K, V分块
// (kBlockM, kHeadDim, num_tile_n)
Tensor gQ = local_tile(Q, make_tile(Int<kBlockM>{}, Int<kHeadDim>{}), make_coord(m_block, _));
// (kBlockN, kHeadDim, num_tile_n)
// NOTE: loading流水线, 初次加载所需K, V
Tensor gK = local_tile(K, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(0, _));
Tensor gV = local_tile(V, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(0, _));
// 获取MMA抽象
TiledMMA tiled_mma;
auto thr_mma = tiled_mma.get_slice(tidx);
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
// Tensor for V Transpose; used in GEMM-II.
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
Tensor sVtNoSwizzle = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVtNoSwizzle{});
// NOTE: copy抽象
// NOTE: QKV gmem -> smem拷贝的抽象
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
// NOTE: 定义gmem -> smem拷贝的src, dst
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0));
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
// NOTE: 定义smem -> reg拷贝的dst
// partition_fragment与partition类似, 只是返回的是寄存器表示
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
//
// Copy Atom retiling
//
// TODO: 理解这里的atom retiling
// NOTE: 准备拷贝Q, K, V到smem的copy对象
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
// TODO: 拷贝时转置
// NOTE: smem->reg拷贝Vt
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// NOTE: 命名规则, t表示to, s/g表示位置(smem, gmem)
// 从smem加载时做retiling
// tKgK表示gmem中的K, 用作gmem->smem的src
// tKsK表示smem中的K, 用作gmem->smem的dst
// tSsK表示smem中的K, 用作smem->reg的src
// NOTE: make_identity_tensor创建只有形状的tensor用于拷贝
// 在copy时用于跳过整块的block
// // TODO: cQ等用在causal模式, 暂时无用
// // Construct identity layout for sQ and sK
// Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// // Repeat the partitioning with identity layouts
// Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// 流水线加载初始Q, K
// 加载Q到smem
flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ);
// 加载K到smem
flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
// 开始执行异步拷贝
cute::cp_async_fence();
Tensor rAccOut = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
// step1: slice-k compute QK block
// Q[BLOCK_M, BLOCK_N] @ K[BLOCK_M, BLOCK_N].T = O[BLOCK_M, BLOCK_M]
//
// step2:
// advance K, V
// NOTE: K, V分块的数量: 处理的区间
const int n_block_min = 0;
// NOTE: 1. mask between N BLOCKs if is causal mode
int seqlen_start = m_block * kBlockM;
int seqlen_end = (m_block + 1) * kBlockM;
int n_block_max = Is_causal ? cute::ceil_div(seqlen_end, kBlockN) : cute::ceil_div(params.k_seqlen, kBlockN);
// NOTE: 需要记录的max
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(rAccOut)>>{});
// NOTE: 需要记录的denom
Tensor scores_sum = make_fragment_like(scores_max);
clear(rAccOut);
for (int nbi = n_block_min; nbi < n_block_max; nbi++) {
auto rAccScore = partition_fragment_C(tiled_mma, make_shape(Int<kBlockM>{}, Int<kBlockN>{}));
clear(rAccScore);
// 等待Q, K的gmem -> smem拷贝完成, 即Q, K就绪
// wait<0>表示等待还剩0个未完成
flash::cp_async_wait<0>();
__syncthreads();
// gemm的同时异步加载V
gV = local_tile(V, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(nbi, _));
tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
// 异步加载V到smem
flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV);
// 发起异步拷贝
cute::cp_async_fence();
// O = [email protected]
// NOTE: 加载smem中的数据到reg再做gemm, **加载期间执行retile**
flash::gemm_smem(rAccScore, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
Tensor scores = make_tensor(rAccScore.data(), flash::convert_layout_acc_rowcol(rAccScore.layout()));
// NOTE: 2. mask within N BLOCKs
if (Is_causal == true && nbi * kBlockN >= seqlen_start) {
flash::mask_within_nblock<kBlockM, kBlockN, kNWarps>(scores, m_block, nbi);
}
// NOTE: 等待V加载完成, 为下个K加载准备初始状态
flash::cp_async_wait<0>();
__syncthreads();
// advance K
if (nbi != n_block_max - 1) {
gK = local_tile(K, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(nbi + 1, _));
tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
cute::cp_async_fence();
}
// 计算softmax
// NOTE: rAccOut记录softmax后所有的分子
nbi == 0 ? flash::softmax_rescale_o</*Is_first=*/true>(scores, scores_max, scores_sum, rAccOut, params.softmax_scale) :
flash::softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, rAccOut, params.softmax_scale);
// 实际执行QK @ V
// (score AKA rAccScore): QK[M, N] @ V[N, dim]
// NOTE: DABC: F32F16F16F32, convert D type(F32) to A type(F16)
// TODO: convert_type目前写死
Tensor rP = flash::convert_type_f32_to_f16(rAccScore);
// NOTE: Convert from layout C to layout A
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<TiledMMA>(scores.layout()));
flash::gemm_A_in_regs(rAccOut, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
// NOTE: 最后统一除上分母部分
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// AKA reshape to (nrow, ncol) but with specific MMA layout
Tensor acc_o_rowcol = make_tensor(rAccOut.data(), flash::convert_layout_acc_rowcol(rAccOut.layout()));
// NOTE: 保存lse给bwd
Tensor lse = make_fragment_like(scores_sum);
// for row
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
// compute lse
// NOTE: here we use max * scale
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.softmax_scale + __logf(sum);
float scale = inv_sum;
// for col
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scale;
}
}
// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type_f32_to_f16(rAccOut);
// 复用sQ的smem做sO的拷出
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
// TODO: review
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// NOTE: 先拷贝到smem
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
Tensor gO = local_tile(O, make_tile(Int<kBlockM>{}, Int<kHeadDim>{}), make_coord(m_block, _));
// 创建到smem -> gmem的拷贝
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO(_, _, 0));
__syncthreads();
// NOTE:: 再拷贝到gmem
// TODO: review, 这里两个copy的作用
Tensor tOrO = make_tensor<Element>(shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
flash::copy(gmem_tiled_copy_O, tOrO, tOgO);
// NOTE: 写回lse
Tensor gLSE = local_tile(LSE, make_tile(Int<kBlockM>{}), make_coord(m_block));
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0>(taccOcO))::value == 4);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
// TODO: review this shape
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
// TODO: 搞清楚这里的逻辑
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
// if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
gLSE(row) = lse(mi);
}
}
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// TODO: check if works: default stream = 0
using Element = typename Kernel_traits::Element;
using SmemLayoutQ = typename Kernel_traits::SmemLayoutQ;
using SmemLayoutK = typename Kernel_traits::SmemLayoutKV;
using SmemLayoutV = typename Kernel_traits::SmemLayoutKV;
const int num_m_block =
(params.q_seqlen + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.bs * params.head, 1);
dim3 block(Kernel_traits::kNThreads);
int smem_size = int(sizeof(SharedStorage<Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV>));
auto kernel = &flash_attention_v2_cutlass_kernel<Kernel_traits, Is_causal, Flash_fwd_params>;
// NOTE: smem过大时需要设置
if (smem_size >= 48 * 1024) {
CUDA_ERROR_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// TODO: stream
kernel<<<grid, block, smem_size>>>(params);
}
template<typename T, int Headdim>
void run_flash_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
// TODO: 挨个写出特化, 目前使用通用模板
// 如, run_flash_fwd_hdim32用于特化hdim=32
// 这样做可以根据实际情况微调kBlockN和kBlockM的组合, 也可以加速编译
template<typename T, int Headdim>
void run_flash_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, /*kBlockM_=*/128, /*kBlockN_=*/128, /*kNWarps_=*/4, T>, Is_causal>(params, stream);
// TODO: kBlockM, kBlockN的组合
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, /*kBlockM_=*/64, /*kBlockN_=*/64, /*kNWarps_=*/4, T>, Is_causal>(params, stream);
});
}
// entry point of flash attention
void run_flash_attn_cutlass(Flash_fwd_params ¶ms, cudaStream_t stream) {
// FP16_SWITCH yield elem_type namespace
FP16_SWITCH(!params.is_bf16, [&] {
// FWD_HEADDIM_SWITCH yield kHeadDim constexpr
FWD_HEADDIM_SWITCH(params.dim, [&] {
run_flash_fwd_<elem_type, kHeadDim>(params, stream);
});
});
}
std::vector<torch::Tensor> flash_attention_v2_cutlass(torch::Tensor q, torch::Tensor k,
torch::Tensor v, bool is_causal = false, float softmax_scale=1) {
CHECK_INPUT(q);
CHECK_INPUT(k);
CHECK_INPUT(v);
// batch size
int bs = q.size(0);
// head number
int head = q.size(1);
// seqlen
int seqlen = q.size(2);
// dim
int dim = q.size(3);
auto out = torch::empty_like(q);
auto opts = q.options();
auto softmax_lse = torch::empty({bs, head, seqlen}, opts.dtype(torch::kFloat32));
Flash_fwd_params params;
set_params_fprop(params, q, k, v, out,
softmax_lse.data_ptr(), softmax_scale, is_causal);
run_flash_attn_cutlass(params, 0);
// Wait until kernel finish.
cudaDeviceSynchronize();
CUDA_ERROR_CHECK(cudaGetLastError());
return {out, softmax_lse};
}