diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 90054658..f69d767d 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -423,6 +423,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( ParamsT params(q, paged_kv, /*custom_mask=*/nullptr, qo_indptr, /*qk_indptr=*/nullptr, q_offset, o, lse, /*alibi_slopes=*/nullptr, num_qo_heads, + /*q_stride_n*/num_qo_heads*HEAD_DIM, /*q_stride_h*/HEAD_DIM, /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); params.request_indices = handler->GetRequestIndices(); @@ -525,6 +526,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( /*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi=*/false)>; ParamsT params(q, q_offset, paged_kv, o, lse, /*alibi_slopes=*/nullptr, num_qo_heads, + /*q_stride_n*/num_qo_heads*HEAD_DIM, /*q_stride_h*/HEAD_DIM, /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); params.request_indices = handler->GetRequestIndices();