Skip to content

Commit

Permalink
fix: batch decode kernel redundant store output to gmem (#505)
Browse files Browse the repository at this point in the history
Hi, this is a minor fix, when bdz is greater than 1, there would be
redundant store to gmem operations for some warps. We may also check 'if
(tx == 0)' when storing lse value, but since bdx is 32 most of the time,
I think that would be fine.

Co-authored-by: tsu-bin <[email protected]>
  • Loading branch information
tsu-bin and tsu-bin authored Sep 25, 2024
1 parent 33ef957 commit 90e42a7
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast<float*>(smem), smem_md);
st.normalize();

st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
// write lse
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
if (tz == 0) {
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
// write lse
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
}
}
}

Expand Down

0 comments on commit 90e42a7

Please sign in to comment.