Skip to content

Commit

Permalink
Add num_splits in flash_attn backward api to support determistic resu…
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored Aug 30, 2023
1 parent b5bdb79 commit e6b9d0d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 14 deletions.
32 changes: 21 additions & 11 deletions csrc/capi/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ void set_params_fprop(Flash_fwd_params &params,
float p_dropout,
float softmax_scale,
bool is_causal,
bool is_bf16) {
bool is_bf16) {
// Reset the parameters
memset(&params, 0, sizeof(params));

params.is_bf16 = is_bf16;

// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
Expand Down Expand Up @@ -185,7 +184,8 @@ void set_params_dgrad(Flash_bwd_params &params,
float p_dropout,
float softmax_scale,
bool is_causal,
bool is_bf16) {
bool is_bf16,
const int num_splits=0) {

set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
Expand All @@ -196,7 +196,8 @@ void set_params_dgrad(Flash_bwd_params &params,
softmax_lse_d,
p_dropout,
softmax_scale,
is_causal, is_bf16);
is_causal,
is_bf16);

// Set the pointers and strides.
params.do_ptr = dout;
Expand Down Expand Up @@ -225,6 +226,7 @@ void set_params_dgrad(Flash_bwd_params &params,

// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;
params.num_splits = num_splits;
}

void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
Expand Down Expand Up @@ -431,6 +433,7 @@ bool flash_attn_bwd(const void * const dout,
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
const int num_splits,
cudaStream_t stream,
uint64_t seed,
uint64_t offset) {
Expand Down Expand Up @@ -465,11 +468,13 @@ bool flash_attn_bwd(const void * const dout,
num_heads, num_heads_k,
head_size, head_size_rounded,
const_cast<void *>(q),
const_cast<void *>(k),
const_cast<void *>(v),
const_cast<void *>(out),
const_cast<void *>(k),
const_cast<void *>(v),
const_cast<void *>(out),
const_cast<void *>(dout),
dq, dk, dv,
dq,
dk,
dv,
nullptr,
nullptr,
loop ? dq_accum : nullptr,
Expand All @@ -480,7 +485,8 @@ bool flash_attn_bwd(const void * const dout,
p_dropout,
softmax_scale,
is_causal,
is_bf16);
is_bf16,
num_splits);

auto launch = &run_mha_bwd;

Expand Down Expand Up @@ -527,6 +533,7 @@ bool flash_attn_varlen_bwd(const void * const dout,
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
const int num_splits,
cudaStream_t stream,
uint64_t seed,
uint64_t offset) {
Expand Down Expand Up @@ -562,7 +569,9 @@ bool flash_attn_varlen_bwd(const void * const dout,
const_cast<void*>(v),
const_cast<void*>(out),
const_cast<void*>(dout),
dq, dk, dv,
dq,
dk,
dv,
const_cast<int32_t*>(cu_seqlens_q),
const_cast<int32_t*>(cu_seqlens_k),
loop ? dq_accum : nullptr,
Expand All @@ -573,7 +582,8 @@ bool flash_attn_varlen_bwd(const void * const dout,
p_dropout,
softmax_scale,
is_causal,
is_bf16);
is_bf16,
num_splits);

auto launch = &run_mha_bwd;

Expand Down
2 changes: 2 additions & 0 deletions csrc/capi/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
const int num_splits,
cudaStream_t stream,
uint64_t seed,
uint64_t offset);
Expand Down Expand Up @@ -116,6 +117,7 @@ bool flash_attn_varlen_bwd(const void * const dout, // total_q x num_heads, x h
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
const int num_splits,
cudaStream_t stream,
uint64_t seed,
uint64_t offset);
Expand Down
1 change: 1 addition & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct Flash_bwd_params : public Flash_fwd_params {

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
int num_splits;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
12 changes: 10 additions & 2 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1508,8 +1508,16 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;

compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
constexpr int kBlockN = Kernel_traits::kBlockN;
if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
int loop_step_x = 0;
for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
loop_step_x += 1;
}
} else {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
const int num_n_block = params.num_splits == 1 ? params.num_splits : (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
dim3 grid_n(num_n_block, params.b, params.h);

flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
Expand Down

0 comments on commit e6b9d0d

Please sign in to comment.