-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixed-precision with torch.autocast
is broken for many models when using attn_implementation="flash_attention_2"
#35945
Comments
def patch_broken_autocast_llama():
class FixedLlamaRMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# fix silent upcasting of output (note that we keep the inner .to() to *only* cast the output and no inner computation)
return (self.weight * hidden_states.to(input_dtype)).to(input_dtype)
# return self.weight * hidden_states.to(input_dtype)
transformers.models.llama.modeling_llama.LlamaRMSNorm = FixedLlamaRMSNorm
class FixedLlamaRotaryEmbedding(LlamaRotaryEmbedding):
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
# fix silent upcasting of output
output_dtype = x.dtype
if torch.is_autocast_enabled():
output_dtype = torch.get_autocast_dtype(device_type=device_type)
return cos.to(dtype=output_dtype), sin.to(dtype=output_dtype)
# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding Snippet applying a patch solution to this problem for |
As a sidenote, is it intended that we always keep the residual stream in transformers/src/transformers/models/llama/modeling_llama.py Lines 331 to 353 in ec7afad
while the computation in Can also open a separate issue if you prefer. |
gentle ping @ArthurZucker @Cyrilvallez |
Hey! THis is not an issue: In [1]: import torch
...: from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
...: torch.set_default_device("cuda:0")
...: model_name = "meta-llama/Llama-3.2-3B" # many others, e.g. "allenai/OLMo-2-1124-7B"
...: inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
...: model = AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)
...:
...: with torch.autocast("cuda", dtype=torch.bfloat16):
...: ^I# errors with -> RuntimeError: FlashAttention only support fp16 and bf16 data type
...: ^Ioutputs = model(**inputs)
...:
2025-02-13 16:35:49.063903: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-13 16:35:49.063954: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-13 16:35:49.065064: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-13 16:35:49.071083: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-02-13 16:35:49.866410: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:16<00:00, 8.50s/it]
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16. runs as expected, tho you do have a warning! it says that we casted to float32, then casted back |
Hey @ArthurZucker, that's interesting -- I just tried the snippet again in a fresh environment using the latest versions of
|
@ArthurZucker ah, this is fixed in the nightly build of There is still the issue (although it doesn't lead to an error) of the residual stream still being in import torch
from transformers import AutoModel, AutoTokenizer
torch.set_default_device("cuda:0")
model_name = "meta-llama/Llama-3.2-3B"
model = AutoModel.from_pretrained(
model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32, device_map="cuda"
)
inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs = model(**inputs, output_hidden_states=True)
for i, hidden_state in enumerate(outputs.hidden_states):
print(f"Hidden state {i} dtype: {hidden_state.dtype}") |
System Info
transformers==4.48.1
python=3.11.11
Who can help?
@ArthurZucker
Expected behavior
Mixed-precision training via
torch.autocast
is broken for most models inspired by the HF Llama code (which is a lot of models) when usingattn_implementation="flash_attention_2
and potentially not working as intended in general.Snippet to reproduce on
transformers==4.48.1
:Indeed after calling
AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)
, we getbut the snippet fails even though we do use
torch.autocast
as suggested. For mixed-precision training, we do actually want to load the weights infloat32
.Concretely, the source is two different issues:
float32
even within anautocast
context, which is propagated from theq_norm
/v_norm
up until passing the projections to the attention function (FA2 fails here)RMSNorm
usually handling silent upcasting correctly but it seems at some point this broke:transformers/src/transformers/integrations/flash_attention.py
Lines 32 to 36 in ec7afad
cos
andsin
position embeddings for RoPE are infloat32
even within anautocast
context, which will again silently upcast thequery_states
/key_states
tofloat32
before passing to the attention function:transformers/src/transformers/models/llama/modeling_llama.py
Line 275 in ec7afad
float32
is becausecos
andsin
are created in e.g.LlamaModel
:transformers/src/transformers/models/llama/modeling_llama.py
Line 571 in ec7afad
where the
hidden_states
come from thenn.Embedding
which is never autocasted bytorch.autocast
. So:transformers/src/transformers/models/llama/modeling_llama.py
Line 141 in ec7afad
does not work as intended (?) because the input at that point has not been autocasted yet.
One fix is to remove the silent upcasting of the output to
float32
inRMSNorm
if the input isbfloat16
and directly castingcos
andsin
to thetorch.get_autocast_dtype
if in autocast.In this discussion it seems that this might come with some issues so there might have to be some different solution (I am not quite sure of the exact reasons though for potential issues).
It's important to note that through all this silent upcasting, we're probably (I haven't benchmarked though) using a lot of extra memory when doing mixed-precision training (regardless of whether we use
attn_implementation="flash_attention_2"
or not).The text was updated successfully, but these errors were encountered: