Skip to content

Commit

Permalink
Fix loading broken LoRAs that could give NaN (huggingface#5316)
Browse files Browse the repository at this point in the history
* Fix fuse Lora

* improve a bit

* make style

* Update src/diffusers/models/lora.py

Co-authored-by: Benjamin Bossan <[email protected]>

* ciao C file

* ciao C file

* test & make style

---------

Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
patrickvonplaten and BenjaminBossan authored Oct 9, 2023
1 parent a844065 commit ed2f956
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 17 deletions.
48 changes: 33 additions & 15 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):

return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_linear_layer is None:
return

Expand All @@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)

self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand Down Expand Up @@ -672,13 +680,14 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def fuse_lora(self, lora_scale=1.0):
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(self._fuse_lora_apply)

def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale)
module._fuse_lora(self.lora_scale, self._safe_fusing)

def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
Expand Down Expand Up @@ -2086,7 +2095,13 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
Expand All @@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
Expand All @@ -2112,12 +2129,13 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
)

if fuse_unet:
self.unet.fuse_lora(lora_scale)
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)

if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer

def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
# TODO(Patrick, Younes): enable "safe" fusing
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
Expand All @@ -2129,24 +2147,24 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
if version.parse(__version__) > version.parse("0.23"):
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)

def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)

if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder, lora_scale)
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)

def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
Expand Down
20 changes: 18 additions & 2 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_layer is None:
return

Expand All @@ -128,6 +128,14 @@ def _fuse_lora(self, lora_scale=1.0):
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + (lora_scale * fusion)

if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)

self.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand Down Expand Up @@ -182,7 +190,7 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_layer is None:
return

Expand All @@ -196,6 +204,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank

fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)

self.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand Down
41 changes: 41 additions & 0 deletions tests/lora/test_lora_layers_old_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,47 @@ def test_load_lora_locally_safetensors(self):

sd_pipe.unload_lora_weights()

def test_lora_fuse_nan(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)

# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)

with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))

# corrupt one LoRA weight with `inf` values
with torch.no_grad():
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
"inf"
)

# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
sd_pipe.fuse_lora(safe_fusing=True)

# without we should not see an error, but every image will be black
sd_pipe.fuse_lora(safe_fusing=False)

out = sd_pipe("test", num_inference_steps=2, output_type="np").images

assert np.isnan(out).all()

def test_lora_fusion(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
Expand Down

0 comments on commit ed2f956

Please sign in to comment.