Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parallelism in RoPE with pos_ids #609

Merged
merged 1 commit into from
Nov 14, 2024

Conversation

nandor
Copy link
Contributor

@nandor nandor commented Nov 13, 2024

The previous kernel was not parallelised sufficiently well for low batch sizes. Similarly to the regular rotary kernel, now all qo/kv heads are split across separate blocks.

In decode mode, the pos_ids kernel is now faster.

The previous kernel was not parallelised sufficiently well for low batch sizes.
Similarly to the regular rotary kernel, now all qo/kv heads are split across separate blocks.

In decode mode, the pos_ids kernel is now faster.
@yzh119
Copy link
Collaborator

yzh119 commented Nov 13, 2024

Hi @nandor we use such parallelism mainly to save sin/cos computation time (same sin/cos can be reused for multiple heads).
I expect using different threadblock for different heads will be faster for small batch size.

Would you mind running https://github.com/flashinfer-ai/flashinfer/blob/32d9510d67187f1f3a379cce81302cdd15a557d2/benchmarks/bench_rope.py ?

@nandor
Copy link
Contributor Author

nandor commented Nov 13, 2024

You are right - saving sin and cos across 2-8 heads does yield a small speedup. But the finer-grained computation is significantly faster on an H100 already.

Unfortunately this sort of batching is a bit more convoluted to implement in CUDA than triton and internally we'll be relying on a Triton kernel instead.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good, thank you!

this sort of batching is a bit more convoluted to implement in CUDA than triton and internally we'll be relying on a Triton kernel instead.

Yes I agree, we plan to port most of the kernels (except for sampling and attention) to triton in v0.2.1 :)

@yzh119 yzh119 merged commit ff05155 into flashinfer-ai:main Nov 14, 2024
james-p-xu added a commit to james-p-xu/sglang that referenced this pull request Nov 16, 2024
flashinfer-ai/flashinfer#609 potentially introduces correctness issues
@james-p-xu
Copy link

I'm actually seeing that this change causes a correctness issue wrt apply_rope_pos_ids.

Here's a sample comparison script, passing prior to this commit hash (32d9510) but failing post-change: https://github.com/sgl-project/sglang/blob/dd0d2a3af4967880362e3bad9d95cd14572c89ea/scripts/playground/compare_flashinfer_vllm_rope.py

@yzh119
Copy link
Collaborator

yzh119 commented Nov 16, 2024

@james-p-xu I'll fix it, thank you!

yzh119 added a commit that referenced this pull request Nov 20, 2024
yzh119 added a commit that referenced this pull request Nov 20, 2024
As observed by @james-p-xu, #609 produce wrong results for some input
shapes, this PR fixes the correctness issue, and add optimizations of
dispatching to different parallelism modes for different input shapes.

For large shape inputs, using the original implementation (re-use
sin/cos for different heads) will be better.
For small shape inputs, using head parallelism will be better.

Some results:
```
Before #609 (no head-parallelism, re-use sin/cos value)
-----------------
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:   0.762GB/s
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 22us, throughput:   0.919GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  95.699GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 28us, throughput:  95.244GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 31us, throughput: 670.254GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 31us, throughput: 667.253GB/s
---
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  14.490GB/s
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  14.466GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 37us, throughput: 1344.086GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 37us, throughput: 1344.902GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 148us, throughput: 2699.475GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 147us, throughput: 2701.897GB/s
---
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  74.322GB/s
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  74.568GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 110us, throughput: 2352.352GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 110us, throughput: 2365.580GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 718us, throughput: 2893.608GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 717us, throughput: 2894.859GB/s
---
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  95.373GB/s
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  95.810GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 130us, throughput: 2583.872GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 129us, throughput: 2595.944GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 923us, throughput: 2907.408GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 924us, throughput: 2905.533GB/s

Head parallelism only (no dispatch)
---------------------
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  6us, throughput:   3.321GB/s
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  6us, throughput:   3.391GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  7us, throughput: 358.862GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  7us, throughput: 362.361GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 15us, throughput: 1413.175GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 15us, throughput: 1437.332GB/s
---
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  6us, throughput:  60.526GB/s
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  6us, throughput:  60.127GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 26us, throughput: 1897.923GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 24us, throughput: 2050.075GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 164us, throughput: 2431.650GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 147us, throughput: 2709.333GB/s
---
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  7us, throughput: 284.641GB/s
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  7us, throughput: 302.815GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 109us, throughput: 2391.712GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 97us, throughput: 2671.150GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 860us, throughput: 2413.211GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 828us, throughput: 2508.817GB/s
---
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  7us, throughput: 349.795GB/s
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  7us, throughput: 376.624GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 139us, throughput: 2413.690GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 124us, throughput: 2705.994GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 1110us, throughput: 2417.480GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 1063us, throughput: 2525.976GB/s

This PR (shape dispatch)
---------------------
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 28us, throughput:   0.728GB/s
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  6us, throughput:   3.451GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  7us, throughput: 359.759GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  7us, throughput: 361.286GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 15us, throughput: 1426.267GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 15us, throughput: 1433.691GB/s
---
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  6us, throughput:  60.390GB/s
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  6us, throughput:  59.937GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 26us, throughput: 1892.575GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 24us, throughput: 2049.735GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 148us, throughput: 2698.780GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 147us, throughput: 2701.558GB/s
---
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  7us, throughput: 285.335GB/s
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  7us, throughput: 303.373GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 110us, throughput: 2351.126GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 110us, throughput: 2362.898GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 717us, throughput: 2893.713GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 717us, throughput: 2894.902GB/s
---
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency:  7us, throughput: 350.720GB/s
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency:  7us, throughput: 376.690GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 130us, throughput: 2584.221GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 129us, throughput: 2596.612GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.480GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 924us, throughput: 2905.134GB/s
```

cc @nandor @james-p-xu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants