Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading broken LoRAs that could give NaN #5316

Merged
merged 7 commits into from
Oct 9, 2023
Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 6, 2023

What does this PR do?

This PR adds a safe_fusing=False/True flag that allows to detect broken LoRAs. Fixes: #5313

by doing:

#!/usr/bin/env python3
import torch
from warnings import warn
from diffusers import (
    AutoencoderKL,
    DiffusionPipeline,
)
import hashlib

base = "stabilityai/stable-diffusion-xl-base-1.0"
adapter1 = 'nerijs/pixel-art-xl'
weightname1 = 'pixel-art-xl.safetensors'

adapter2 = 'Alexzyx/lora-trained-xl-colab'
weightname2 = None

inputs = "elephant"
kwargs = {}

if torch.cuda.is_available():
    kwargs["torch_dtype"] = torch.float16

#vae = AutoencoderKL.from_pretrained(
#    "madebyollin/sdxl-vae-fp16-fix",
#    torch_dtype=torch.float16,  # load fp16 fix VAE
#)
#kwargs["vae"] = vae
#kwargs["variant"] = "fp16"
#

model = DiffusionPipeline.from_pretrained(
    base, **kwargs
)

if torch.cuda.is_available():
    model.to("cuda")


def inference(adapter, weightname):
    model.load_lora_weights(adapter, weight_name=weightname)
    try:
        model.fuse_lora(safe_fusing=True)
    except ValueError:
        warn(f"{adapter} and {weightname} is broken. LoRA is not fused.")
        model.unload_lora_weights()

    data = model(inputs, num_inference_steps=1).images[0]
    model.unfuse_lora()
    model.unload_lora_weights()
    filename = '/tmp/hello.jpg'
    data.save(filename, format='jpeg')
    with open(filename, 'rb') as f:
        md5 = hashlib.md5(f.read()).hexdigest()
    print("Adapter %s, md5sum %s" % (adapter, md5))
    if md5 == '40c78c9fd4daeff01c988c3532fdd51b':
        print("BLACK SCREEN IMAGE for adapter %s" % adapter)


inference(adapter1, weightname1)
inference(adapter2, weightname2)
inference(adapter1, weightname1)
inference(adapter1, weightname1)

Other design options

Alternatively we could throw a warning and just make it a no-op by not fusing the weights. However, this doesn't really solve the problem since the user still needs to unload the LoRAs anyways then. Personally, I think it's cleaner to solve it with a try/except statement as shown above.

TODO

  • Add tests
  • Sync with PEFT team how to implement this "safe" fusing cc @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will add it in #5151 together with a test !

C Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should not be there 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gracias

Comment on lines +139 to +144
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crazy, honestly.

@@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we defaulting to False?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe because the torch.isnan(fused_weight).any().item() adds an overhead to the fusion operation. For a large tensor this might add a considerable slowdown (one needs to benchmark though), so to be on the safe zone I would also default it to False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Thanks for explaining!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me and I agree that tackling it using try/except is better. Even when there is no try/except we throw a sensible error message which makes sense to me.

@younesbelkada
Copy link
Contributor

FYI, equivalent PR in PEFT : huggingface/peft#1001

@@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do you know if usually the nan happens in

w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

or in

(lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

cc @BenjaminBossan - if it happens in the second case we could remove the copy in the PEFT PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know really. I have not tested it, but my intuition is that checking for NaN values can be quite expensive anyways when on GPU. So no matter what we have a time overhead and can't set safe_fusing as a default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sounds great!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some of the same comments here as in huggingface/peft#1001 (GH won't let me link to the comment directly). I'll paraphrase:

The error message

This LoRA weight seems to be broken

is a bit confusing IMO, because (if I'm not mistaken) it is totally possible for the LoRA weight to be working in forward when applied separately from the original weights, and only encountering errors after fusing, since the mathematical operation is not identical. E.g. two weight parameters could be overflowing when added, but when they are both first multiplied by the activation and only then added, they might not overflow anymore.

Therefore, I wouldn't say "broken", as it may work without fusing. Instead, I would change the message to just say that this adapter cannot be fused safely. WDYT?

The next issue is that I think we should not check with torch.isnan because it doesn't detect torch.inf. Instead, torch.isfinite(x).all() should work for both torch.inf and torch.nan. WDYT?

src/diffusers/models/lora.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title Fix fuse Lora Fix loading broken LoRAs that could give NaN Oct 9, 2023
@patrickvonplaten patrickvonplaten merged commit ed2f956 into main Oct 9, 2023
12 of 13 checks passed
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Oct 10, 2023

I have some of the same comments here as in huggingface/peft#1001 (GH won't let me link to the comment directly). I'll paraphrase:

The error message

This LoRA weight seems to be broken

is a bit confusing IMO, because (if I'm not mistaken) it is totally possible for the LoRA weight to be working in forward when applied separately from the original weights, and only encountering errors after fusing, since the mathematical operation is not identical. E.g. two weight parameters could be overflowing when added, but when they are both first multiplied by the activation and only then added, they might not overflow anymore.

Therefore, I wouldn't say "broken", as it may work without fusing. Instead, I would change the message to just say that this adapter cannot be fused safely. WDYT?

The next issue is that I think we should not check with torch.isnan because it doesn't detect torch.inf. Instead, torch.isfinite(x).all() should work for both torch.inf and torch.nan. WDYT?

Hmm good comments!

1.) I think when we see NaN values (like we're seeing here), the only reason for this can be that the LoRA weights are already broken. If the weights would be overflowing I think we would only see inf values.
E.g. the operation:

w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

can only yield NaN if there is already a NaN value in either w_up or w_down. If any of the values would be inf or overflowing the result would also be inf and not NaN because we only have "multiply" and/or "add" operations which when one value is inf would not result in Nan, but inf (we don't have any divide operations here)

=> So I think the only reason/explanation here is in fact that the LoRA weights are broken. Wdyt?

2.) It's probably a good idea to also handle the case where weights yield inf values. In this case I think there are two options:

  • a) There is already an inf value in the LoRA weights (=> weights are broken)
  • b) Values are overflowing in fp32 (note that this is in my experience really rare, fp32 has an upper limit that is very high). In this case there is a (small) chance that one could run the LoRA without fusing and it does indeed make sense to give a nice error message here.

Edit: torch.bmm can yield NaN values if inputs are "inf" because it doesn't use standard matrix multiplication.
=> So both statements are valid @BenjaminBossan! Happy to then change torch.isnan to not torch.isfinite and have a more precise error message (that either weights are broken or they've overflown). Thanks for the investigation! Would you like to open a PR here? :-)

Overall, I would not advise the user to run it in "non-fused" mode as I think it's highly unlikely to yield reasonable results. But it is a possibility so having a precise error message would be great here.

@BenjaminBossan
Copy link
Member

I agree that infs would be more surprising to find (not sure if weights can be fp16 here, then it might happen more often). The main reason for suggesting the check is that torch.isfinite detects inf on top of nan, so this would cover both. However, I did not think about performance, which might be an issue:

>>> import torch
>>> x = torch.rand((1000, 1000)).to('cuda')
>>> %time _ = torch.isfinite(x)
CPU times: user 14.9 ms, sys: 4.45 ms, total: 19.4 ms
Wall time: 28.6 ms
>>> %time _ = torch.isnan(x)
CPU times: user 216 µs, sys: 33 µs, total: 249 µs
Wall time: 2.01 ms

Not sure if the tradeoff is worth it.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Oct 10, 2023

I agree that infs would be more surprising to find (not sure if weights can be fp16 here, then it might happen more often). The main reason for suggesting the check is that torch.isfinite detects inf on top of nan, so this would cover both. However, I did not think about performance, which might be an issue:

>>> import torch
>>> x = torch.rand((1000, 1000)).to('cuda')
>>> %time _ = torch.isfinite(x)
CPU times: user 14.9 ms, sys: 4.45 ms, total: 19.4 ms
Wall time: 28.6 ms
>>> %time _ = torch.isnan(x)
CPU times: user 216 µs, sys: 33 µs, total: 249 µs
Wall time: 2.01 ms

Not sure if the tradeoff is worth it.

Actually I played around with it a bit more. In my experiments:

w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

can only yield NaN if any of the values of w_up or w_down is either "inf" or "NaN". E.g. the folllowing (when tested here) does not yield NaN

sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += torch.finfo(torch.float32).max

but inf

=> So maybe it could make sense to make a difference between "broken" and "potentially broken / overflowing"


# corrupt one LoRA weight with `inf` values
with torch.no_grad():
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the different possibilities here:
#5316 (comment)

@BenjaminBossan

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Oct 10, 2023

Ok my conclusion here is the following:

  1. I'd argue that in 99% of the cases when fusing yields NaN or Inf values the LoRA weights are already broken before-hand. E.g. see: https://huggingface.co/Alexzyx/lora-trained-xl-colab/discussions/1

  2. It very strongly looks to me like that we can get NaN only if the weights were broken before (have either Inf or NaN values). Tested this for a bunch of different values now. See here.

  3. Given that checking for NaN is much faster than checking for Inf and given that it's good for the user to know if weights are broken or if overflow, I'd advocate to add two "checks" - one for NaN (=> weights are broken, don't use it), one for "Inf" (=> weights are broken OR weights are overflowing exactly in this operation).

=> So I'd say we add a new check for Inf with a nice error message. Thoughts @BenjaminBossan ?

@sayakpaul
Copy link
Member

Really like #5316 (comment)

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Oct 10, 2023

1. I'd argue that in 99% of the cases when fusing yields NaN or Inf values the LoRA weights are already broken before-hand.

Ideally, this should be noticed as early as possible. Is there a good entry point for such a check before fusing the weights?

2. It very strongly looks to me like that we can get NaN only if the weights were broken before

Thanks for testing, this was not obvious to me.

3. I'd advocate to add two "checks"

The downside I could see from this is that for the majority of cases, there is nothing wrong, so adding the inf check, which is slow, would slow down the whole fusing call despite only being relevant in very rare cases. One could argue that if users opt into checking, they are ready to wait a bit longer. In the end, it's a trade-off and it could be worth it to check the overhead introduced by inf checking on a real model.

@patrickvonplaten
Copy link
Contributor Author

  1. I'd argue that in 99% of the cases when fusing yields NaN or Inf values the LoRA weights are already broken before-hand.

Ideally, this should be noticed as early as possible. Is there a good entry point for such a check before fusing the weights?

True! We could check both lora_up and lora_down instead for both NaN and inf instead of checking the fused weigths for Nan, but not sure what is faster here in the end.

  1. It very strongly looks to me like that we can get NaN only if the weights were broken before

Thanks for testing, this was not obvious to me.

  1. I'd advocate to add two "checks"

The downside I could see from this is that for the majority of cases, there is nothing wrong, so adding the inf check, which is slow, would slow down the whole fusing call despite only being relevant in very rare cases. One could argue that if users opt into checking, they are ready to wait a bit longer. In the end, it's a trade-off and it could be worth it to check the overhead introduced by inf checking on a real model.

True, also happy to just skip this check for now if it's too slow or another option is that we change safe_fuse=True/False to something like safe_fuse=None/"finite"/"nan" to have for granularity of checks

@sayakpaul
Copy link
Member

True! We could check both lora_up and lora_down instead for both NaN and inf instead of checking the fused weigths for Nan, but not sure what is faster here in the end.

I think during the fusion process we could check it. I think that is the best tradeoff.

True, also happy to just skip this check for now if it's too slow or another option is that we change safe_fuse=True/False to something like safe_fuse=None/"finite"/"nan" to have for granularity of checks

Granular check sounds great to me!

@kashif kashif deleted the save_fuse_lora branch December 5, 2023 09:00
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix fuse Lora

* improve a bit

* make style

* Update src/diffusers/models/lora.py

Co-authored-by: Benjamin Bossan <[email protected]>

* ciao C file

* ciao C file

* test & make style

---------

Co-authored-by: Benjamin Bossan <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fix fuse Lora

* improve a bit

* make style

* Update src/diffusers/models/lora.py

Co-authored-by: Benjamin Bossan <[email protected]>

* ciao C file

* ciao C file

* test & make style

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LoRA adapters: non revertible fuse
5 participants