-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch 2.5 CuDNN backend for SDPA NaN error #9768
Comments
I think this is a bug of pytorch. Before they fix we should use pytorch <2.5.0 |
Related? |
Yes, or disable the CuDNN attention backend for SDPA. |
Yes, let's look forward to the PyTorch team fixing this issue in future versions of PyTorch. |
i tried pytorch 2.4.1 also have this bug. |
Thanks so much for the detailed issue, once again. We have pinned Line 133 in c754318
However, I think this is likely fixed with Torch 2.5.1 as cuDNN backend isn't selected as the SDPA backend by default. Could you give this a try? |
Yes, Torch 2.5.1 puts the CuDNN backend as the lowest precedence in the backend list. I have tested torch 2.5.1, and it no longer reports NaN gradients aforementioned. |
Describe the bug
When using the recently released PyTorch 2.5, the default SDPA backend is CUDNN_ATTENTION. In the example's CogVideoX-lora training script, NaN gradients occur right at the first step. However, using other SDPA backends, such as FLASH_ATTENTION or EFFICIENT_ATTENTION, does not lead to NaN issues.
After some preliminary investigation, I found that this might be related to the transpose and reshape operations following the SDPA computation (see L1954).
some related issues and PRs:
pytorch/pytorch#134001
pytorch/pytorch#134031
pytorch/pytorch#138354
Furthermore, I discovered that other attention processors in attention_processor.py also utilize the same transpose and reshape operations, such as FluxAttnProcessor2_0, which could potentially lead to similar problems.
Reproduction
This issue can be reproduced by setting a breakpoint after gradient backward and then printing the gradients:
Change the default backend for SDPA to FLASH_ATTENTION or EFFICIENT_ATTENTION in attention_processor.py, and the NaN issue will not occur.
Considering that PyTorch 2.5 is currently the default version available for installation, this issue may require some attention.
Logs
No response
System Info
NVIDIA H20, 97871 MiB
NVIDIA H20, 97871 MiB
NVIDIA H20, 97871 MiB
NVIDIA H20, 97871 MiB
NVIDIA H20, 97871 MiB
NVIDIA H20, 97871 MiB
NVIDIA H20, 97871 MiB
Who can help?
@DN6 @a-r-r-o-w @yiyixuxu @sayakpaul
The text was updated successfully, but these errors were encountered: