Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed-precision with torch.autocast is broken for many models when using attn_implementation="flash_attention_2" #35945

Open
konstantinjdobler opened this issue Jan 28, 2025 · 6 comments
Labels

Comments

@konstantinjdobler
Copy link
Contributor

konstantinjdobler commented Jan 28, 2025

System Info

transformers==4.48.1
python=3.11.11

Who can help?

@ArthurZucker

Expected behavior

Mixed-precision training via torch.autocast is broken for most models inspired by the HF Llama code (which is a lot of models) when using attn_implementation="flash_attention_2 and potentially not working as intended in general.

Snippet to reproduce on transformers==4.48.1:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_default_device("cuda:0")
model_name = "meta-llama/Llama-3.2-3B" # many others, e.g. "allenai/OLMo-2-1124-7B"
inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
model = AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)

with torch.autocast("cuda", dtype=torch.bfloat16):
	# errors with -> RuntimeError: FlashAttention only support fp16 and bf16 data type
	outputs = model(**inputs)

Indeed after calling AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32), we get

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. ...`

but the snippet fails even though we do use torch.autocast as suggested. For mixed-precision training, we do actually want to load the weights in float32.

Concretely, the source is two different issues:

  • The common llama-inspired implementation of RMSNorm silently upcasts to float32 even within an autocast context, which is propagated from the q_norm / v_norm up until passing the projections to the attention function (FA2 fails here)
    • The RMSNorm issue has been discussed in these related issues: here and here and here
    • We have this comment in the FlashAttention integration about RMSNorm usually handling silent upcasting correctly but it seems at some point this broke:
      # In PEFT, usually we cast the layer norms in float32 for training stability reasons
      # therefore the input hidden states gets silently casted in float32. Hence, we need
      # cast them back in the correct dtype just to be sure everything works as expected.
      # This might slowdown training & inference so it is recommended to not cast the LayerNorms
      # in fp32. (usually our RMSNorm modules handle it correctly)
  • The cos and sin position embeddings for RoPE are in float32 even within an autocast context, which will again silently upcast the query_states/key_states to float32 before passing to the attention function:
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

One fix is to remove the silent upcasting of the output to float32 in RMSNorm if the input is bfloat16 and directly casting cos and sin to the torch.get_autocast_dtype if in autocast.
In this discussion it seems that this might come with some issues so there might have to be some different solution (I am not quite sure of the exact reasons though for potential issues).

It's important to note that through all this silent upcasting, we're probably (I haven't benchmarked though) using a lot of extra memory when doing mixed-precision training (regardless of whether we use attn_implementation="flash_attention_2" or not).

@konstantinjdobler
Copy link
Contributor Author

def patch_broken_autocast_llama():
    class FixedLlamaRMSNorm(LlamaRMSNorm):
        def forward(self, hidden_states):
            input_dtype = hidden_states.dtype
            hidden_states = hidden_states.to(torch.float32)
            variance = hidden_states.pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            # fix silent upcasting of output (note that we keep the inner .to() to *only* cast the output and no inner computation)
            return (self.weight * hidden_states.to(input_dtype)).to(input_dtype)
            # return self.weight * hidden_states.to(input_dtype)

    transformers.models.llama.modeling_llama.LlamaRMSNorm = FixedLlamaRMSNorm

    class FixedLlamaRotaryEmbedding(LlamaRotaryEmbedding):
        @torch.no_grad()
        def forward(self, x, position_ids):
            if "dynamic" in self.rope_type:
                self._dynamic_frequency_update(position_ids, device=x.device)

            # Core RoPE block
            inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
            position_ids_expanded = position_ids[:, None, :].float()
            # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
            device_type = x.device.type
            device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                emb = torch.cat((freqs, freqs), dim=-1)
                cos = emb.cos()
                sin = emb.sin()

            # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
            cos = cos * self.attention_scaling
            sin = sin * self.attention_scaling

            # fix silent upcasting of output
            output_dtype = x.dtype
            if torch.is_autocast_enabled():
                output_dtype = torch.get_autocast_dtype(device_type=device_type)
            return cos.to(dtype=output_dtype), sin.to(dtype=output_dtype)
            # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

    transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding

Snippet applying a patch solution to this problem for Llama (applies 1-to-1 to many other models), running this before the error repro snippet solves the RuntimeError`. Happy to submit a PR but definitely needs some discussion on what you think is the correct way to solve this as I know these precision issues have been a problem in the past as well.

@konstantinjdobler
Copy link
Contributor Author

As a sidenote, is it intended that we always keep the residual stream in float32 even in mixed-precision training? because of these lines:

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

while the computation in self.self_attn and self.mlp are autocasted and their outputs are bfloat16, adding back onto the float32 hidden states silently upcasts to float32 -> the residual stream is never autocasted since the initial hidden_states at layer 0 is still float32 as it (usually) comes from a nn.Embedding (which is not autocasted by torch.autocast).

Can also open a separate issue if you prefer.

@Rocketknight1
Copy link
Member

gentle ping @ArthurZucker @Cyrilvallez

@ArthurZucker
Copy link
Collaborator

Hey! THis is not an issue:

In [1]: import torch
   ...: from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
   ...: torch.set_default_device("cuda:0")
   ...: model_name = "meta-llama/Llama-3.2-3B" # many others, e.g. "allenai/OLMo-2-1124-7B"
   ...: inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
   ...: model = AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)
   ...:
   ...: with torch.autocast("cuda", dtype=torch.bfloat16):
   ...: ^I# errors with -> RuntimeError: FlashAttention only support fp16 and bf16 data type
   ...: ^Ioutputs = model(**inputs)
   ...:
2025-02-13 16:35:49.063903: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-13 16:35:49.063954: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-13 16:35:49.065064: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-13 16:35:49.071083: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-02-13 16:35:49.866410: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:16<00:00,  8.50s/it]
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.

runs as expected, tho you do have a warning! it says that we casted to float32, then casted back

@konstantinjdobler
Copy link
Contributor Author

Hey @ArthurZucker, that's interesting -- I just tried the snippet again in a fresh environment using the latest versions of torch, transformers and flash-attn and still reproduce the error. This is my environment:

- `transformers` version: 4.48.3
- Platform: Linux-5.15.0-1042-nvidia-x86_64-with-glibc2.35
- Python version: 3.11.11
- Huggingface_hub version: 0.28.1
- Safetensors version: 0.5.2
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- GPU type: NVIDIA H100 80GB HBM3
- flash-attn version: 2.7.4.post1

@konstantinjdobler
Copy link
Contributor Author

konstantinjdobler commented Feb 14, 2025

@ArthurZucker ah, this is fixed in the nightly build of transformers on GitHub -- after pip install git+https://github.com/huggingface/transformers.git the error is gone! 🎉

There is still the issue (although it doesn't lead to an error) of the residual stream still being in float32 leading to increased memory usage (see my previous comment) -- if we inspect the dtype of hidden states, they will still all be in float32:

import torch
from transformers import AutoModel, AutoTokenizer
torch.set_default_device("cuda:0")
model_name = "meta-llama/Llama-3.2-3B"
model = AutoModel.from_pretrained(
    model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32, device_map="cuda"
)
inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")

with torch.autocast("cuda", dtype=torch.bfloat16):
    outputs = model(**inputs, output_hidden_states=True)

for i, hidden_state in enumerate(outputs.hidden_states):
    print(f"Hidden state {i} dtype: {hidden_state.dtype}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants