-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR implements the JIT compilation (#170 ) of flashinfer, after this PR, flashinfer will compile kernels just-in-time for different input data types and shapes, and cached the kernels at the disk, instead of pre-compile a set of kernels in the wheel. # Motivation The pip wheel size is exploding as we add support to more data types, more head dimensions, more attention variants and more kernel implementation. Pre-compile everything is not sustainable, and impedes development speed. This PR refactors the codebase to use torch's [JIT Compiling Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions) feature instead of pre-compile kernels in the wheel. ## Attention Variants We learned from [FlexAttention](https://pytorch.org/blog/flexattention/) and describes every attention variant as a template class, each instance of the struct can carry some closure variable defined in local memory or shared memory, below are two examples (logits soft cap and alibi attention, the programming interface is tentative and will be updated as we improve the programmability of the JIT template): ```cuda template <typename ParamsT> struct LogitsSoftCap { using DTypeQ = typename ParamsT::DTypeQ; using DTypeKV = typename ParamsT::DTypeKV; using DTypeO = typename ParamsT::DTypeO; uint32_t qo_len, kv_len; uint32_t window_left; __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; } template <typename T> __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); } template <typename T> __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return params.logits_soft_cap * math::log2e * float(math::tanh(logits)); } __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return true; } }; template <typename ParamsT> struct ALIBIAttention { using DTypeQ = typename ParamsT::DTypeQ; using DTypeKV = typename ParamsT::DTypeKV; using DTypeO = typename ParamsT::DTypeO; using IdType = typename ParamsT::IdType; uint32_t qo_len, kv_len; uint32_t window_left; __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; } template <typename T> __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { return float(q) * params.sm_scale * math::log2e; } template <typename T> __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); } __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return true; } }; ``` User can customize their own `ParamsT` class and variants class to define their own attention variants, we hope such refactor will make the codebase more concise and extensive. # Roadmap After this PR, we will add support for: 1. PyPI wheels #153 2. fp8 tensor cores attention: #502 3. different head dimensions: #142 #454 #455 4. flashattention3 #369 5. multi-head latency attention #237 6. Generate ParamsT and Attention variants description from python dsl The development of this features have been blocked by the limitation of wheel size (binary size >= 2GB will trigger some linking issues), I hope this PR will make development easier in the future.
- Loading branch information
Showing
137 changed files
with
6,986 additions
and
6,122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../3rdparty |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# sdist & wheel | ||
include version.txt | ||
recursive-include include * | ||
recursive-include csrc * | ||
recursive-include 3rdparty/cutlass * | ||
|
||
# wheel-only | ||
exclude flashinfer/_build_meta.py | ||
|
||
# Unneeded files | ||
prune */__pycache__ | ||
global-exclude *.so |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../python/csrc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
/* | ||
* Copyright (c) 2023 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <torch/extension.h> | ||
|
||
#include <flashinfer/attention/decode_params.cuh> | ||
#include <flashinfer/attention/scheduler.cuh> | ||
#include <flashinfer/attention/variants.cuh> | ||
#include <optional> | ||
|
||
#include "pytorch_extension_utils.h" | ||
|
||
namespace flashinfer { | ||
|
||
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant> | ||
cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, | ||
typename AttentionVariant::DTypeO* tmp_v, | ||
float* tmp_s, cudaStream_t stream); | ||
|
||
} // namespace flashinfer | ||
|
||
std::vector<int64_t> BatchDecodeWithPagedKVCachePlan( | ||
bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, | ||
torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, | ||
torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, | ||
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, | ||
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) { | ||
size_t float_workspace_size_in_bytes = | ||
float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); | ||
size_t int_workspace_size_in_bytes = | ||
int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); | ||
auto device = float_workspace_buffer.device(); | ||
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); | ||
indptr = indptr.to(torch::kCPU); | ||
|
||
DecodePlanInfo plan_info; | ||
|
||
using IdType = int32_t; | ||
// check indptr has idtype int32 | ||
TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32"); | ||
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; | ||
|
||
auto q_scalar_type = empty_q_data.scalar_type(); | ||
auto kv_scalar_type = empty_kv_data.scalar_type(); | ||
|
||
DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { | ||
using DTypeQ = q_type; | ||
using DTypeKV = kv_type; | ||
using DTypeO = DTypeQ; | ||
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { | ||
return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { | ||
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>; | ||
using AttentionVariant = | ||
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false, | ||
/*use_sliding_window=*/true, | ||
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>; | ||
|
||
cudaError_t status = DecodePlan<HEAD_DIM, POS_ENCODING_MODE, AttentionVariant>( | ||
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, | ||
static_cast<void*>(int_workspace_buffer.data_ptr()), | ||
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()), | ||
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()), | ||
batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph, | ||
/*stream=*/torch_current_stream); | ||
|
||
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", | ||
cudaGetErrorString(status)); | ||
return true; | ||
}); | ||
}); | ||
}); | ||
|
||
return plan_info.ToVector(); | ||
} | ||
|
||
std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun( | ||
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, | ||
std::vector<int64_t> plan_info_vec, torch::Tensor q, | ||
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache, | ||
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr, | ||
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, | ||
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left, | ||
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { | ||
DecodePlanInfo plan_info; | ||
plan_info.FromVector(plan_info_vec); | ||
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code); | ||
bool paged_kv_defined = paged_kv_cache.has_value(); | ||
auto device = q.device(); | ||
int64_t batch_size = q.size(0); | ||
int64_t num_qo_heads = q.size(1); | ||
int64_t num_kv_heads, page_size; | ||
if (paged_kv_defined) { | ||
if (kv_layout == QKVLayout::kHND) { | ||
num_kv_heads = paged_kv_cache->size(2); | ||
page_size = paged_kv_cache->size(3); | ||
} else { | ||
page_size = paged_kv_cache->size(2); | ||
num_kv_heads = paged_kv_cache->size(3); | ||
} | ||
} else { | ||
if (kv_layout == QKVLayout::kHND) { | ||
num_kv_heads = paged_k_cache->size(1); | ||
page_size = paged_k_cache->size(2); | ||
} else { | ||
page_size = paged_k_cache->size(1); | ||
num_kv_heads = paged_k_cache->size(2); | ||
} | ||
} | ||
uint32_t head_dim = q.size(2); | ||
|
||
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); | ||
torch::Tensor o = torch::empty_like(q); | ||
torch::Tensor lse; | ||
if (return_lse) { | ||
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); | ||
} | ||
|
||
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); | ||
|
||
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr()); | ||
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr()); | ||
|
||
using IdType = int32_t; | ||
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; | ||
|
||
// get q_scalar_type and kv_scalar_type | ||
auto q_scalar_type = q.scalar_type(); | ||
auto kv_scalar_type = | ||
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); | ||
|
||
DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { | ||
using DTypeQ = q_type; | ||
using DTypeKV = kv_type; | ||
using DTypeO = DTypeQ; | ||
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { | ||
return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { | ||
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>; | ||
using AttentionVariant = | ||
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false, | ||
/*use_sliding_window=*/true, | ||
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>; | ||
|
||
paged_kv_t<DTypeKV, IdType> paged_kv( | ||
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, | ||
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() | ||
: nullptr), | ||
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), | ||
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), | ||
static_cast<IdType*>(paged_kv_indices.data_ptr()), | ||
static_cast<IdType*>(paged_kv_indptr.data_ptr()), | ||
static_cast<IdType*>(paged_kv_last_page_len.data_ptr())); | ||
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()), | ||
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()), | ||
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), | ||
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, | ||
sm_scale, rope_scale, rope_theta); | ||
|
||
DTypeO* tmp_v = nullptr; | ||
float* tmp_s = nullptr; | ||
params.request_indices = | ||
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset); | ||
params.kv_tile_indices = | ||
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset); | ||
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset); | ||
params.kv_chunk_size_ptr = | ||
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset); | ||
if (plan_info.split_kv) { | ||
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset); | ||
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset); | ||
if (plan_info.enable_cuda_graph) { | ||
params.block_valid_mask = | ||
GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset); | ||
} | ||
} | ||
params.padded_batch_size = plan_info.padded_batch_size; | ||
|
||
cudaError_t status = | ||
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, | ||
AttentionVariant>( | ||
params, tmp_v, tmp_s, /*stream=*/torch_current_stream); | ||
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", | ||
cudaGetErrorString(status)); | ||
return true; | ||
}); | ||
}); | ||
}); | ||
|
||
if (return_lse) { | ||
return {o, lse}; | ||
} else { | ||
return {o}; | ||
} | ||
} |
Oops, something went wrong.