Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA Flux Xlabs] Error loading trained LoRA with Xlabs on Diffusers (Fix proposal) #9914

Closed
raulmosa opened this issue Nov 12, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@raulmosa
Copy link
Contributor

raulmosa commented Nov 12, 2024

Describe the bug

Proposal to update the following script for Xlab Flux LoRA conversion due to a mismatch between keys in the state dictionary.
src/diffusers/loaders/lora_conversion_utils.py
When mapping single_blocks layers, if the model trained in Flux contains single_blocks, these keys are not updated and removed from the old_state_dict, see lines 635-655. And the ValueError is reached:

 if len(old_state_dict) > 0:
    raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

See example, keys from Flux LoRA model working (XLabs-AI/flux-RealismLora), it doesn’t contain single_blocks:
['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight', 'double_blocks.1.processor.proj_lora1.down.weight', 'double_blocks.1.processor.proj_lora1.up.weight', 'double_blocks.1.processor.proj_lora2.down.weight', 'double_blocks.1.processor.proj_lora2.up.weight', 'double_blocks.1.processor.qkv_lora1.down.weight', 'double_blocks.1.processor.qkv_lora1.up.weight', 'double_blocks.1.processor.qkv_lora2.down.weight', 'double_blocks.1.processor.qkv_lora2.up.weight', 'double_blocks.10.processor.proj_lora1.down.weight', 'double_blocks.10.processor.proj_lora1.up.weight', 'double_blocks.10.processor.proj_lora2.down.weight', 'double_blocks.10.processor.proj_lora2.up.weight', 'double_blocks.10.processor.qkv_lora1.down.weight', 'double_blocks.10.processor.qkv_lora1.up.weight', 'double_blocks.10.processor.qkv_lora2.down.weight', 'double_blocks.10.processor.qkv_lora2.up.weight', 'double_blocks.11.processor.proj_lora1.down.weight', 'double_blocks.11.processor.proj_lora1.up.weight', 'double_blocks.11.processor.proj_lora2.down.weight', 'double_blocks.11.processor.proj_lora2.up.weight', 'double_blocks.11.processor.qkv_lora1.down.weight', 'double_blocks.11.processor.qkv_lora1.up.weight', 'double_blocks.11.processor.qkv_lora2.down.weight', 'double_blocks.11.processor.qkv_lora2.up.weight', 'double_blocks.12.processor.proj_lora1.down.weight', 'double_blocks.12.processor.proj_lora1.up.weight', 'double_blocks.12.processor.proj_lora2.down.weight', 'double_blocks.12.processor.proj_lora2.up.weight', 'double_blocks.12.processor.qkv_lora1.down.weight', 'double_blocks.12.processor.qkv_lora1.up.weight', 'double_blocks.12.processor.qkv_lora2.down.weight', 'double_blocks.12.processor.qkv_lora2.up.weight', 'double_blocks.13.processor.proj_lora1.down.weight', 'double_blocks.13.processor.proj_lora1.up.weight', 'double_blocks.13.processor.proj_lora2.down.weight', 'double_blocks.13.processor.proj_lora2.up.weight', 'double_blocks.13.processor.qkv_lora1.down.weight', 'double_blocks.13.processor.qkv_lora1.up.weight', 'double_blocks.13.processor.qkv_lora2.down.weight', 'double_blocks.13.processor.qkv_lora2.up.weight', 'double_blocks.14.processor.proj_lora1.down.weight', 'double_blocks.14.processor.proj_lora1.up.weight', 'double_blocks.14.processor.proj_lora2.down.weight', 'double_blocks.14.processor.proj_lora2.up.weight', 'double_blocks.14.processor.qkv_lora1.down.weight', 'double_blocks.14.processor.qkv_lora1.up.weight', 'double_blocks.14.processor.qkv_lora2.down.weight', 'double_blocks.14.processor.qkv_lora2.up.weight', 'double_blocks.15.processor.proj_lora1.down.weight', 'double_blocks.15.processor.proj_lora1.up.weight', 'double_blocks.15.processor.proj_lora2.down.weight', 'double_blocks.15.processor.proj_lora2.up.weight', 'double_blocks.15.processor.qkv_lora1.down.weight', 'double_blocks.15.processor.qkv_lora1.up.weight', 'double_blocks.15.processor.qkv_lora2.down.weight', 'double_blocks.15.processor.qkv_lora2.up.weight', 'double_blocks.16.processor.proj_lora1.down.weight', 'double_blocks.16.processor.proj_lora1.up.weight', 'double_blocks.16.processor.proj_lora2.down.weight', 'double_blocks.16.processor.proj_lora2.up.weight', 'double_blocks.16.processor.qkv_lora1.down.weight', 'double_blocks.16.processor.qkv_lora1.up.weight', 'double_blocks.16.processor.qkv_lora2.down.weight', 'double_blocks.16.processor.qkv_lora2.up.weight', 'double_blocks.17.processor.proj_lora1.down.weight', 'double_blocks.17.processor.proj_lora1.up.weight', 'double_blocks.17.processor.proj_lora2.down.weight', 'double_blocks.17.processor.proj_lora2.up.weight', 'double_blocks.17.processor.qkv_lora1.down.weight', 'double_blocks.17.processor.qkv_lora1.up.weight', 'double_blocks.17.processor.qkv_lora2.down.weight', 'double_blocks.17.processor.qkv_lora2.up.weight', 'double_blocks.18.processor.proj_lora1.down.weight', 'double_blocks.18.processor.proj_lora1.up.weight', 'double_blocks.18.processor.proj_lora2.down.weight', 'double_blocks.18.processor.proj_lora2.up.weight', 'double_blocks.18.processor.qkv_lora1.down.weight', 'double_blocks.18.processor.qkv_lora1.up.weight', 'double_blocks.18.processor.qkv_lora2.down.weight', 'double_blocks.18.processor.qkv_lora2.up.weight', 'double_blocks.2.processor.proj_lora1.down.weight', 'double_blocks.2.processor.proj_lora1.up.weight', 'double_blocks.2.processor.proj_lora2.down.weight', 'double_blocks.2.processor.proj_lora2.up.weight', 'double_blocks.2.processor.qkv_lora1.down.weight', 'double_blocks.2.processor.qkv_lora1.up.weight', 'double_blocks.2.processor.qkv_lora2.down.weight', 'double_blocks.2.processor.qkv_lora2.up.weight', 'double_blocks.3.processor.proj_lora1.down.weight', 'double_blocks.3.processor.proj_lora1.up.weight', 'double_blocks.3.processor.proj_lora2.down.weight', 'double_blocks.3.processor.proj_lora2.up.weight', 'double_blocks.3.processor.qkv_lora1.down.weight', 'double_blocks.3.processor.qkv_lora1.up.weight', 'double_blocks.3.processor.qkv_lora2.down.weight', 'double_blocks.3.processor.qkv_lora2.up.weight', 'double_blocks.4.processor.proj_lora1.down.weight', 'double_blocks.4.processor.proj_lora1.up.weight', 'double_blocks.4.processor.proj_lora2.down.weight', 'double_blocks.4.processor.proj_lora2.up.weight', 'double_blocks.4.processor.qkv_lora1.down.weight', 'double_blocks.4.processor.qkv_lora1.up.weight', 'double_blocks.4.processor.qkv_lora2.down.weight', 'double_blocks.4.processor.qkv_lora2.up.weight', 'double_blocks.5.processor.proj_lora1.down.weight', 'double_blocks.5.processor.proj_lora1.up.weight', 'double_blocks.5.processor.proj_lora2.down.weight', 'double_blocks.5.processor.proj_lora2.up.weight', 'double_blocks.5.processor.qkv_lora1.down.weight', 'double_blocks.5.processor.qkv_lora1.up.weight', 'double_blocks.5.processor.qkv_lora2.down.weight', 'double_blocks.5.processor.qkv_lora2.up.weight', 'double_blocks.6.processor.proj_lora1.down.weight', 'double_blocks.6.processor.proj_lora1.up.weight', 'double_blocks.6.processor.proj_lora2.down.weight', 'double_blocks.6.processor.proj_lora2.up.weight', 'double_blocks.6.processor.qkv_lora1.down.weight', 'double_blocks.6.processor.qkv_lora1.up.weight', 'double_blocks.6.processor.qkv_lora2.down.weight', 'double_blocks.6.processor.qkv_lora2.up.weight', 'double_blocks.7.processor.proj_lora1.down.weight', 'double_blocks.7.processor.proj_lora1.up.weight', 'double_blocks.7.processor.proj_lora2.down.weight', 'double_blocks.7.processor.proj_lora2.up.weight', 'double_blocks.7.processor.qkv_lora1.down.weight', 'double_blocks.7.processor.qkv_lora1.up.weight', 'double_blocks.7.processor.qkv_lora2.down.weight', 'double_blocks.7.processor.qkv_lora2.up.weight', 'double_blocks.8.processor.proj_lora1.down.weight', 'double_blocks.8.processor.proj_lora1.up.weight', 'double_blocks.8.processor.proj_lora2.down.weight', 'double_blocks.8.processor.proj_lora2.up.weight', 'double_blocks.8.processor.qkv_lora1.down.weight', 'double_blocks.8.processor.qkv_lora1.up.weight', 'double_blocks.8.processor.qkv_lora2.down.weight', 'double_blocks.8.processor.qkv_lora2.up.weight', 'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight']

And below an example of a LoRA trained with current Xlabs code containing single_blocks:
['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight', 'double_blocks.1.processor.proj_lora1.down.weight', 'double_blocks.1.processor.proj_lora1.up.weight', 'double_blocks.1.processor.proj_lora2.down.weight', 'double_blocks.1.processor.proj_lora2.up.weight', 'double_blocks.1.processor.qkv_lora1.down.weight', 'double_blocks.1.processor.qkv_lora1.up.weight', 'double_blocks.1.processor.qkv_lora2.down.weight', 'double_blocks.1.processor.qkv_lora2.up.weight', 'double_blocks.10.processor.proj_lora1.down.weight', 'double_blocks.10.processor.proj_lora1.up.weight', 'double_blocks.10.processor.proj_lora2.down.weight', 'double_blocks.10.processor.proj_lora2.up.weight', 'double_blocks.10.processor.qkv_lora1.down.weight', 'double_blocks.10.processor.qkv_lora1.up.weight', 'double_blocks.10.processor.qkv_lora2.down.weight', 'double_blocks.10.processor.qkv_lora2.up.weight', 'double_blocks.11.processor.proj_lora1.down.weight', 'double_blocks.11.processor.proj_lora1.up.weight', 'double_blocks.11.processor.proj_lora2.down.weight', 'double_blocks.11.processor.proj_lora2.up.weight', 'double_blocks.11.processor.qkv_lora1.down.weight', 'double_blocks.11.processor.qkv_lora1.up.weight', 'double_blocks.11.processor.qkv_lora2.down.weight', 'double_blocks.11.processor.qkv_lora2.up.weight', 'double_blocks.12.processor.proj_lora1.down.weight', 'double_blocks.12.processor.proj_lora1.up.weight', 'double_blocks.12.processor.proj_lora2.down.weight', 'double_blocks.12.processor.proj_lora2.up.weight', 'double_blocks.12.processor.qkv_lora1.down.weight', 'double_blocks.12.processor.qkv_lora1.up.weight', 'double_blocks.12.processor.qkv_lora2.down.weight', 'double_blocks.12.processor.qkv_lora2.up.weight', 'double_blocks.13.processor.proj_lora1.down.weight', 'double_blocks.13.processor.proj_lora1.up.weight', 'double_blocks.13.processor.proj_lora2.down.weight', 'double_blocks.13.processor.proj_lora2.up.weight', 'double_blocks.13.processor.qkv_lora1.down.weight', 'double_blocks.13.processor.qkv_lora1.up.weight', 'double_blocks.13.processor.qkv_lora2.down.weight', 'double_blocks.13.processor.qkv_lora2.up.weight', 'double_blocks.14.processor.proj_lora1.down.weight', 'double_blocks.14.processor.proj_lora1.up.weight', 'double_blocks.14.processor.proj_lora2.down.weight', 'double_blocks.14.processor.proj_lora2.up.weight', 'double_blocks.14.processor.qkv_lora1.down.weight', 'double_blocks.14.processor.qkv_lora1.up.weight', 'double_blocks.14.processor.qkv_lora2.down.weight', 'double_blocks.14.processor.qkv_lora2.up.weight', 'double_blocks.15.processor.proj_lora1.down.weight', 'double_blocks.15.processor.proj_lora1.up.weight', 'double_blocks.15.processor.proj_lora2.down.weight', 'double_blocks.15.processor.proj_lora2.up.weight', 'double_blocks.15.processor.qkv_lora1.down.weight', 'double_blocks.15.processor.qkv_lora1.up.weight', 'double_blocks.15.processor.qkv_lora2.down.weight', 'double_blocks.15.processor.qkv_lora2.up.weight', 'double_blocks.16.processor.proj_lora1.down.weight', 'double_blocks.16.processor.proj_lora1.up.weight', 'double_blocks.16.processor.proj_lora2.down.weight', 'double_blocks.16.processor.proj_lora2.up.weight', 'double_blocks.16.processor.qkv_lora1.down.weight', 'double_blocks.16.processor.qkv_lora1.up.weight', 'double_blocks.16.processor.qkv_lora2.down.weight', 'double_blocks.16.processor.qkv_lora2.up.weight', 'double_blocks.17.processor.proj_lora1.down.weight', 'double_blocks.17.processor.proj_lora1.up.weight', 'double_blocks.17.processor.proj_lora2.down.weight', 'double_blocks.17.processor.proj_lora2.up.weight', 'double_blocks.17.processor.qkv_lora1.down.weight', 'double_blocks.17.processor.qkv_lora1.up.weight', 'double_blocks.17.processor.qkv_lora2.down.weight', 'double_blocks.17.processor.qkv_lora2.up.weight', 'double_blocks.18.processor.proj_lora1.down.weight', 'double_blocks.18.processor.proj_lora1.up.weight', 'double_blocks.18.processor.proj_lora2.down.weight', 'double_blocks.18.processor.proj_lora2.up.weight', 'double_blocks.18.processor.qkv_lora1.down.weight', 'double_blocks.18.processor.qkv_lora1.up.weight', 'double_blocks.18.processor.qkv_lora2.down.weight', 'double_blocks.18.processor.qkv_lora2.up.weight', 'double_blocks.2.processor.proj_lora1.down.weight', 'double_blocks.2.processor.proj_lora1.up.weight', 'double_blocks.2.processor.proj_lora2.down.weight', 'double_blocks.2.processor.proj_lora2.up.weight', 'double_blocks.2.processor.qkv_lora1.down.weight', 'double_blocks.2.processor.qkv_lora1.up.weight', 'double_blocks.2.processor.qkv_lora2.down.weight', 'double_blocks.2.processor.qkv_lora2.up.weight', 'double_blocks.3.processor.proj_lora1.down.weight', 'double_blocks.3.processor.proj_lora1.up.weight', 'double_blocks.3.processor.proj_lora2.down.weight', 'double_blocks.3.processor.proj_lora2.up.weight', 'double_blocks.3.processor.qkv_lora1.down.weight', 'double_blocks.3.processor.qkv_lora1.up.weight', 'double_blocks.3.processor.qkv_lora2.down.weight', 'double_blocks.3.processor.qkv_lora2.up.weight', 'double_blocks.4.processor.proj_lora1.down.weight', 'double_blocks.4.processor.proj_lora1.up.weight', 'double_blocks.4.processor.proj_lora2.down.weight', 'double_blocks.4.processor.proj_lora2.up.weight', 'double_blocks.4.processor.qkv_lora1.down.weight', 'double_blocks.4.processor.qkv_lora1.up.weight', 'double_blocks.4.processor.qkv_lora2.down.weight', 'double_blocks.4.processor.qkv_lora2.up.weight', 'double_blocks.5.processor.proj_lora1.down.weight', 'double_blocks.5.processor.proj_lora1.up.weight', 'double_blocks.5.processor.proj_lora2.down.weight', 'double_blocks.5.processor.proj_lora2.up.weight', 'double_blocks.5.processor.qkv_lora1.down.weight', 'double_blocks.5.processor.qkv_lora1.up.weight', 'double_blocks.5.processor.qkv_lora2.down.weight', 'double_blocks.5.processor.qkv_lora2.up.weight', 'double_blocks.6.processor.proj_lora1.down.weight', 'double_blocks.6.processor.proj_lora1.up.weight', 'double_blocks.6.processor.proj_lora2.down.weight', 'double_blocks.6.processor.proj_lora2.up.weight', 'double_blocks.6.processor.qkv_lora1.down.weight', 'double_blocks.6.processor.qkv_lora1.up.weight', 'double_blocks.6.processor.qkv_lora2.down.weight', 'double_blocks.6.processor.qkv_lora2.up.weight', 'double_blocks.7.processor.proj_lora1.down.weight', 'double_blocks.7.processor.proj_lora1.up.weight', 'double_blocks.7.processor.proj_lora2.down.weight', 'double_blocks.7.processor.proj_lora2.up.weight', 'double_blocks.7.processor.qkv_lora1.down.weight', 'double_blocks.7.processor.qkv_lora1.up.weight', 'double_blocks.7.processor.qkv_lora2.down.weight', 'double_blocks.7.processor.qkv_lora2.up.weight', 'double_blocks.8.processor.proj_lora1.down.weight', 'double_blocks.8.processor.proj_lora1.up.weight', 'double_blocks.8.processor.proj_lora2.down.weight', 'double_blocks.8.processor.proj_lora2.up.weight', 'double_blocks.8.processor.qkv_lora1.down.weight', 'double_blocks.8.processor.qkv_lora1.up.weight', 'double_blocks.8.processor.qkv_lora2.down.weight', 'double_blocks.8.processor.qkv_lora2.up.weight', 'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight', 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight']

The script works changing lines 639-642 by:

if "proj_lora" in old_key:
  new_key += ".proj_out"
elif "qkv_lora" in old_key and "up" not in old_key:
  handle_qkv(old_state_dict, new_state_dict, old_key, [
    f"transformer.single_transformer_blocks.{block_num}.norm.linear"
  ])

Related PR #9295 (@sayakpaul )

Reproduction

import torch
from diffusers import DiffusionPipeline


model_path = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
lora_model_path = "XLabs-AI/flslux-RealismLora"
# lora_model_path = "<PATH-LoRA-trained-Xlabs.safetensors>"
pipe.load_lora_weights(lora_model_path, adapter_name="lora_A")

Logs

When a custom LoRA trained with Xlabs code containing single_blocks is loaded:

File "/home/.pyenv/versions/xflux/lib/python3.10/site-packages/diffusers/loaders/lora_conversion_utils.py", line 658, in _convert_xlabs_flux_lora_to_diffusers
    raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
ValueError: `old_state_dict` should be at this point but has: ['single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'].

System Info

  • 🤗 Diffusers version: 0.31.0
  • Platform: Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.24.5
  • Transformers version: 4.43.3
  • Accelerate version: 0.30.1
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul

@raulmosa raulmosa added the bug Something isn't working label Nov 12, 2024
@sayakpaul
Copy link
Member

sayakpaul commented Nov 12, 2024

Thanks for the detailed issue, @raulmosa. Would you like to maybe open a PR to fix this as it looks like you already know the fix?

The reason why we didn't have this support is because most of the XLabs LoRAs didn't have single blocks when we integrated.

@raulmosa
Copy link
Contributor Author

#9915

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants