Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] allow loras to be loaded with low_cpu_mem_usage. #9510

Merged
merged 32 commits into from
Oct 9, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Sep 24, 2024

What does this PR do?

huggingface/peft#1961 added the ability to set low_cpu_mem_usage while loading LoRAs. This can be quite helpful in speeding up the loading of LoRAs that are large and have many layers.

#8953 is a good example where this feature could be beneficial.

Benchmarking code
from diffusers import DiffusionPipeline 
import torch 
import time 
import fire

def main(ckpt_id: str, lora_id: str, low_cpu_mem_usage: bool = False):
    pipeline = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch.bfloat16).to("cuda")
    
    for _ in range(10):
        start_time = time.time()  
        pipeline.load_lora_weights(lora_id, low_cpu_mem_usage=low_cpu_mem_usage)
        end_time = time.time()  
        pipeline.unload_lora_weights()
        elapsed_time = end_time - start_time 
        
        print(f"Iteration {_ + 1}: Load Lora weights took {elapsed_time:.6f} seconds with {low_cpu_mem_usage=}")

if __name__ == "__main__":
    fire.Fire(main)
Results:
Iteration 1: Load Lora weights took 13.924374 seconds with low_cpu_mem_usage=True
Iteration 2: Load Lora weights took 1.621597 seconds with low_cpu_mem_usage=True
Iteration 3: Load Lora weights took 1.612010 seconds with low_cpu_mem_usage=True
Iteration 4: Load Lora weights took 1.670260 seconds with low_cpu_mem_usage=True
Iteration 5: Load Lora weights took 1.664858 seconds with low_cpu_mem_usage=True
Iteration 6: Load Lora weights took 1.482521 seconds with low_cpu_mem_usage=True
Iteration 7: Load Lora weights took 1.633697 seconds with low_cpu_mem_usage=True
Iteration 8: Load Lora weights took 1.593326 seconds with low_cpu_mem_usage=True
Iteration 9: Load Lora weights took 1.503672 seconds with low_cpu_mem_usage=True
Iteration 10: Load Lora weights took 1.566633 seconds with low_cpu_mem_usage=True

Iteration 1: Load Lora weights took 33.370373 seconds with low_cpu_mem_usage=False
Iteration 2: Load Lora weights took 3.937800 seconds with low_cpu_mem_usage=False
Iteration 3: Load Lora weights took 4.364943 seconds with low_cpu_mem_usage=False
Iteration 4: Load Lora weights took 4.303800 seconds with low_cpu_mem_usage=False
Iteration 5: Load Lora weights took 4.154818 seconds with low_cpu_mem_usage=False
Iteration 6: Load Lora weights took 3.869319 seconds with low_cpu_mem_usage=False
Iteration 7: Load Lora weights took 4.153911 seconds with low_cpu_mem_usage=False
Iteration 8: Load Lora weights took 4.275074 seconds with low_cpu_mem_usage=False
Iteration 9: Load Lora weights took 4.395445 seconds with low_cpu_mem_usage=False
Iteration 10: Load Lora weights took 4.071344 seconds with low_cpu_mem_usage=False

The feature currently needs the user to install peft and transformers from the source. So, I suggest we wait until both the libraries have made stable releases to merge this PR.

Once Ben reviews the PR, will request for a review from Yiyi.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

I think we need to have a low_cpu_mem_usage tag in the load_adapter() method too: https://github.com/huggingface/diffusers/blob/28f9d84549c0b1d24ef00d69a4c723f3a11cffb6/src/diffusers/loaders/lora_pipeline.py#L371C30-L371C42

If so, then this PR would be contingent on that. We could, however, use a combination of inject_adapter_in_model() and set_peft_model_state_dict() to mimic the same thing, I assume. I wouldn't personally prefer that because load_adapter() has been there for a while in diffusers.

Could you please clarify that bit? On the PEFT side, we have low_cpu_mem_usage on load_adapter but that's not the method being used here (just has the same name), right? Is this method coming from transformers (i.e. here)?

@sayakpaul
Copy link
Member Author

that's not the method being used here (just has the same name), right? Is this method coming from transformers (i.e. here)?

Yes, that is correct.

