-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Add SparseAttention kernel for sm=75 #20531
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
wangyems
approved these changes
May 2, 2024
onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/compile_sparse_attention_v2.py
Show resolved
Hide resolved
kunal-vaishnavi
approved these changes
May 2, 2024
yihonglyu
pushed a commit
that referenced
this pull request
May 4, 2024
### Description Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
TedThemistokleous
pushed a commit
to TedThemistokleous/onnxruntime
that referenced
this pull request
May 7, 2024
### Description Follow up of microsoft#20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
poweiw
pushed a commit
to poweiw/onnxruntime
that referenced
this pull request
Jun 25, 2024
### Description Follow up of microsoft#20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
cherry-picked
Cherry-picked for a cherrypicks branch
rel-merged
Cherrypicks merged into release
release:1.18.0
triage:approved
Approved for cherrypicks for release
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc)
For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request.
Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with
batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8
We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster).
Motivation and Context
To inference Phi-3-small in T4 GPU