diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index c1bf4cc7..73434b98 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -575,10 +575,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( sync_state(st, reinterpret_cast(smem), smem_md); st.normalize(); - st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); - // write lse - if (lse != nullptr) { - lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse(); + if (tz == 0) { + st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + // write lse + if (lse != nullptr) { + lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse(); + } } }