Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Add a custom layer for cudnn flash attention #53

Merged
merged 1 commit into from
Jul 17, 2024

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Mar 22, 2024

This PR adds a new layer for the cudnn flash attention and so that users can config it into tr_atten_tpl of the Transformer layer.
This PR depends on this Jax PR: jax-ml/jax#20380

cc. @nluehr @zhangqiaorjc

@kaixih
Copy link
Contributor Author

kaixih commented Apr 8, 2024

FYI, the prerequisite JAX PR is now merged.

@kaixih kaixih force-pushed the cudnn_attention_dev branch from 61beeb9 to 79f9279 Compare April 8, 2024 18:47
@kaixih
Copy link
Contributor Author

kaixih commented Apr 10, 2024

Any update? It seems the PR is in pull ready state for a while. @zhangqiaorjc

@kaixih kaixih force-pushed the cudnn_attention_dev branch from 79f9279 to d8d440d Compare July 8, 2024 20:41
@kaixih
Copy link
Contributor Author

kaixih commented Jul 8, 2024

The sdpa is now in the jax public module (see PR). So, we modify this PR to use the public API.

cc. @zhangqiaorjc @abhinavgoel95

@zhangqiaorjc
Copy link
Member

@kaixih could you add a config option in MoE transformer that preserves the old behavior by default

i don't want to break existing checkpoints

@kaixih
Copy link
Contributor Author

kaixih commented Jul 10, 2024

Can you articulate why this PR breaks the default behavior of MoE transformer? I think this PR just added a new layer of attention and it will be used only when users config it to the model with USE_CUDNN_FLASH_ATTENTION from this PR.

@zhangqiaorjc
Copy link
Member

Sorry i meant to reply to the other PR #80

@kaixih
Copy link
Contributor Author

kaixih commented Jul 15, 2024

Gentle ping. @zhangqiaorjc

@copybara-service copybara-service bot merged commit 43fe564 into google:main Jul 17, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants