Skip to content

Commit

Permalink
feat: support cached cos/sin in rope APIs (#585)
Browse files Browse the repository at this point in the history
As requested in #530 , this PR implements the RoPE with cached cos/sin
embeddings, which is more flexible in some use cases.

In our previous RoPE implementations, cos/sin values are computed
on-the-fly inside kernels with float32 instead using cached values.

In this PR we found that if we use f16 cos/sin cache, the rope result
will have a large discrepancy compared to our original implementation
`flashinfer.apply_rope` (which stores cos/sin with fp32). So we require
the `cos_cache` and `sin_cache` to use fp32 data type.

cc @dreaming-panda @ByronHsu
  • Loading branch information
yzh119 authored Nov 5, 2024
1 parent 7557dc8 commit 83e541d
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 208 deletions.
3 changes: 2 additions & 1 deletion benchmarks/bench_append_paged_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import dataclasses
from typing import cast

import flashinfer
import torch
from triton.testing import do_bench

import flashinfer


@dataclasses.dataclass(kw_only=True)
class ModelConfig:
Expand Down
92 changes: 92 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cmath>
#include <cstdint>
#include <iostream>
#include <string>

#include "layout.cuh"
Expand Down Expand Up @@ -156,6 +157,55 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_i
return vec;
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_cache,
float* __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz,
uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h,
size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h,
size_t k_rope_stride_n, size_t k_rope_stride_h) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;

vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];

cos.load(cos_cache + pos * head_dim + tx * vec_size);
sin.load(sin_cache + pos * head_dim + tx * vec_size);

#pragma unroll 1
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr =
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}

#pragma unroll 1
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr =
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
}
}
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsKernel(
Expand Down Expand Up @@ -309,6 +359,48 @@ __global__ void BatchQKApplyRotaryKernel(
__VA_ARGS__ \
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_cache, float* sin_cache,
IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
bool interleave, cudaStream_t stream = nullptr) {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks((nnz + bdy - 1) / bdy);
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&q_rope,
(void*)&k_rope,
(void*)&cos_cache,
(void*)&sin_cache,
(void*)&pos_ids,
(void*)&nnz,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&q_rope_stride_n,
(void*)&q_rope_stride_h,
(void*)&k_rope_stride_n,
(void*)&k_rope_stride_h};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope,
IdType* __restrict__ pos_ids, uint32_t nnz,
Expand Down
7 changes: 7 additions & 0 deletions python/csrc/flashinfer_rope_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor cos_cache,
torch::Tensor sin_cache, torch::Tensor pos_ids,
bool interleave);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope", &apply_rope, "Apply RoPE");
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
"Apply Llama 3.1 style RoPE with positional ids");
m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache,
"Apply RoPE with positional ids and cosine/sine cache");
}
62 changes: 58 additions & 4 deletions python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ using namespace flashinfer;
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

Expand Down Expand Up @@ -69,8 +69,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
CHECK_INPUT(pos_ids);

auto device = q.device();
Expand Down Expand Up @@ -107,6 +107,60 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
});
}

void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor cos_cache,
torch::Tensor sin_cache, torch::Tensor pos_ids,
bool interleave) {
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
CHECK_INPUT(cos_cache);
CHECK_INPUT(sin_cache);
CHECK_INPUT(pos_ids);
auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_EQ(cos_cache.device(), device);
CHECK_EQ(sin_cache.device(), device);
CHECK_EQ(pos_ids.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(2, cos_cache); // cos_cache: (max_seq_len, D)
CHECK_DIM(2, sin_cache); // sin_cache: (max_seq_len, D)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
CHECK_EQ(cos_cache.size(1), q.size(2));
CHECK_EQ(sin_cache.size(1), q.size(2));
CHECK_EQ(cos_cache.dtype(), torch::kFloat32);
CHECK_EQ(sin_cache.dtype(), torch::kFloat32);
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int nnz = q.size(0);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
size_t q_rope_stride_n = q_rope.stride(0);
size_t q_rope_stride_h = q_rope.stride(1);
size_t k_rope_stride_n = k_rope.stride(0);
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_cache.data_ptr()), static_cast<float*>(sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim,
q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h,
k_rope_stride_n, k_rope_stride_h, interleave, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}

void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
Expand Down
8 changes: 8 additions & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@
from .quantization import segment_packbits as segment_packbits
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope as apply_rope
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
)
from .sampling import chain_speculative_sampling as chain_speculative_sampling
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
from .sampling import sampling_from_probs as sampling_from_probs
Expand Down
Loading

0 comments on commit 83e541d

Please sign in to comment.