diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index ed0b732a..c52795f8 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -164,6 +164,7 @@ __global__ void BatchQKApplyRotaryPosIdsKernel( size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { + // NOTE: q and q_rope may be the same ptr, so do k and k_rope uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t freq; @@ -410,6 +411,21 @@ cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, return cudaSuccess; } +template +cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, + bool interleave, float rope_scale, float rope_theta, + cudaStream_t stream = nullptr) { + return BatchQKApplyRotary(q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, + head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, + interleave, rope_scale, rope_theta, stream); + +} + template cudaError_t BatchQKApplyLlama31Rotary( DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr,