Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: fix the stride bug in page append #527

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ struct paged_kv_t {
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param k_data The start pointer of key cache, k_cache should be contiguous
* \param v_data The start pointer of value cache, v_cache should be contiguous
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
Expand All @@ -107,20 +107,19 @@ struct paged_kv_t {
}

/*!
* \brief Construct a paged key-value cache
* \brief Construct a paged key-value cache with custom kv-cache strides
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param k_data The start pointer of key cache, k_cache doesn't have to be contiguous
* \param v_data The start pointer of value cache, v_cache doesn't have to be contiguous
* \param kv_strides custom strides of each dimensions of k_data and v_data
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* k_data,
Expand Down
19 changes: 13 additions & 6 deletions python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
num_heads = paged_k_cache.size(2);
}

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

CHECK_EQ(append_key.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);
CHECK_EQ(append_value.size(1), num_heads);
Expand All @@ -79,12 +86,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
auto kv_scalar_dtype = paged_k_cache.scalar_type();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] {
paged_kv_t<c_type, int32_t> paged_kv(num_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_type*>(paged_k_cache.data_ptr()),
static_cast<c_type*>(paged_v_cache.data_ptr()),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
paged_kv_t<c_type, int32_t> paged_kv(
num_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_type*>(paged_k_cache.data_ptr()),
static_cast<c_type*>(paged_v_cache.data_ptr()), kv_cache_strides,
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status =
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
Expand Down