forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA] Add SparseAttention kernel for sm=75 (microsoft#20531)
### Description Follow up of microsoft#20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
- Loading branch information
Showing
24 changed files
with
1,374 additions
and
300 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.