diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index 9ab9a86c..5cf365ca 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -83,29 +83,23 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); - -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta); - -std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); + +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); + +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); + +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index ef046ead..c6259968 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -17,29 +17,23 @@ #include -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); -std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta); +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); -std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_rope", &apply_rope, "Apply RoPE"); diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index d2ca9155..8f661da0 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -19,10 +19,9 @@ using namespace flashinfer; -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -65,14 +64,11 @@ std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::T std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } -std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta) { +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -109,16 +105,12 @@ std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length) { +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -162,16 +154,12 @@ std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } -std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length) { +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -209,6 +197,4 @@ std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Te std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 29c2fcb7..60ca2c6e 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -42,7 +42,165 @@ def get_rope_module(): return _rope_module -@register_custom_op("flashinfer::apply_rope_inplace", mutates_args=("q", "k")) +@register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope")) +def _apply_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + get_rope_module().apply_rope( + q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope") +def _fake_apply_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + pass + + +@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=("q_rope", "k_rope")) +def _apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + get_rope_module().apply_llama31_rope( + q, + k, + q_rope, + k_rope, + indptr, + offsets, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + old_context_len, + ) + + +@register_fake_op("flashinfer::apply_llama31_rope") +def _fake_apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + pass + + +@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=("q_rope", "k_rope")) +def _apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + get_rope_module().apply_rope_pos_ids( + q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids") +def _fake_apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + pass + + +@register_custom_op( + "flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope") +) +def _apply_llama31_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + get_rope_module().apply_llama31_rope_pos_ids( + q, + k, + q_rope, + k_rope, + pos_ids, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + old_context_len, + ) + + +@register_fake_op("flashinfer::apply_llama31_rope_pos_ids") +def _fake_apply_llama31_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + pass + + def apply_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -118,25 +276,9 @@ def apply_rope_inplace( -------- apply_rope """ - get_rope_module().apply_rope( - q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta - ) - - -@register_fake_op("flashinfer::apply_rope_inplace") -def _fake_apply_rope_inplace( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> None: - pass + _apply_rope(q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta) -@register_custom_op("flashinfer::apply_rope_pos_ids_inplace", mutates_args=("q", "k")) def apply_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, @@ -183,24 +325,9 @@ def apply_rope_pos_ids_inplace( -------- apply_rope_pos_ids """ - get_rope_module().apply_rope_pos_ids( - q, k, q, k, pos_ids, interleave, rope_scale, rope_theta - ) + _apply_rope_pos_ids(q, k, q, k, pos_ids, interleave, rope_scale, rope_theta) -@register_fake_op("flashinfer::apply_rope_pos_ids_inplace") -def _fake_apply_rope_pos_ids_inplace( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> None: - pass - - -@register_custom_op("flashinfer::apply_llama31_rope_inplace", mutates_args=("q", "k")) def apply_llama31_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -286,7 +413,7 @@ def apply_llama31_rope_inplace( -------- apply_llama31_rope """ - get_rope_module().apply_llama31_rope( + _apply_llama31_rope( q, k, q, @@ -302,12 +429,10 @@ def apply_llama31_rope_inplace( ) -@register_fake_op("flashinfer::apply_llama31_rope_inplace") -def _fake_apply_llama31_rope_inplace( +def apply_llama31_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + pos_ids: torch.Tensor, interleave: bool = True, rope_scale: float = 8, rope_theta: float = 5e5, @@ -315,10 +440,66 @@ def _fake_apply_llama31_rope_inplace( high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: - pass + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``5e5``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default: ``8192``. + + See Also + -------- + apply_llama31_rope_pos_ids + """ + _apply_llama31_rope_pos_ids( + q, + k, + q, + k, + pos_ids, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) -@register_custom_op("flashinfer::apply_rope", mutates_args=()) def apply_rope( q: torch.Tensor, k: torch.Tensor, @@ -407,25 +588,12 @@ def apply_rope( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) - return get_rope_module().apply_rope( + _apply_rope( q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta ) + return q_rope, k_rope -@register_fake_op("flashinfer::apply_rope") -def _fake_apply_rope( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) - - -@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=()) def apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, @@ -481,24 +649,12 @@ def apply_rope_pos_ids( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) - return get_rope_module().apply_rope_pos_ids( + _apply_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta ) + return q_rope, k_rope -@register_fake_op("flashinfer::apply_rope_pos_ids") -def _fake_apply_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) - - -@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=()) def apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, @@ -597,7 +753,7 @@ def apply_llama31_rope( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) - return get_rope_module().apply_llama31_rope( + _apply_llama31_rope( q, k, q_rope, @@ -611,14 +767,13 @@ def apply_llama31_rope( high_freq_factor, float(old_context_len), ) + return q_rope, k_rope -@register_fake_op("flashinfer::apply_llama31_rope") -def _fake_apply_llama31_rope( +def apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + pos_ids: torch.Tensor, interleave: bool = True, rope_scale: float = 8, rope_theta: float = 5e5, @@ -626,4 +781,70 @@ def _fake_apply_llama31_rope( high_freq_factor: float = 4, old_context_len: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor). + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])`` + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``5e5``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default: ``8192``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + + See Also + -------- + apply_llama31_rope_pos_ids_inplace + """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) + _apply_llama31_rope_pos_ids( + q, + k, + q_rope, + k_rope, + pos_ids, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) + return q_rope, k_rope diff --git a/tests/conftest.py b/tests/conftest.py index 95738ddb..08697065 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,8 +32,12 @@ flashinfer.quantization.packbits, flashinfer.rope.apply_rope, flashinfer.rope.apply_rope_inplace, + flashinfer.rope.apply_rope_pos_ids, + flashinfer.rope.apply_rope_pos_ids_inplace, flashinfer.rope.apply_llama31_rope, flashinfer.rope.apply_llama31_rope_inplace, + flashinfer.rope.apply_llama31_rope_pos_ids, + flashinfer.rope.apply_llama31_rope_pos_ids_inplace, flashinfer.sampling.sampling_from_probs, flashinfer.sampling.top_p_sampling_from_probs, flashinfer.sampling.top_k_sampling_from_probs,