Skip to content

Commit

Permalink
add how to dynamically load LoRA in README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 13, 2023
1 parent 958d2d0 commit c04410f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy
- [Install From Source](#install-from-source)
- [Usage](#usage)
- [Optimize StableDiffusionPipeline](#optimize-stablediffusionpipeline)
- [Dynamically Load LoRA](#dynamically-load-lora)
- [Some Common Methods To Speed Up PyTorch](#some-common-methods-to-speed-up-pytorch)
- [Trouble Shooting](#trouble-shooting)

Expand Down Expand Up @@ -249,6 +250,42 @@ output_image = compiled_model(**kwarg_inputs).images[0]
output_image = compiled_model(**kwarg_inputs).images[0]
```

### Dynamically Load LoRA

Loading LoRA dynamically is supported but you need to do some extra work.
It is possible because the compiled graph and `CUDA Graph` shares the same
underlaying data (pointers) with the original UNet model. So all you need to do
is to update the original UNet model's parameters inplace.

The following code assumes you have already load a LoRA and compiled the model,
and you want to switch to another LoRA.

```python
def update_state_dict(dst, src):
for key, value in src.items():
# Do inplace copy.
# As the traced forward function shares the same reference of the tensors,
# this modification will be reflected in the traced forward function.
dst[key].copy_(value)
return dst

# Load "another" lora into UNet
def load_new_lora(unet, lora):
# Store the original UNet parameters
state_dict = unet.state_dict()
# Load another lora into unet
unet.load_attn_procs(lora)
# Inplace copy current UNet parameters to the original unet parameters
state_dict = update_state_dict(state_dict, unet.state_dict())
# Load the original UNet parameters back.
# We use assign=True because we still want to hold the references
# of the original UNet parameters
unet.load_state_dict(state_dict, assign=True)
return unet

compiled_model.uent = load_new_lora(compiled_model.unet, '/path/to/lora')
```

### Some Common Methods To Speed Up PyTorch

```bash
Expand Down
2 changes: 1 addition & 1 deletion tests/compilers/test_stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def load_new_lora(unet, lora):
# Store the original UNet parameters
state_dict = unet.state_dict()
# Load another lora into unet
unet.load_attn_procs(lora_b_path)
unet.load_attn_procs(lora)
# Inplace copy current UNet parameters to the original unet parameters
state_dict = update_state_dict(state_dict, unet.state_dict())
# Load the original UNet parameters back.
Expand Down

0 comments on commit c04410f

Please sign in to comment.