diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/flashinfer-aot/csrc_aot/batch_prefill.cu index 0289269f..d192514d 100644 --- a/flashinfer-aot/csrc_aot/batch_prefill.cu +++ b/flashinfer-aot/csrc_aot/batch_prefill.cu @@ -71,14 +71,14 @@ std::vector BatchPrefillWithKVCachePlan( return plan_info.ToVector(); } -std::vector BatchPrefillWithRaggedKVCacheRun( +torch::Tensor BatchPrefillWithRaggedKVCacheRun( unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + std::optional maybe_lse) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -98,10 +98,11 @@ std::vector BatchPrefillWithRaggedKVCacheRun( auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - int64_t nnz_qo = q.size(0); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = float_workspace_buffer.data_ptr(); @@ -140,7 +141,7 @@ std::vector BatchPrefillWithRaggedKVCacheRun( : nullptr, /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -187,14 +188,10 @@ std::vector BatchPrefillWithRaggedKVCacheRun( }); }); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } -std::vector BatchPrefillWithPagedKVCacheRun( +torch::Tensor BatchPrefillWithPagedKVCacheRun( unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, @@ -202,7 +199,7 @@ std::vector BatchPrefillWithPagedKVCacheRun( torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -221,10 +218,11 @@ std::vector BatchPrefillWithPagedKVCacheRun( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - int64_t nnz_qo = q.size(0); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); @@ -277,7 +275,7 @@ std::vector BatchPrefillWithPagedKVCacheRun( maybe_qk_indptr.has_value() ? static_cast(maybe_qk_indptr->data_ptr()) : nullptr, /*q_offset=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), /*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -323,9 +321,5 @@ std::vector BatchPrefillWithPagedKVCacheRun( }); }); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; }