Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: fix batch_prefill.cu in AOT mode after #554 #559

Merged
merged 1 commit into from
Oct 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
return plan_info.ToVector();
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse) {
std::optional<torch::Tensor> maybe_lse) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
Expand All @@ -98,10 +98,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
int64_t nnz_qo = q.size(0);
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
Expand Down Expand Up @@ -140,7 +141,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
: nullptr,
/*q_offset=*/nullptr,
/*k_rope_pos_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h,
kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale,
rope_theta);
Expand Down Expand Up @@ -187,22 +188,18 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
torch::Tensor BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
Expand All @@ -221,10 +218,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
int64_t nnz_qo = q.size(0);
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

void* float_buffer_ptr = static_cast<void*>(float_workspace_buffer.data_ptr());
Expand Down Expand Up @@ -277,7 +275,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
maybe_qk_indptr.has_value() ? static_cast<IdType*>(maybe_qk_indptr->data_ptr())
: nullptr,
/*q_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);

Expand Down Expand Up @@ -323,9 +321,5 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}