-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FA3 kvcache + split kv + gqa parallelization #1236
Conversation
hopper/flash_attn_interface.py
Outdated
@@ -174,7 +175,8 @@ def forward( | |||
causal, | |||
descale_q=descale_q, | |||
descale_k=descale_k, | |||
descale_v=descale_v, | |||
descale_v=descale_v, | |||
gqa_decoding=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder does it make sense to give user an option to enable GQA optimization for general use cases outside of decoding?
e.g. It's generally useful for small seq_len prefill. In this case we don't really need split-kv, but we want to have each threadblock handle multiple Q heads with the same KV head.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Furthermore, does it make sense to just enable GQA optimization by default when input is GQA? I feel it won't cause perf regressions even for long sequence length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it might slow things down a bit, but I haven't tried
KV cache functionality not added yet.
…ded template params
…ly matters for fp8 support
…sion using smem boolean
This PR adds split KV ("Flash decoding") and GQA parallelization improvements for FA3. Some essential parts of the KV cache API are added as well, including the
cache_seqlens
andcache_batch_idx
arguments.Up to 15x improvement over FA2 measured on my H100 PCIe in exceptional cases, e.g.
Times given in microseconds. GB/s is measured in terms of loading the KV cache. Note that theoretical max bandwidth is 2 TB/s for H100 PCIe.
TODO on this PR before merge: add split kv heuristic, implement for FP8.
fa3-decoding-times-091724.log