Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1 from kircle888/smask_up
Browse files Browse the repository at this point in the history
Smask up ready
  • Loading branch information
kircle888 authored May 24, 2024
2 parents 94b7725 + 07791f4 commit c3212be
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 25 deletions.
6 changes: 4 additions & 2 deletions csrc/capi/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ bool flash_attn_fwd(const void * const q,
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down Expand Up @@ -619,7 +620,7 @@ bool flash_attn_fwd(const void * const q,
false/*varlen_padded_input=*/,
const_cast<void *>(attn_mask),
const_cast<void *>(attn_mask_start_row_indices),
nullptr,
const_cast<void *>(attn_mask_end_row_indices),
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);
Expand Down Expand Up @@ -798,6 +799,7 @@ bool flash_attn_bwd(const void * const dout,
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down Expand Up @@ -890,7 +892,7 @@ bool flash_attn_bwd(const void * const dout,
num_splits,
const_cast<void *>(attn_mask),
const_cast<void *>(attn_mask_start_row_indices),
nullptr,
const_cast<void *>(attn_mask_end_row_indices),
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);
Expand Down
2 changes: 2 additions & 0 deletions csrc/capi/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ bool flash_attn_fwd(const void * const q, // batch_size x seqlen_q x num
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down Expand Up @@ -128,6 +129,7 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down
37 changes: 32 additions & 5 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

// Shared memory.
__shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_up[Kernel_traits::kBlockN];
extern __shared__ char smem_[];

// The thread index.
Expand All @@ -459,13 +460,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

const index_t row_offset_sparsemask_nblock =
(bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * cute::ceil_div(params.seqlen_k, kBlockN);
const int* gSparseMaskDownMax = reinterpret_cast<int32_t*>(params.attn_sparsemask_down_nblockmax)+row_offset_sparsemask_nblock;
const int* gSparseMaskDownMin = reinterpret_cast<int32_t*>(params.attn_sparsemask_down_nblockmin)+row_offset_sparsemask_nblock;
const int *gSparseMaskDownMax =
reinterpret_cast<int32_t *>(params.attn_sparsemask_down_nblockmax) +
row_offset_sparsemask_nblock;
const int *gSparseMaskDownMin =
reinterpret_cast<int32_t *>(params.attn_sparsemask_down_nblockmin) +
row_offset_sparsemask_nblock;
const int *gSparseMaskUpMax =
reinterpret_cast<int32_t *>(params.attn_sparsemask_up_nblockmax) +
row_offset_sparsemask_nblock;
const int *gSparseMaskUpMin =
reinterpret_cast<int32_t *>(params.attn_sparsemask_up_nblockmin) +
row_offset_sparsemask_nblock;

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
int attn_mask_end_row = 0;
if (Is_sparse_attn_mask) {
m_block_max = min(m_block_max, cute::ceil_div(gSparseMaskDownMax[n_block],kBlockM));
m_block_max = min(m_block_max,
cute::ceil_div(gSparseMaskDownMax[n_block], kBlockM));
attn_mask_start_row = gSparseMaskDownMin[n_block];
attn_mask_end_row = gSparseMaskUpMax[n_block];
}
const int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);

Expand Down Expand Up @@ -525,6 +539,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
make_stride(params.seqlen_k, _1{}));
Tensor gSparseMask = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.attn_mask_start_row_indices_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gSparseMaskUp = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.attn_mask_end_row_indices_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQdO{});
Expand All @@ -549,6 +565,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// sP and sdQ share the same memory so be careful
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskUp = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_up)), Shape<Int<kBlockN>>{});
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});

Expand Down Expand Up @@ -697,6 +714,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
if (!Is_causal && Is_sparse_attn_mask) {
m_block_min = max(m_block_min, gSparseMaskUpMin[n_block] / kBlockM);
}

// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
Expand Down Expand Up @@ -817,6 +837,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Is_sparse_attn_mask) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
if(!Is_causal)
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
}
__syncthreads();
}
Expand Down Expand Up @@ -865,7 +887,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tPgMask.data() = tPgMask.data() + (-kBlockM * params.seqlen_k);
}
if (!Is_causal) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
if (Is_sparse_attn_mask &&
((m_block + 1) * kBlockM >= attn_mask_start_row || m_block * kBlockM < attn_mask_end_row)){
flash::apply_sparse_mask(scores, sSparseMask, sSparseMaskUp, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
AtomLayoutMS * 16, n_block * kBlockN);
} else if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
}
Expand All @@ -875,7 +902,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements not beyond actual_seqlen_k.

if (Is_sparse_attn_mask && (m_block+1) * kBlockM >= attn_mask_start_row) {
if (Is_sparse_attn_mask && (m_block + 1) * kBlockM >= attn_mask_start_row) {
flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
AtomLayoutMS * 16, n_block * kBlockN);
Expand Down
88 changes: 72 additions & 16 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// Shared memory.
__shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN];
__shared__ int32_t sparse_mask_smem_up[Kernel_traits::kBlockN];
extern __shared__ char smem_[];

// The thread index.
Expand Down Expand Up @@ -203,8 +204,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

Tensor gSparseMask = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.attn_mask_start_row_indices_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
const int* gSparseMaskDownMax = reinterpret_cast<int32_t*>(params.attn_sparsemask_down_nblockmax)+row_offset_sparsemask_nblock;
const int* gSparseMaskDownMin = reinterpret_cast<int32_t*>(params.attn_sparsemask_down_nblockmin)+row_offset_sparsemask_nblock;
Tensor gSparseMaskUp = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.attn_mask_end_row_indices_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
const int* gSparseMaskDownMax = reinterpret_cast<int32_t*>(params.attn_sparsemask_down_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskDownMin = reinterpret_cast<int32_t*>(params.attn_sparsemask_down_nblockmin) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpMax = reinterpret_cast<int32_t*>(params.attn_sparsemask_up_nblockmax) + row_offset_sparsemask_nblock;
const int* gSparseMaskUpMin = reinterpret_cast<int32_t*>(params.attn_sparsemask_up_nblockmin) + row_offset_sparsemask_nblock;

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Expand All @@ -215,6 +220,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_)), Shape<Int<kBlockN>>{});
Tensor sSparseMaskUp = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_up)), Shape<Int<kBlockN>>{});

typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
Expand Down Expand Up @@ -407,7 +413,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
if (Is_sparse_attn_mask &&
((m_block + 1) * kBlockM >= gSparseMaskDownMin[n_block] || m_block * kBlockM < gSparseMaskUpMax[n_block])){
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
}
__syncthreads();
flash::apply_sparse_mask(scores, sSparseMask, sSparseMaskUp, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16, n_block * kBlockN);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
} else if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
Expand All @@ -420,7 +439,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
if (Is_sparse_attn_mask && (m_block+1) * kBlockM >= gSparseMaskDownMin[n_block]) {
if (Is_sparse_attn_mask && (m_block + 1) * kBlockM >= gSparseMaskDownMin[n_block]) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
}
Expand All @@ -439,9 +458,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
if (Is_sparse_attn_mask){
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
}
}

if (Is_sparse_attn_mask){
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
if (!Is_causal)
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
}

flash::cp_async_wait<0>();
Expand All @@ -456,10 +478,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}

// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);

if(Is_sparse_attn_mask){
// We must check inf if use sparse_attn_mask
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/true>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/true>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); }
else{
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
}
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
Expand Down Expand Up @@ -495,8 +523,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) {
if (Is_sparse_attn_mask && m_block * kBlockM >= gSparseMaskDownMax[n_block]) {
continue;
if (Is_sparse_attn_mask) {
if ((m_block * kBlockM >= gSparseMaskDownMax[n_block]) ||
(!Is_causal &&
(m_block + 1) * kBlockM < gSparseMaskUpMin[n_block])) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
if (!Is_causal)
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
if(Return_softmax)
tPgP.data() = tPgP.data() + (-kBlockN);
continue;
}
}
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
Expand Down Expand Up @@ -533,7 +572,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
params.unscale_softmax);
tPgMask.data() = tPgMask.data() + (-kBlockN);
}
if (Is_causal && Is_sparse_attn_mask && (m_block+1) * kBlockM >= gSparseMaskDownMin[n_block]) {
if (!Is_causal && Is_sparse_attn_mask &&
((m_block + 1) * kBlockM >= gSparseMaskDownMin[n_block] || m_block * kBlockM < gSparseMaskUpMax[n_block])){
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
sSparseMaskUp(tidx) = gSparseMaskUp(tidx);
}
__syncthreads();
flash::apply_sparse_mask(scores, sSparseMask, sSparseMaskUp, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16, n_block * kBlockN);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
} else if (Is_causal && Is_sparse_attn_mask && (m_block + 1) * kBlockM >= gSparseMaskDownMin[n_block]) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
}
Expand All @@ -546,11 +598,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}

if (Is_causal && Is_sparse_attn_mask) {
if (Is_sparse_attn_mask) {
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
if (!Is_causal)
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
}

if (Is_equal_seq_qk) {
if(Is_sparse_attn_mask){
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/true>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
} else if (Is_equal_seq_qk) {
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
} else {
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Expand Down
2 changes: 0 additions & 2 deletions csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ int *prepare_sparsemask(Flash_fwd_params &params, cudaStream_t stream) {
params.attn_sparsemask_down_nblockmin = nblock_smask + nblock_masklen;
params.attn_sparsemask_up_nblockmax = nblock_smask + 2 * nblock_masklen;
params.attn_sparsemask_up_nblockmin = nblock_smask + 3 * nblock_masklen;
params.attn_mask_end_row_indices_ptr =
nullptr; // TODO: up mask not enable now, will be ignored
if (params.attn_mask_start_row_indices_ptr != nullptr) {
scanMaxMinGpu<kBlockN>(
static_cast<const int *>(params.attn_mask_start_row_indices_ptr),
Expand Down
Loading

0 comments on commit c3212be

Please sign in to comment.