-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix loading broken LoRAs that could give NaN #5316
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, will add it in #5151 together with a test !
C
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this file should not be there 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gracias
if safe_fusing and torch.isnan(fused_weight).any().item(): | ||
raise ValueError( | ||
"This LoRA weight seems to be broken. " | ||
f"Encountered NaN values when trying to fuse LoRA weights for {self}." | ||
"LoRA weights will not be fused." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Crazy, honestly.
@@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora | |||
LoRA parameters then it won't have any effect. | |||
lora_scale (`float`, defaults to 1.0): | |||
Controls how much to influence the outputs with the LoRA parameters. | |||
safe_fusing (`bool`, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we defaulting to False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe because the torch.isnan(fused_weight).any().item()
adds an overhead to the fusion operation. For a large tensor this might add a considerable slowdown (one needs to benchmark though), so to be on the safe zone I would also default it to False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! Thanks for explaining!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me and I agree that tackling it using try/except
is better. Even when there is no try/except we throw a sensible error message which makes sense to me.
FYI, equivalent PR in PEFT : huggingface/peft#1001 |
@@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0): | |||
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank | |||
|
|||
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: do you know if usually the nan
happens in
w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
or in
(lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
cc @BenjaminBossan - if it happens in the second case we could remove the copy in the PEFT PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know really. I have not tested it, but my intuition is that checking for NaN values can be quite expensive anyways when on GPU. So no matter what we have a time overhead and can't set safe_fusing
as a default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK sounds great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some of the same comments here as in huggingface/peft#1001 (GH won't let me link to the comment directly). I'll paraphrase:
The error message
This LoRA weight seems to be broken
is a bit confusing IMO, because (if I'm not mistaken) it is totally possible for the LoRA weight to be working in forward when applied separately from the original weights, and only encountering errors after fusing, since the mathematical operation is not identical. E.g. two weight parameters could be overflowing when added, but when they are both first multiplied by the activation and only then added, they might not overflow anymore.
Therefore, I wouldn't say "broken", as it may work without fusing. Instead, I would change the message to just say that this adapter cannot be fused safely. WDYT?
The next issue is that I think we should not check with torch.isnan
because it doesn't detect torch.inf
. Instead, torch.isfinite(x).all()
should work for both torch.inf
and torch.nan
. WDYT?
Co-authored-by: Benjamin Bossan <[email protected]>
The documentation is not available anymore as the PR was closed or merged. |
Hmm good comments!
w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) can only yield => So I think the only reason/explanation here is in fact that the LoRA weights are broken. Wdyt? 2.) It's probably a good idea to also handle the case where weights yield inf values. In this case I think there are two options:
Edit: torch.bmm can yield NaN values if inputs are "inf" because it doesn't use standard matrix multiplication. Overall, I would not advise the user to run it in "non-fused" mode as I think it's highly unlikely to yield reasonable results. But it is a possibility so having a precise error message would be great here. |
I agree that infs would be more surprising to find (not sure if weights can be fp16 here, then it might happen more often). The main reason for suggesting the check is that >>> import torch
>>> x = torch.rand((1000, 1000)).to('cuda')
>>> %time _ = torch.isfinite(x)
CPU times: user 14.9 ms, sys: 4.45 ms, total: 19.4 ms
Wall time: 28.6 ms
>>> %time _ = torch.isnan(x)
CPU times: user 216 µs, sys: 33 µs, total: 249 µs
Wall time: 2.01 ms Not sure if the tradeoff is worth it. |
Actually I played around with it a bit more. In my experiments: w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) can only yield NaN if any of the values of sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += torch.finfo(torch.float32).max but => So maybe it could make sense to make a difference between "broken" and "potentially broken / overflowing" |
|
||
# corrupt one LoRA weight with `inf` values | ||
with torch.no_grad(): | ||
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing the different possibilities here:
#5316 (comment)
Ok my conclusion here is the following:
=> So I'd say we add a new check for Inf with a nice error message. Thoughts @BenjaminBossan ? |
Really like #5316 (comment) |
Ideally, this should be noticed as early as possible. Is there a good entry point for such a check before fusing the weights?
Thanks for testing, this was not obvious to me.
The downside I could see from this is that for the majority of cases, there is nothing wrong, so adding the inf check, which is slow, would slow down the whole fusing call despite only being relevant in very rare cases. One could argue that if users opt into checking, they are ready to wait a bit longer. In the end, it's a trade-off and it could be worth it to check the overhead introduced by inf checking on a real model. |
True! We could check both lora_up and lora_down instead for both NaN and inf instead of checking the fused weigths for Nan, but not sure what is faster here in the end.
True, also happy to just skip this check for now if it's too slow or another option is that we change |
I think during the fusion process we could check it. I think that is the best tradeoff.
Granular check sounds great to me! |
* Fix fuse Lora * improve a bit * make style * Update src/diffusers/models/lora.py Co-authored-by: Benjamin Bossan <[email protected]> * ciao C file * ciao C file * test & make style --------- Co-authored-by: Benjamin Bossan <[email protected]>
* Fix fuse Lora * improve a bit * make style * Update src/diffusers/models/lora.py Co-authored-by: Benjamin Bossan <[email protected]> * ciao C file * ciao C file * test & make style --------- Co-authored-by: Benjamin Bossan <[email protected]>
What does this PR do?
This PR adds a
safe_fusing=False/True
flag that allows to detect broken LoRAs. Fixes: #5313by doing:
Other design options
Alternatively we could throw a warning and just make it a no-op by not fusing the weights. However, this doesn't really solve the problem since the user still needs to unload the LoRAs anyways then. Personally, I think it's cleaner to solve it with a try/except statement as shown above.
TODO