-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama inference instability in fp16 producing inf in the middle of the model #27179
Comments
inf
, nan
or element < 0
AutoGPTQ/AutoGPTQ#295
Related to #17937 but there is dummy model. Will take a look here. |
@ArthurZucker has been tracking it, and has a draft PR for it: #27114 @fxmarty Can you check if applying this change fixes it? |
@ydshieh @gante Thank you! No this PR is unrelated unfortunately, as it also happens when the prompt that does not have any inf. It may just be instability in the model, but it feels weird that it arises only when some attention mask rows are fully masked. |
Well, I guess it needs another deep dive 😬 |
I haven't been able to give a final conclusion, but in h1 = self.gate_proj(x)
h2 = self.act_fn(h1)
h3 = self.up_proj(x)
h4 = self.down_proj(h2 * h3)
down_proj = h4 and print their maximal absolute values, we will see their magnitude get unusually larger than before from layer 29 (0-based), and amplified to The question is what happened in Will take a further look later when I get spare time.
full
|
After taking a further look, this doesn't seem to relate any bug but just the limitation of using fp16, and this is also depending on the input data. One observation I found is: larger tensor values tend to appear when the prompt is (very) short. Also, when this happens, I often see many places in the corresponding multiplications have values with the same sign. Nothing more I can provide I am afraid. |
Thanks a lot @ydshieh. Did you notice any difference with whether rows are fully masked in the attention mask or not? We can probably close this one - at least it is good to know that (at least) llama 7b has numerical instabilities during inference in fp16. |
Oh, I might made a mistake! You have |
FYI: here the issue is not even in the generation - the issue comes already in the first step: just encoding the input prompt. |
Same issue in layer 29/30 in AutoGPTQ/AutoGPTQ#412. Unmasking fully masked padding rows solves the issue there as well. And the nans indeed start to appear at the padding index if we do not unmask: In the layer 30 without unmasking:
In the layer 30 with unmasking fully masked rows:
It is unclear to me what is happening here and how it relates to fully masked rows. |
Great details! I am thinking if maybe the original training saw the unmasked row but now at inference time, it saw another version, which leads to this large value now. (similar to the different behavior of SDPA between torch 2.0.1 / 2.1.0 on GPU as we saw previously.) |
@ydshieh I want to give a try at some point to the original llama repo to see how padding is handled there. |
not stale |
mark |
I think computing ROPE in float32 percision should partly fix this |
I'll mark this as closed, because llama now computes rope in float32! 🥳 Feel free to ping me if you feel like this should not be closed |
you can try to update optimum to the latest version to solve this |
System Info
transformers
version: 4.35.0.dev0Who can help?
@ydshieh @fxmarty @gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hi, I encounter inference instability with llama running in fp16 when left padding is used, and especially when full rows are masked out in the 4D attention mask.
At some point in the forward,
inf
values may appear in the intermediate logits, ultimately leading to tensors filled withnan
and raising the error:Note that the
inf
specifically appear at a padding position.Reproduction:
Printing
torch.all(torch.isfinite())
at some points in the model, it appears theinf
start to appear in the MLP atself.gate_proj(x)) * self.up_proj(x)
and things go crazy from there.What's interesting is that for example fixing (two left padding tokens)
to
solves the issue.
It makes me think that the solution implemented for SDPA to avoid fully masked rows in the attention mask may actually be required for some other cases as this one #26572 - but it is unclear why it relates to overflow here.
WDYT @gante @ydshieh? Is this something you have ever observed?
Expected behavior
No
inf
spawning in the middle of inference with fp16 modelThe text was updated successfully, but these errors were encountered: