-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support vLLM-style rope #530
Comments
Hi @ByronHsu , thanks for your suggestions, I think 1 & 3 are easy to support. For 1, we can adding For 3, yes we can add another rope_dim field for partial rope. Can you give a concrete example of 2? |
Okay I think I understand 2 now, for example, if Is that true? |
Yes exactly! Thank you for the prompt response! All sounds good to me. One comment: Can we separate the new API from the current 4 flashinfer's rope functions and provide the exact same interface with vLLM? Several reasons:
Maybe we can call this def apply_rope_inplace_with_cache(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
... |
I did a global search and found |
Previously our rope apis assume the position indices of each request is contiguous, which is not appropriate for applications such as speculative decoding, this PR fixes the issue by supporting the huggingface transformer-style API which use `pos_ids` argument to specify positions. This PR implements parts of the feature of #530 , other requests are coming in later PRs. cc @dreaming-panda @abcdabcd987 @ByronHsu
As requested in #530 , this PR implements the RoPE with cached cos/sin embeddings, which is more flexible in some use cases. In our previous RoPE implementations, cos/sin values are computed on-the-fly inside kernels with float32 instead using cached values. In this PR we found that if we use f16 cos/sin cache, the rope result will have a large discrepancy compared to our original implementation `flashinfer.apply_rope` (which stores cos/sin with fp32). So we require the `cos_cache` and `sin_cache` to use fp32 data type. cc @dreaming-panda @ByronHsu
…599) This PR implements the final piece of #530 , so that we can partially apply rotary embedding to first head dimensions instead of entire head dimensions. We also add a simple benchmark for RoPE, below is the result on H100: ```python batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 23us, throughput: 0.876GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 26us, throughput: 0.801GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.735GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.639GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 31us, throughput: 672.889GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 32us, throughput: 662.972GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 14.559GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 14.435GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 37us, throughput: 1339.450GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 37us, throughput: 1340.399GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2696.563GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 148us, throughput: 2689.104GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 74.186GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 74.452GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2350.830GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2359.814GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 717us, throughput: 2895.389GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 718us, throughput: 2891.385GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.449GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.646GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2576.101GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 130us, throughput: 2582.447GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.154GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 925us, throughput: 2903.484GB/s ```
As part of SGLang Issue #1487, SGLang plans to move vLLM to optional dependencies and use flashinfer as the main dependency.
I am working on moving rope to flashinfer. My plan is to reuse most of the existing vllm rope but replace
ops.rotary_embedding
andops.batch_rotary_embedding
with flashinfer's kernel, which can be found here.However, I've noticed some gaps between the vLLM and flashinfer implementations:
positions
back tooffsets
+indptr
. Can we supportpositions
directly?In general, we can prioritize the first three issues and consider the fourth as a stretch goal.
The text was updated successfully, but these errors were encountered: