Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference #115663

Closed
wants to merge 52 commits into from

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented Dec 12, 2023

#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with TORCH_CUDNN_MHA_ENABLED=1.

CC @drisspg @ptrblck

Copy link

pytorch-bot bot commented Dec 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115663

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 3 Unrelated Failures

As of commit 4fc3337 with merge base 3ab0894 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy eqy added the topic: not user facing topic category label Dec 12, 2023
@eqy eqy marked this pull request as ready for review December 13, 2023 02:16
@albanD albanD requested a review from drisspg December 13, 2023 16:28
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 13, 2023
@eqy eqy added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Dec 13, 2023
@drisspg
Copy link
Contributor

drisspg commented Dec 18, 2023

Also feel free to ping me when you think I should do a review

@eqy
Copy link
Collaborator Author

eqy commented Jan 4, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnnmha3 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnnmha3 && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Jan 5, 2024

@drisspg sorry for the delay, I think this should be ready for review now

@@ -21,6 +21,15 @@ class CuDNNError : public c10::Error {

} // namespace c10

#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a result of frontend changing to 9.0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of just the "new" way that cuDNN-frontend has been doing error reporting and this is done in recent frontend versions that are supported in 8.9.x. It is a bit annoying in that there is now a mechanism for getting an error string and it is different from the previous convention of "return CUDNN_STATUS_SUCCESS."

// .set_dim({b, 1, s_q, s_kv})
// .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1}));
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another dumb Q: are these names meaningful? like do they need to match an api for calling the function?

Copy link
Collaborator Author

@eqy eqy Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a dumb question! The names are potentially meaningful in the future in that we would be able to reference I/O tensors by name rather than by holding onto an explicit reference as we currently do to supply the "variant pack" (the map used to specify the data pointers for each I/O tensor upon invocation). In either case they do not really mean anything to cuDNN and rather are for the user's (us) housekeeping.

if (cudnnGetVersion() >= 8904) {
//scaled_dot_product_flash_attention_options.set_alibi_mask(true);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only support certain bias types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need check with the cuDNN folks, this doc https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#flash-fused-multi-head-att-fprop doesn't seem to be very clear about this

cudnnHandle_t handle = getCudnnHandle();
o = at::empty_strided({b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options());
if (return_softmaxstats) {
// TODO(eqy): fix strides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this stride fix is only needed for backward support right? so is okay for this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is out-of-date, but I should verify rather than fix the strides here

@drisspg
Copy link
Contributor

drisspg commented Jan 9, 2024

One other random question, if I wanted to play around with this early, is there any specific build options I should be using? And any specific version of cudnn I need to have?

@eqy
Copy link
Collaborator Author

eqy commented Jan 9, 2024

One other random question, if I wanted to play around with this early, is there any specific build options I should be using? And any specific version of cudnn I need to have?

I don't think you would need any specific build options---cuDNN >= 8.9 would probably work best. The only other requirement should be cuDNN frontend 1.0, but that tag was updated in the cudnn-frontend submodule a while ago.

@eqy
Copy link
Collaborator Author

eqy commented Jan 10, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnnmha3 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnnmha3 && git pull --rebase)

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch pytorch deleted a comment from pytorch-bot bot Feb 14, 2024
@izaitsevfb
Copy link
Contributor

@pytorchbot merge -i

Landed internally as D53716382

@pytorchmergebot
Copy link
Collaborator

@Skylion007
Copy link
Collaborator

We should probably update our cudnn binaries as most of the recent changes has been improving this flash attention kernal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants