diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index a77b11ce..e12c9d06 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -400,12 +400,12 @@ def single_decode_with_kv_cache( q.dtype, head_dim, PosEncodingMode[pos_encoding_mode].value, - MaskMode.NON_CAUSAL.value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap False, # allow_fp16_qk_reduction ) .run( + MaskMode.NON_CAUSAL.value, q.unsqueeze(0), k, v, @@ -418,7 +418,7 @@ def single_decode_with_kv_cache( sm_scale, rope_scale, rope_theta, - False, # return_lse + None, # maybe_lse )[0] .squeeze(0) ) @@ -743,8 +743,10 @@ def plan( indptr_host = indptr.to("cpu") if data_type is not None: - q_data_type = data_type - kv_data_type = data_type + if q_data_type is None: + q_data_type = data_type + if kv_data_type is None: + kv_data_type = data_type q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: @@ -761,7 +763,6 @@ def plan( indptr.dtype, head_dim, PosEncodingMode[pos_encoding_mode].value, - MaskMode.NON_CAUSAL.value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap False, # allow_fp16_qk_reduction @@ -938,6 +939,7 @@ def run( if self.use_tensor_cores: out = self._cached_module.paged_run( + MaskMode.NON_CAUSAL.value, self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py index e6498aab..c6e4b788 100644 --- a/python/flashinfer/jit/attention.py +++ b/python/flashinfer/jit/attention.py @@ -33,7 +33,6 @@ from .utils import ( dtype_map, filename_safe_dtype_map, - mask_mode_literal, pos_encoding_mode_literal, write_if_different, ) @@ -216,7 +215,6 @@ def get_single_prefill_cu_str( dtype_o: torch.dtype, head_dim: int, pos_encoding_mode: int, - mask_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, use_fp16_qk_reduction: bool, @@ -228,7 +226,6 @@ def get_single_prefill_cu_str( dtype_o=dtype_map[dtype_o], head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - mask_mode=mask_mode_literal[mask_mode], use_sliding_window="true" if use_sliding_window else "false", use_logits_soft_cap="true" if use_logits_soft_cap else "false", use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false", @@ -241,7 +238,6 @@ def get_single_prefill_uri( dtype_o: torch.dtype, head_dim: int, pos_encoding_mode: int, - mask_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, use_fp16_qk_reduction: bool, @@ -252,7 +248,6 @@ def get_single_prefill_uri( f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" f"head_dim_{head_dim}_" f"posenc_{pos_encoding_mode}_" - f"mask_{mask_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" f"f16qk_{use_fp16_qk_reduction}" @@ -280,7 +275,6 @@ def get_batch_prefill_cu_str( dtype_idx: torch.dtype, head_dim: int, pos_encoding_mode: int, - mask_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, use_fp16_qk_reduction: bool, @@ -293,7 +287,6 @@ def get_batch_prefill_cu_str( dtype_idx=dtype_map[dtype_idx], head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - mask_mode=mask_mode_literal[mask_mode], use_sliding_window="true" if use_sliding_window else "false", use_logits_soft_cap="true" if use_logits_soft_cap else "false", use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false", @@ -307,7 +300,6 @@ def get_batch_prefill_uri( dtype_idx: torch.dtype, head_dim: int, pos_encoding_mode: int, - mask_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, use_fp16_qk_reduction: bool, @@ -319,7 +311,6 @@ def get_batch_prefill_uri( f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" f"head_dim_{head_dim}_" f"posenc_{pos_encoding_mode}_" - f"mask_{mask_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" f"f16qk_{use_fp16_qk_reduction}" @@ -424,7 +415,6 @@ def get_customize_single_prefill_cu_str( dtype_kv: torch.dtype, dtype_o: torch.dtype, head_dim: int, - mask_mode: int, additional_input_tensor_var_names: List[str], additional_input_tensor_var_types: List[str], additional_input_scalar_var_names: List[str], @@ -489,7 +479,6 @@ def get_customize_single_prefill_cu_str( dtype_kv=dtype_map[dtype_kv], dtype_o=dtype_map[dtype_o], head_dim=head_dim, - mask_mode=mask_mode_literal[mask_mode], additional_params_decl=additional_params_decl, additional_params=additional_params, additional_params_init=additional_params_init, diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index 3855fc2b..f0c2eafa 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -25,12 +25,9 @@ using namespace flashinfer; -{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %} {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; -using RaggedAttentionVariant = ComposedAttention; using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; -using PagedAttentionVariant = ComposedAttention; std::vector BatchPrefillWithKVCachePlan( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, @@ -68,6 +65,7 @@ } torch::Tensor BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor k, torch::Tensor v, @@ -109,10 +107,10 @@ RaggedParamsT params( static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()), static_cast<{{ dtype_kv }}*>(v.data_ptr()), - {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %}, + /*custom_mask=*/(maybe_custom_mask ? static_cast(maybe_custom_mask->data_ptr()) : nullptr), static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()), static_cast<{{ dtype_idx }}*>(kv_indptr.data_ptr()), - {% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %}, + /*qk_indptr=*/(maybe_qk_indptr ? static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()) : nullptr), /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast<{{ dtype_o }}*>(o.data_ptr()), /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), @@ -141,10 +139,16 @@ cudaError_t status = cudaSuccess; - DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - status = BatchPrefillWithRaggedKVCacheDispatched< - CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, RaggedAttentionVariant>( - params, tmp_v, tmp_s, torch_current_stream); + MaskMode mask_mode = static_cast(mask_mode_code); + + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; + using RaggedAttentionVariant = ComposedAttention; + DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { + status = BatchPrefillWithRaggedKVCacheDispatched< + CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>( + params, tmp_v, tmp_s, torch_current_stream); + }); }); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status)); @@ -153,6 +157,7 @@ } torch::Tensor BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, @@ -215,9 +220,9 @@ PagedParamsT params( static_cast<{{ dtype_q }}*>(q.data_ptr()), paged_kv, - {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %}, + /*custom_mask=*/(maybe_custom_mask ? static_cast(maybe_custom_mask->data_ptr()) : nullptr), static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()), - {% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %}, + /*qk_indptr=*/(maybe_qk_indptr ? static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()) : nullptr), /*q_offset=*/nullptr, static_cast<{{ dtype_o }}*>(o.data_ptr()), /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), @@ -245,10 +250,16 @@ cudaError_t status = cudaSuccess; - DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - status = BatchPrefillWithPagedKVCacheDispatched< - CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, PagedAttentionVariant>( - params, tmp_v, tmp_s, torch_current_stream); + MaskMode mask_mode = static_cast(mask_mode_code); + + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; + using PagedAttentionVariant = ComposedAttention; + DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { + status = BatchPrefillWithPagedKVCacheDispatched< + CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>( + params, tmp_v, tmp_s, torch_current_stream); + }); }); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); diff --git a/python/flashinfer/jit/single_prefill_templ.py b/python/flashinfer/jit/single_prefill_templ.py index e741a9d0..2be8b1d3 100644 --- a/python/flashinfer/jit/single_prefill_templ.py +++ b/python/flashinfer/jit/single_prefill_templ.py @@ -81,9 +81,9 @@ {{ variant_decl }} -std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int layout, int32_t window_left, bool return_lse{{ additional_func_params }}) { +torch::Tensor single_prefill_with_kv_cache( + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, unsigned int layout, int32_t window_left, std::optional maybe_lse{{ additional_func_params }}) { auto device = q.device(); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; @@ -104,9 +104,11 @@ } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } using ParamsT = SinglePrefillParams; @@ -115,22 +117,22 @@ static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()), static_cast<{{ dtype_kv }}*>(v.data_ptr()), static_cast<{{ dtype_o }}*>(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left{{ additional_params_data }}); - cudaError_t status = - SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, {{ mask_mode }}, AttentionVariant>( - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); + MaskMode mask_mode = static_cast(mask_mode_code); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + cudaError_t status = + SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, MASK_MODE, AttentionVariant>( + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + }); + + return o; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -149,12 +151,11 @@ using namespace flashinfer; -{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %} {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; -using AttentionVariant = ComposedAttention; torch::Tensor single_prefill_with_kv_cache( + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_packed_custom_mask, torch::Tensor tmp, std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, std::optional maybe_lse) { @@ -188,7 +189,7 @@ ParamsT params( static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()), static_cast<{{ dtype_kv }}*>(v.data_ptr()), - {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_packed_custom_mask->data_ptr()){% else %}nullptr{% endif %}, + /*custom_mask=*/(maybe_packed_custom_mask ? static_cast(maybe_packed_custom_mask->data_ptr()) : nullptr), static_cast<{{ dtype_o }}*>(o.data_ptr()), /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, @@ -196,12 +197,19 @@ kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); - cudaError_t status = - SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, AttentionVariant>( - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); + + MaskMode mask_mode = static_cast(mask_mode_code); + + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; + using AttentionVariant = ComposedAttention; + cudaError_t status = + SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, AttentionVariant>( + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + }); return o; } diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index a39dc014..ec6515be 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -84,12 +84,7 @@ def get_single_prefill_module(*args): if has_prebuilt_ops and uri in prebuilt_ops_uri: from . import _kernels - # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later - mask_mode = args[5] - run_func = lambda *run_args: _kernels.single_prefill_with_kv_cache( - mask_mode, - *run_args, - ) + run_func = _kernels.single_prefill_with_kv_cache else: run_func = compile_single_prefill_module(*args).run @@ -97,6 +92,7 @@ def get_single_prefill_module(*args): @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "maybe_lse")) def run_single_prefill( + mask_mode: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -112,6 +108,7 @@ def run_single_prefill( maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: return run_func( + mask_mode, q, k, v, @@ -129,6 +126,7 @@ def run_single_prefill( @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_single_prefill( + mask_mode: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -164,19 +162,8 @@ def get_batch_prefill_module(*args): head_dim, *plan_args, ) - mask_mode = args[6] - ragged_run_func = ( - lambda *run_args: _kernels.batch_prefill_with_ragged_kv_cache_run( - mask_mode, - *run_args, - ) - ) - paged_run_func = ( - lambda *run_args: _kernels.batch_prefill_with_paged_kv_cache_run( - mask_mode, - *run_args, - ) - ) + ragged_run_func = _kernels.batch_prefill_with_ragged_kv_cache_run + paged_run_func = _kernels.batch_prefill_with_paged_kv_cache_run else: module = compile_batch_prefill_module(*args) plan_func = module.plan @@ -194,6 +181,7 @@ def get_batch_prefill_module(*args): ), ) def ragged_run( + mask_mode: int, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], @@ -214,6 +202,7 @@ def ragged_run( maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: return ragged_run_func( + mask_mode, float_workspace_buffer, int_workspace_buffer, plan_info_vec, @@ -236,6 +225,7 @@ def ragged_run( @register_fake_op(f"flashinfer::{uri}_ragged_run") def _fake_ragged_run( + mask_mode: int, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], @@ -270,6 +260,7 @@ def _fake_ragged_run( ), ) def paged_run( + mask_mode: int, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], @@ -292,6 +283,7 @@ def paged_run( maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: return paged_run_func( + mask_mode, float_workspace_buffer, int_workspace_buffer, plan_info_vec, @@ -316,6 +308,7 @@ def paged_run( @register_fake_op(f"flashinfer::{uri}_paged_run") def _fake_paged_run( + mask_mode: int, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], @@ -358,6 +351,7 @@ def single_prefill_with_kv_cache_with_jit_module( v: torch.Tensor, *args, kv_layout: str = "NHD", + mask_mode: int = MaskMode.NON_CAUSAL.value, window_left: int = -1, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -366,7 +360,7 @@ def single_prefill_with_kv_cache_with_jit_module( if return_lse: lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) out = jit_module.run( - q, k, v, tmp, TensorLayout[kv_layout].value, window_left, lse, *args + mask_mode, q, k, v, tmp, TensorLayout[kv_layout].value, window_left, lse, *args ) return (out, lse) if return_lse else out @@ -568,11 +562,11 @@ def single_prefill_with_kv_cache( q.dtype, q.shape[-1], PosEncodingMode[pos_encoding_mode].value, - mask_mode, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ).run( + mask_mode, q, k, v, @@ -1046,14 +1040,6 @@ def plan( qo_indptr_host = qo_indptr.to("cpu") paged_kv_indptr_host = paged_kv_indptr.to("cpu") - if packed_custom_mask is not None: - mask_mode = MaskMode.CUSTOM.value - else: - if causal: - mask_mode = MaskMode.CAUSAL.value - else: - mask_mode = MaskMode.NON_CAUSAL.value - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type self._cached_module = get_batch_prefill_module( @@ -1063,7 +1049,6 @@ def plan( paged_kv_indptr.dtype, head_dim, PosEncodingMode[pos_encoding_mode].value, - mask_mode, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, @@ -1206,7 +1191,16 @@ def run( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) + if self._custom_mask_buf is not None: + mask_mode = MaskMode.CUSTOM.value + else: + if self._causal: + mask_mode = MaskMode.CAUSAL.value + else: + mask_mode = MaskMode.NON_CAUSAL.value + out = self._cached_module.paged_run( + mask_mode, self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, @@ -1633,14 +1627,6 @@ def plan( qo_indptr_host = qo_indptr.to("cpu") kv_indptr_host = kv_indptr.to("cpu") - if packed_custom_mask is not None: - mask_mode = MaskMode.CUSTOM.value - else: - if causal: - mask_mode = MaskMode.CAUSAL.value - else: - mask_mode = MaskMode.NON_CAUSAL.value - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type self._cached_module = get_batch_prefill_module( @@ -1650,7 +1636,6 @@ def plan( kv_indptr.dtype, head_dim, PosEncodingMode[pos_encoding_mode].value, - mask_mode, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, @@ -1782,7 +1767,16 @@ def run( k = k.to(torch.float16) v = v.to(torch.float16) + if self._custom_mask_buf is not None: + mask_mode = MaskMode.CUSTOM.value + else: + if self._causal: + mask_mode = MaskMode.CAUSAL.value + else: + mask_mode = MaskMode.NON_CAUSAL.value + out = self._cached_module.ragged_run( + mask_mode, self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index dc5dca3a..19c4734e 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -294,6 +294,7 @@ def plan( self._packed_mask_buf = None self._qk_indptr_buf = None mask_mode = MaskMode.NON_CAUSAL.value + self._mask_mode = mask_mode self.M = M self.N = N @@ -343,7 +344,6 @@ def plan( indptr.dtype, head_dim, PosEncodingMode[pos_encoding_mode].value, - mask_mode, False, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, @@ -448,6 +448,7 @@ def run( if self._use_tensor_cores: out = self._cached_module.paged_run( + self._mask_mode, self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index b3b608d4..921aa652 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -463,26 +463,86 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( - 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16 + 256, + 54, + 8, + 8, + 8, + 128, + "NHD", + "NONE", + 0.0, + False, + torch.float16, + torch.float16, + True, ) test_batch_decode_with_tuple_paged_kv_cache( - 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16 + 256, + 54, + 8, + 8, + 8, + 128, + "NHD", + "NONE", + 0.0, + False, + torch.float16, + torch.float16, + True, ) test_batch_decode_with_paged_kv_cache( - 12, 2048, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16 + 12, + 2048, + 8, + 8, + 8, + 128, + "NHD", + "NONE", + 0.0, + False, + torch.float16, + torch.float16, + True, ) test_batch_decode_with_paged_kv_cache( - 12, 54, 1, 8, 8, 128, "HND", "NONE", 0.0, True, torch.float16, torch.float8_e5m2 + 12, + 54, + 1, + 8, + 8, + 128, + "HND", + "NONE", + 0.0, + True, + torch.float16, + torch.float8_e5m2, + True, ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 + 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16, True ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 + 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16, True ) test_batch_decode_with_paged_kv_cache( - 12, 54, 1, 8, 8, 128, "HND", "NONE", 0.0, True, torch.float16, torch.float8_e5m2 + 12, + 54, + 1, + 8, + 8, + 128, + "HND", + "NONE", + 0.0, + True, + torch.float16, + torch.float8_e5m2, + True, ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2 + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2, True ) diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index e056a686..b7b0c981 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -688,16 +688,16 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( - 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False + 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True ) test_batch_prefill_with_tuple_paged_kv_cache( - 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False + 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True ) test_batch_prefill_with_paged_kv_cache( - 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False + 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False, True ) test_batch_prefill_with_paged_kv_cache_custom_mask( - 1, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False + 1, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False, True ) test_batch_prefill_with_ragged_kv_cache( 12, 54, 37, 8, 8, 128, True, "NONE", 0.0, False diff --git a/tests/test_fp8_prefill.py b/tests/test_fp8_prefill.py index 22fe2285..414173f4 100644 --- a/tests/test_fp8_prefill.py +++ b/tests/test_fp8_prefill.py @@ -65,10 +65,10 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( ).to(0) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + wrapper_f16 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) - wrapper.plan( + wrapper_f16.plan( qo_indptr, kv_indptr, kv_indices, @@ -78,8 +78,9 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( head_dim, page_size, q_data_type=torch.float16, + kv_data_type=torch.float16, ) - o_fp16 = wrapper.run(q, kv_data) + o_fp16 = wrapper_f16.run(q, kv_data) k_data, v_data = torch.chunk(kv_data, 2, dim=1) k_scale = k_data.amax().item() / 256 v_scale = v_data.amax().item() / 256 @@ -88,7 +89,10 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( v_fp8 = (v_data / v_scale).to(dtype) kv_data_fp8 = torch.cat([k_fp8, v_fp8], dim=1) - wrapper.plan( + wrapper_f8 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper_f8.plan( qo_indptr, kv_indptr, kv_indices, @@ -98,8 +102,9 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( head_dim, page_size, q_data_type=torch.float16, + kv_data_type=dtype, ) - o_fp8 = wrapper.run( + o_fp8 = wrapper_f8.run( q, kv_data_fp8.to(dtype), k_scale=k_scale, @@ -163,6 +168,7 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( head_dim, page_size, q_data_type=torch.float16, + kv_data_type=dtype, ) o_fp8 = wrapper.run(q, kv_data) @@ -177,8 +183,8 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( num_kv_heads, head_dim, page_size, - data_type=dtype, q_data_type=torch.float16, + kv_data_type=dtype, ) o_decode_fp8 = decode_wrapper.run(q, kv_data) diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index 4593515b..241aa658 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -150,7 +150,6 @@ def test_flash_sigmoid(): torch.float16, # dtype_kv torch.float16, # dtype_o 128, # hidden_dim - MaskMode.NON_CAUSAL.value, [], # additional_input_tensor_var_names [], # additional_input_tensor_var_types ["logits_scale", "sigmoid_bias"], # additional_input_scalar_var_names @@ -167,7 +166,7 @@ def test_flash_sigmoid(): v = torch.randn(1027, 8, 128, dtype=torch.float16, device="cuda") logits_scale = 1.0 / math.sqrt(128) sigmoid_bias = 0.25 - o = f(q, k, v, logits_scale, sigmoid_bias) + o = f(q, k, v, logits_scale, sigmoid_bias, mask_mode=MaskMode.NON_CAUSAL.value) p = torch.sigmoid( torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * logits_scale + sigmoid_bias @@ -225,7 +224,6 @@ def test_dump_logits(): torch.float16, # dtype_kv torch.float16, # dtype_o 128, # hidden_dim - MaskMode.NON_CAUSAL.value, ["output_logits"], # additional_input_tensor_var_names ["float"], # additional_input_tensor_var_types ["sm_scale"], # additional_input_scalar_var_names @@ -242,7 +240,7 @@ def test_dump_logits(): v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") logits = torch.empty(32, 128, 1023, dtype=torch.float32, device="cuda") sm_scale = 1.0 / math.sqrt(128) - o = f(q, k, v, logits, sm_scale) + o = f(q, k, v, logits, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value) p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half() @@ -300,7 +298,6 @@ def test_debug_print_logits(): torch.float16, # dtype_kv torch.float16, # dtype_o 128, # hidden_dim - MaskMode.NON_CAUSAL.value, [], # additional_input_tensor_var_names [], # additional_input_tensor_var_types ["sm_scale"], # additional_input_scalar_var_names @@ -316,7 +313,7 @@ def test_debug_print_logits(): k = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") sm_scale = 1.0 / math.sqrt(128) - o = f(q, k, v, sm_scale) + o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value) p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half()