BenjaminBossan added a commit to BenjaminBossan/transformers that referenced this pull request Sep 26, 2024
PEFT added support for low_cpu_mem_usage=True when loading adapters in
huggingface/peft#1961. This feature is now
available when installing PEFT v0.13.0. With this PR, this option is
also supported when loading PEFT adapters directly into transformers
models.

Additionally, with this PR,
huggingface/diffusers#9510 will be unblocked,
which implements this option in diffusers.
@sayakpaul
Copy link
Member Author

sayakpaul commented Sep 27, 2024

@BenjaminBossan when I used:

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights(
    "TheLastBen/The_Hound", 
    weight_name="sandor_clegane_single_layer.safetensors", 
    low_cpu_mem_usage=True
)

prompt = "sandor clegane drinking in a pub"
image = pipe(
    prompt=prompt,
    num_inference_steps=30,
    width=1024,
    generator=torch.manual_seed(42),
    height=1024,
).images[0]
image.save("sandor.png")

It leads to:

Error trace
  File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 98, in forward
    hidden_states = gate * self.proj_out(hidden_states)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 585, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

I investigated this a bit and confirmed that the LoRA params are kept on CPU which causes this failure. In case of low_cpu_mem_usage=False the LoRA parameters are on the expected device ("cuda" in the above example).

I further investigated why the tests added in this PR don't fail. That is because the state dict we're supplying to set_peft_model_state_dict() (here) -- the tensors of that state dict are already on the desired device. When I forcibly changed their device to a CPU and ran the tests on a GPU, the tests failed, and they complained about the same thing.

Possible to look into this?

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Sep 27, 2024

Hmm, I could not reproduce the issue :-/ I had to change the code slightly due to memory constraints:

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map="balanced", max_memory={0: "24GB", 1: "20GB"},
)
# after loading the LoRA adapter {p.device for p in pipe.transformer.parameters()} returns:
# {device(type='cuda', index=0)}

Could this be the reason why it works for me?

I also tried this with a normal PEFT model that I moved to CUDA and then loaded with low_cpu_mem_usage and it worked.

@sayakpaul
Copy link
Member Author

Yeah you need to be on the exact same setup to replicate this. We cannot assume people will do load_lora_weights() only in a specific manner.

You can perhaps use an SD LoRA:

def test_a1111(self):

@BenjaminBossan
Copy link
Member

Quick update, I couldn't reproduce with that model:

import torch
from diffusers import FluxPipeline, StableDiffusionPipeline

generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to("cuda")
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, low_cpu_mem_usage=True)
images = pipe(
    "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

@sayakpaul
Copy link
Member Author

sayakpaul commented Sep 28, 2024

@BenjaminBossan here's a minimal reproduction:

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("sayakpaul/tiny-flux-pipeline-with-lora", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights(
    "sayakpaul/tiny-flux-pipeline-with-lora", weight_name="pytorch_lora_weights.bin", low_cpu_mem_usage=True
)

prompt = "sandor clegane drinking in a pub"
image = pipe(prompt=prompt, num_inference_steps=30).images[0]
image.save("sandor.png")

I am on peft:main.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Sep 30, 2024
See: huggingface/diffusers#9510 (comment)

Right now, the low_cpu_mem_usage=True option does not consolidate the
devices. E.g. when the model is on GPU and the state_dict on CPU, the
adapter weight will be on CPU after loading, when it should be GPU. This
fix ensures that the devices are consolidated.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Sep 30, 2024
See: huggingface/diffusers#9510 (comment)

Right now, the low_cpu_mem_usage=True option does not consolidate the
devices. E.g. when the model is on GPU and the state_dict on CPU, the
adapter weight will be on CPU after loading, when it should be GPU. This
fix ensures that the devices are consolidated.
@BenjaminBossan
Copy link
Member

Okay, got it now, thanks for the memory-friendly reproducer.

Indeed, if the LoRA weights on the model are on meta device, the device will be taken from the state_dict, not the base layer. I worked on a fix: huggingface/peft#2113

For the time being, you could add this snippet and it should fix the issue:

    if low_cpu_mem_usage:
        for module in model.modules():
            if hasattr(module, "_move_adapter_to_device_of_base_layer"):
                module._move_adapter_to_device_of_base_layer(adapter_name)

BenjaminBossan added a commit to huggingface/peft that referenced this pull request Oct 2, 2024
See: huggingface/diffusers#9510 (comment)

Right now, the low_cpu_mem_usage=True option does not consolidate the
devices. E.g. when the model is on GPU and the state_dict on CPU, the
adapter weight will be on CPU after loading, when it should be GPU. This
fix ensures that the devices are consolidated.
ArthurZucker pushed a commit to BenjaminBossan/transformers that referenced this pull request Oct 3, 2024
PEFT added support for low_cpu_mem_usage=True when loading adapters in
huggingface/peft#1961. This feature is now
available when installing PEFT v0.13.0. With this PR, this option is
also supported when loading PEFT adapters directly into transformers
models.

Additionally, with this PR,
huggingface/diffusers#9510 will be unblocked,
which implements this option in diffusers.
ArthurZucker pushed a commit to huggingface/transformers that referenced this pull request Oct 3, 2024
…3725)

* [PEFT] Support low_cpu_mem_usage for PEFT loading

PEFT added support for low_cpu_mem_usage=True when loading adapters in
huggingface/peft#1961. This feature is now
available when installing PEFT v0.13.0. With this PR, this option is
also supported when loading PEFT adapters directly into transformers
models.

Additionally, with this PR,
huggingface/diffusers#9510 will be unblocked,
which implements this option in diffusers.

* Fix typo
@sayakpaul
Copy link
Member Author

@BenjaminBossan could you give this a review?

@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 8, 2024

@BenjaminBossan thanks!

Docstrings still contain TODO

Yeah will be resolved after Yiyi's approval.

Since PEFT version v0.13.1 is now released, the min PEFT version should be updated accordingly.

That's done.

In case you plan on no longer supporting older PEFT and transformer versions in the future: I would add a TODO comment to all those version checks that they can be removed once support for those older versions is dropped. If you plan to support them indefinitely, ignore this comment.

For now, we can ignore it.

An entry to the diffusers PEFT docs would be nice to have, especially since the name of the argument is not really intuitive.

Done.

@sayakpaul
Copy link
Member Author

@yiyixuxu could you review this PR?

After the approval, will add docs and request for a review from Steven.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I left one comment, otherwise PR looks good to me

tests/lora/utils.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

Thanks, @yiyixuxu!

I have taken care of the TODOs in the docs too. @stevhliu could you review the related changes?

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just one minor change that needs to be propagated :)

docs/source/en/tutorials/using_peft_for_inference.md Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
@sayakpaul sayakpaul merged commit 31058cd into main Oct 9, 2024
18 checks passed
@sayakpaul sayakpaul deleted the low-cpu-mem-usage-lora branch October 9, 2024 05:27
leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
…9510)

* allow loras to be loaded with low_cpu_mem_usage.

* add flux support but note https://github.com/huggingface/diffusers/pull/9510\#issuecomment-2378316687

* low_cpu_mem_usage.

* fix-copies

* fix-copies again

* tests

* _LOW_CPU_MEM_USAGE_DEFAULT_LORA

* _peft_version default.

* version checks.

* version check.

* version check.

* version check.

* require peft 0.13.1.

* explicitly specify low_cpu_mem_usage=False.

* docs.

* transformers version 4.45.2.

* update

* fix

* empty

* better name initialize_dummy_state_dict.

* doc todos.

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* style

* fix-copies

---------

Co-authored-by: Steven Liu <[email protected]>
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
…ggingface#33725)

* [PEFT] Support low_cpu_mem_usage for PEFT loading

PEFT added support for low_cpu_mem_usage=True when loading adapters in
huggingface/peft#1961. This feature is now
available when installing PEFT v0.13.0. With this PR, this option is
also supported when loading PEFT adapters directly into transformers
models.

Additionally, with this PR,
huggingface/diffusers#9510 will be unblocked,
which implements this option in diffusers.

* Fix typo
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 22, 2024
See: huggingface/diffusers#9510 (comment)

Right now, the low_cpu_mem_usage=True option does not consolidate the
devices. E.g. when the model is on GPU and the state_dict on CPU, the
adapter weight will be on CPU after loading, when it should be GPU. This
fix ensures that the devices are consolidated.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants