Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Add config option to use cudnn flash attention #73

Merged
merged 1 commit into from
Jul 17, 2024

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Mar 22, 2024

This PR is to allow users to enable the cudnn flash attention. The PR depends on google/praxis#53.

The preliminary results for the GPT3-5B, we can observe ~30% perf improve on 8xH100 GPUs.

With this PR, users can simply set USE_CUDNN_FLASH_ATTENTION=True in their config and then the attention part will be replaced with the cudnn flash attention.

cc. @nluehr @zhangqiaorjc

@kaixih kaixih force-pushed the cudnn_attention_dev branch from b3d08ad to e3c78c2 Compare April 1, 2024 18:34
@kaixih kaixih force-pushed the cudnn_attention_dev branch from e3c78c2 to a845a6e Compare April 8, 2024 18:47
@kaixih kaixih force-pushed the cudnn_attention_dev branch from a845a6e to 4ff92cb Compare July 8, 2024 20:37
@kaixih
Copy link
Contributor Author

kaixih commented Jul 8, 2024

The sdpa is now in the jax public API (see this PR) and we can use it through this custom praxis layer in this PR.

Then, this PR introduced a fiddle config option: USE_CUDNN_FLASH_ATTENTION to turn it on.

cc. @abhinavgoel95 for viz.

@zhangqiaorjc zhangqiaorjc added the pull ready Used to import PR as CL label Jul 10, 2024
@kaixih
Copy link
Contributor Author

kaixih commented Jul 15, 2024

Gentle ping. @zhangqiaorjc

@copybara-service copybara-service bot merged commit 9a061ee into google:main Jul 17, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Used to import PR as CL
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants