Skip to content

Commit

Permalink
feat: simplify prefill JIT compilation (#605)
Browse files Browse the repository at this point in the history
Compile all three mask modes (causal/non-causal/custom) altogether
instead of compiling them one-by-one.
  • Loading branch information
yzh119 authored Nov 11, 2024
1 parent bb67144 commit fe4f898
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 122 deletions.
12 changes: 7 additions & 5 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,12 @@ def single_decode_with_kv_cache(
q.dtype,
head_dim,
PosEncodingMode[pos_encoding_mode].value,
MaskMode.NON_CAUSAL.value,
window_left != -1, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
False, # allow_fp16_qk_reduction
)
.run(
MaskMode.NON_CAUSAL.value,
q.unsqueeze(0),
k,
v,
Expand All @@ -418,7 +418,7 @@ def single_decode_with_kv_cache(
sm_scale,
rope_scale,
rope_theta,
False, # return_lse
None, # maybe_lse
)[0]
.squeeze(0)
)
Expand Down Expand Up @@ -743,8 +743,10 @@ def plan(

indptr_host = indptr.to("cpu")
if data_type is not None:
q_data_type = data_type
kv_data_type = data_type
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type

q_data_type = canonicalize_torch_dtype(q_data_type)
if kv_data_type is None:
Expand All @@ -761,7 +763,6 @@ def plan(
indptr.dtype,
head_dim,
PosEncodingMode[pos_encoding_mode].value,
MaskMode.NON_CAUSAL.value,
window_left != -1, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
False, # allow_fp16_qk_reduction
Expand Down Expand Up @@ -938,6 +939,7 @@ def run(

if self.use_tensor_cores:
out = self._cached_module.paged_run(
MaskMode.NON_CAUSAL.value,
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
Expand Down
11 changes: 0 additions & 11 deletions python/flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from .utils import (
dtype_map,
filename_safe_dtype_map,
mask_mode_literal,
pos_encoding_mode_literal,
write_if_different,
)
Expand Down Expand Up @@ -216,7 +215,6 @@ def get_single_prefill_cu_str(
dtype_o: torch.dtype,
head_dim: int,
pos_encoding_mode: int,
mask_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_fp16_qk_reduction: bool,
Expand All @@ -228,7 +226,6 @@ def get_single_prefill_cu_str(
dtype_o=dtype_map[dtype_o],
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode],
mask_mode=mask_mode_literal[mask_mode],
use_sliding_window="true" if use_sliding_window else "false",
use_logits_soft_cap="true" if use_logits_soft_cap else "false",
use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false",
Expand All @@ -241,7 +238,6 @@ def get_single_prefill_uri(
dtype_o: torch.dtype,
head_dim: int,
pos_encoding_mode: int,
mask_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_fp16_qk_reduction: bool,
Expand All @@ -252,7 +248,6 @@ def get_single_prefill_uri(
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
f"head_dim_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"mask_{mask_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
f"f16qk_{use_fp16_qk_reduction}"
Expand Down Expand Up @@ -280,7 +275,6 @@ def get_batch_prefill_cu_str(
dtype_idx: torch.dtype,
head_dim: int,
pos_encoding_mode: int,
mask_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_fp16_qk_reduction: bool,
Expand All @@ -293,7 +287,6 @@ def get_batch_prefill_cu_str(
dtype_idx=dtype_map[dtype_idx],
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode],
mask_mode=mask_mode_literal[mask_mode],
use_sliding_window="true" if use_sliding_window else "false",
use_logits_soft_cap="true" if use_logits_soft_cap else "false",
use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false",
Expand All @@ -307,7 +300,6 @@ def get_batch_prefill_uri(
dtype_idx: torch.dtype,
head_dim: int,
pos_encoding_mode: int,
mask_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_fp16_qk_reduction: bool,
Expand All @@ -319,7 +311,6 @@ def get_batch_prefill_uri(
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
f"head_dim_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"mask_{mask_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
f"f16qk_{use_fp16_qk_reduction}"
Expand Down Expand Up @@ -424,7 +415,6 @@ def get_customize_single_prefill_cu_str(
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
head_dim: int,
mask_mode: int,
additional_input_tensor_var_names: List[str],
additional_input_tensor_var_types: List[str],
additional_input_scalar_var_names: List[str],
Expand Down Expand Up @@ -489,7 +479,6 @@ def get_customize_single_prefill_cu_str(
dtype_kv=dtype_map[dtype_kv],
dtype_o=dtype_map[dtype_o],
head_dim=head_dim,
mask_mode=mask_mode_literal[mask_mode],
additional_params_decl=additional_params_decl,
additional_params=additional_params,
additional_params_init=additional_params_init,
Expand Down
41 changes: 26 additions & 15 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@
using namespace flashinfer;
{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %}
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
std::vector<int64_t> BatchPrefillWithKVCachePlan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
Expand Down Expand Up @@ -68,6 +65,7 @@
}
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code,
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec,
torch::Tensor q, torch::Tensor k, torch::Tensor v,
Expand Down Expand Up @@ -109,10 +107,10 @@
RaggedParamsT params(
static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
static_cast<{{ dtype_kv }}*>(v.data_ptr()),
{% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %},
/*custom_mask=*/(maybe_custom_mask ? static_cast<uint8_t*>(maybe_custom_mask->data_ptr()) : nullptr),
static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()),
static_cast<{{ dtype_idx }}*>(kv_indptr.data_ptr()),
{% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %},
/*qk_indptr=*/(maybe_qk_indptr ? static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()) : nullptr),
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<{{ dtype_o }}*>(o.data_ptr()),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
Expand Down Expand Up @@ -141,10 +139,16 @@
cudaError_t status = cudaSuccess;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithRaggedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, RaggedAttentionVariant>(
params, tmp_v, tmp_s, torch_current_stream);
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithRaggedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>(
params, tmp_v, tmp_s, torch_current_stream);
});
});
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status));
Expand All @@ -153,6 +157,7 @@
}
torch::Tensor BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code,
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec,
torch::Tensor q,
Expand Down Expand Up @@ -215,9 +220,9 @@
PagedParamsT params(
static_cast<{{ dtype_q }}*>(q.data_ptr()), paged_kv,
{% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %},
/*custom_mask=*/(maybe_custom_mask ? static_cast<uint8_t*>(maybe_custom_mask->data_ptr()) : nullptr),
static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()),
{% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %},
/*qk_indptr=*/(maybe_qk_indptr ? static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()) : nullptr),
/*q_offset=*/nullptr,
static_cast<{{ dtype_o }}*>(o.data_ptr()),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
Expand Down Expand Up @@ -245,10 +250,16 @@
cudaError_t status = cudaSuccess;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithPagedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, PagedAttentionVariant>(
params, tmp_v, tmp_s, torch_current_stream);
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithPagedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>(
params, tmp_v, tmp_s, torch_current_stream);
});
});
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status));
Expand Down
62 changes: 35 additions & 27 deletions python/flashinfer/jit/single_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@
{{ variant_decl }}
std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, unsigned int layout, int32_t window_left, bool return_lse{{ additional_func_params }}) {
torch::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, unsigned int layout, int32_t window_left, std::optional<torch::Tensor> maybe_lse{{ additional_func_params }}) {
auto device = q.device();
unsigned int head_dim = q.size(2);
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
Expand All @@ -104,9 +104,11 @@
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}
using ParamsT = SinglePrefillParams;
Expand All @@ -115,22 +117,22 @@
static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
static_cast<{{ dtype_kv }}*>(v.data_ptr()),
static_cast<{{ dtype_o }}*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
kv_stride_n, kv_stride_h, head_dim, window_left{{ additional_params_data }});
cudaError_t status =
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, {{ mask_mode }}, AttentionVariant>(
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
if (return_lse) {
return {o, lse};
} else {
return {o};
}
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
cudaError_t status =
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, MASK_MODE, AttentionVariant>(
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
});
return o;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand All @@ -149,12 +151,11 @@
using namespace flashinfer;
{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %}
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
torch::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code,
torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_packed_custom_mask,
torch::Tensor tmp, std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
Expand Down Expand Up @@ -188,20 +189,27 @@
ParamsT params(
static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
static_cast<{{ dtype_kv }}*>(v.data_ptr()),
{% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr()){% else %}nullptr{% endif %},
/*custom_mask=*/(maybe_packed_custom_mask ? static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr()) : nullptr),
static_cast<{{ dtype_o }}*>(o.data_ptr()),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
{% if use_alibi == "true" %}static_cast<float*>(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale,
rope_scale, rope_theta);
cudaError_t status =
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, AttentionVariant>(
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
cudaError_t status =
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, AttentionVariant>(
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
});
return o;
}
Expand Down
Loading

0 comments on commit fe4f898

Please sign in to comment.