Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA] Lean Attention (microsoft#22352)
### Description Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the integration with MultiHeadAttention operator for LLM in GPU. LeanAttention speeds up self-attention for the token-generation phase (decode-phase) of decoder-only transformer models, especially on long context lengths. - [x] Initial implementation of Lean Attention (by Srikant Bharadwaj) - [x] Integration with MultiHeadAttention operator - [x] Add parity tests - [x] Add benchmark #### Implementation Details (1) Lean Attention is enabled in build for Linux, and disabled for Windows (2) Lean Attention is disabled by default. Need enable it through cuda provider option sdpa_kernel, or use environment variable `ORT_ENABLE_LEAN_ATTENTION=1` (3) It only works for token-generation (sequence_length==1, past_sequence_length > 0). (4) Like flash attention, it only works in Ampere or newer GPU. We can revisit #1 and #2 after comparing with DecoderMaskedMultiHeadAttention and XQA kernels. #### Benchmark ``` cd onnxruntime/test/python/transformers /bin/bash benchmark_mha.sh lean ``` Example outputs in H100: Note that past and present does not share buffer for MHA for now, so we can see low tflops. The relative ratio will change after buffer sharing is enabled. But we expect that the order (kernel A is faster than B) will remain the same after buffer sharing is enabled. Note that common settings `sequence_length=1; causal=True;attn_bias=None;cuda_graph=False` are not shown in the below table. batch_size | past_sequence_length | num_heads | head_size | average_latency | tflops | kernel -- | -- | -- | -- | -- | -- | -- 1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash 1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient 1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math 1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean 1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash 1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient 1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math 1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean 1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash 1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient 1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math 1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean 1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash 1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient 1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math 1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean 1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash 1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient 1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean 1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash 1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient 1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean 1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash 1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient 1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean 1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash 1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient 1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean 1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash 1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient 1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean 1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash 1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient 1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean 1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash 1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient 1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean 1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash 1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient 1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean 1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash 1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient 1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean 1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash 1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient 1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean 1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash 1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient 1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean 1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash 1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient 1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean 4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash 4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient 4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math 4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean 4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash 4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient 4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math 4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean 4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash 4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient 4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math 4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean 4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash 4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient 4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math 4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean 4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash 4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient 4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean 4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash 4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient 4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean 4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash 4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient 4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean 4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash 4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient 4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean 4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash 4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient 4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean 4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash 4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient 4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean 4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash 4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient 4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean 4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash 4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient 4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean 4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash 4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient 4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean 4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash 4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient 4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean 4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash 4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient 4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean 4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash 4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient 4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean 16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash 16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient 16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math 16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean 16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash 16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient 16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math 16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean 16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash 16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient 16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math 16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean 16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash 16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient 16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math 16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean 16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash 16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient 16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean 16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash 16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient 16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean 16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash 16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient 16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean 16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash 16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient 16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean 16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash 16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient 16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean 16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash 16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient 16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean 16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash 16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient 16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean 16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash 16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient 16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean 16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash 16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient 16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean 16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash 16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient 16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean 16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash 16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient 16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean 16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash 16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient 16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information