Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch 2.5 CuDNN backend for SDPA NaN error #9768

Closed
wtyuan96 opened this issue Oct 25, 2024 · 7 comments
Closed

torch 2.5 CuDNN backend for SDPA NaN error #9768

wtyuan96 opened this issue Oct 25, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@wtyuan96
Copy link

Describe the bug

When using the recently released PyTorch 2.5, the default SDPA backend is CUDNN_ATTENTION. In the example's CogVideoX-lora training script, NaN gradients occur right at the first step. However, using other SDPA backends, such as FLASH_ATTENTION or EFFICIENT_ATTENTION, does not lead to NaN issues.

After some preliminary investigation, I found that this might be related to the transpose and reshape operations following the SDPA computation (see L1954).

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

some related issues and PRs:
pytorch/pytorch#134001
pytorch/pytorch#134031
pytorch/pytorch#138354

Furthermore, I discovered that other attention processors in attention_processor.py also utilize the same transpose and reshape operations, such as FluxAttnProcessor2_0, which could potentially lead to similar problems.

Reproduction

This issue can be reproduced by setting a breakpoint after gradient backward and then printing the gradients:

loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
accelerator.backward(loss)
print([[name, param.grad] for name, param in transformer.named_parameters() if param.requires_grad])

Change the default backend for SDPA to FLASH_ATTENTION or EFFICIENT_ATTENTION in attention_processor.py, and the NaN issue will not occur.

from torch.nn.attention import SDPBackend, sdpa_kernel                                                                                                                                             
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # or EFFICIENT_ATTENTION                                                                                                                                          
    hidden_states = F.scaled_dot_product_attention(                                                                                                                                                
        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False                                                                                                                
    )                                                                                                                                                                                              
                                                                                                                                                                                                   
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

Considering that PyTorch 2.5 is currently the default version available for installation, this issue may require some attention.

Logs

No response

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.10.134-010
  • Running on Google Colab?: No
  • Python version: 3.10.15
  • PyTorch version (GPU?): 2.5.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.1
  • Transformers version: 4.46.0
  • Accelerate version: 1.0.1
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @a-r-r-o-w @yiyixuxu @sayakpaul

@wtyuan96 wtyuan96 added the bug Something isn't working label Oct 25, 2024
@wangyanhui666
Copy link

I think this is a bug of pytorch.
they are working on this to fix this bug.
pytorch/pytorch#138354 (comment)

Before they fix we should use pytorch <2.5.0

@sayakpaul
Copy link
Member

Related?
#9704

@wtyuan96
Copy link
Author

I think this is a bug of pytorch. they are working on this to fix this bug. pytorch/pytorch#138354 (comment)

Before they fix we should use pytorch <2.5.0

Yes, or disable the CuDNN attention backend for SDPA.

@wtyuan96
Copy link
Author

Related? #9704

Yes, let's look forward to the PyTorch team fixing this issue in future versions of PyTorch.

@wangyanhui666
Copy link

wangyanhui666 commented Oct 26, 2024

i tried pytorch 2.4.1 also have this bug.
so maybe disable the CuDNN attention backend in training code is a good solution.

@sayakpaul
Copy link
Member

sayakpaul commented Nov 1, 2024

Thanks so much for the detailed issue, once again.

We have pinned torch version:

"torch>=1.4,<2.5.0",

However, I think this is likely fixed with Torch 2.5.1 as cuDNN backend isn't selected as the SDPA backend by default. Could you give this a try?

@wtyuan96
Copy link
Author

wtyuan96 commented Nov 2, 2024

Thanks so much for the detailed issue, once again.

We have pinned torch version:

"torch>=1.4,<2.5.0",

However, I think this is likely fixed with Torch 2.5.1 as cuDNN backend isn't selected as the SDPA backend by default. Could you give this a try?

Yes, Torch 2.5.1 puts the CuDNN backend as the lowest precedence in the backend list. I have tested torch 2.5.1, and it no longer reports NaN gradients aforementioned.

@wtyuan96 wtyuan96 closed this as completed Nov 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants