-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve parallelism in RoPE with pos_ids #609
Conversation
The previous kernel was not parallelised sufficiently well for low batch sizes. Similarly to the regular rotary kernel, now all qo/kv heads are split across separate blocks. In decode mode, the pos_ids kernel is now faster.
Hi @nandor we use such parallelism mainly to save sin/cos computation time (same sin/cos can be reused for multiple heads). Would you mind running https://github.com/flashinfer-ai/flashinfer/blob/32d9510d67187f1f3a379cce81302cdd15a557d2/benchmarks/bench_rope.py ? |
You are right - saving Unfortunately this sort of batching is a bit more convoluted to implement in CUDA than triton and internally we'll be relying on a Triton kernel instead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sound good, thank you!
this sort of batching is a bit more convoluted to implement in CUDA than triton and internally we'll be relying on a Triton kernel instead.
Yes I agree, we plan to port most of the kernels (except for sampling and attention) to triton in v0.2.1 :)
flashinfer-ai/flashinfer#609 potentially introduces correctness issues
I'm actually seeing that this change causes a correctness issue wrt apply_rope_pos_ids. Here's a sample comparison script, passing prior to this commit hash (32d9510) but failing post-change: https://github.com/sgl-project/sglang/blob/dd0d2a3af4967880362e3bad9d95cd14572c89ea/scripts/playground/compare_flashinfer_vllm_rope.py |
@james-p-xu I'll fix it, thank you! |
As observed by @james-p-xu, #609 produce wrong results for some input shapes, this PR fixes the correctness issue, and add optimizations of dispatching to different parallelism modes for different input shapes. For large shape inputs, using the original implementation (re-use sin/cos for different heads) will be better. For small shape inputs, using head parallelism will be better. Some results: ``` Before #609 (no head-parallelism, re-use sin/cos value) ----------------- batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 0.762GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 22us, throughput: 0.919GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.699GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 28us, throughput: 95.244GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 31us, throughput: 670.254GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 31us, throughput: 667.253GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 14.490GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 14.466GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 37us, throughput: 1344.086GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 37us, throughput: 1344.902GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2699.475GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 147us, throughput: 2701.897GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 74.322GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 74.568GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2352.352GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2365.580GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 718us, throughput: 2893.608GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 717us, throughput: 2894.859GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.373GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.810GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2583.872GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 129us, throughput: 2595.944GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 923us, throughput: 2907.408GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 924us, throughput: 2905.533GB/s Head parallelism only (no dispatch) --------------------- batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 6us, throughput: 3.321GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 3.391GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 358.862GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 362.361GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 15us, throughput: 1413.175GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 15us, throughput: 1437.332GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 6us, throughput: 60.526GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 60.127GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 26us, throughput: 1897.923GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 24us, throughput: 2050.075GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 164us, throughput: 2431.650GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 147us, throughput: 2709.333GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 284.641GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 302.815GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 109us, throughput: 2391.712GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 97us, throughput: 2671.150GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 860us, throughput: 2413.211GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 828us, throughput: 2508.817GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 349.795GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 376.624GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 139us, throughput: 2413.690GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 124us, throughput: 2705.994GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 1110us, throughput: 2417.480GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 1063us, throughput: 2525.976GB/s This PR (shape dispatch) --------------------- batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 28us, throughput: 0.728GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 3.451GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 359.759GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 361.286GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 15us, throughput: 1426.267GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 15us, throughput: 1433.691GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 6us, throughput: 60.390GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 59.937GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 26us, throughput: 1892.575GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 24us, throughput: 2049.735GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2698.780GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 147us, throughput: 2701.558GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 285.335GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 303.373GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2351.126GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2362.898GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 717us, throughput: 2893.713GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 717us, throughput: 2894.902GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 350.720GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 376.690GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2584.221GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 129us, throughput: 2596.612GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.480GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 924us, throughput: 2905.134GB/s ``` cc @nandor @james-p-xu
The previous kernel was not parallelised sufficiently well for low batch sizes. Similarly to the regular rotary kernel, now all qo/kv heads are split across separate blocks.
In decode mode, the pos_ids kernel is now faster.