Skip to content

Commit

Permalink
fix README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 13, 2023
1 parent c04410f commit 3e31fa4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy
- [Install From Source](#install-from-source)
- [Usage](#usage)
- [Optimize StableDiffusionPipeline](#optimize-stablediffusionpipeline)
- [Dynamically Load LoRA](#dynamically-load-lora)
- [Dynamically Switch LoRA](#dynamically-switch-lora)
- [Some Common Methods To Speed Up PyTorch](#some-common-methods-to-speed-up-pytorch)
- [Trouble Shooting](#trouble-shooting)

Expand Down Expand Up @@ -250,9 +250,9 @@ output_image = compiled_model(**kwarg_inputs).images[0]
output_image = compiled_model(**kwarg_inputs).images[0]
```

### Dynamically Load LoRA
### Dynamically Switch LoRA

Loading LoRA dynamically is supported but you need to do some extra work.
Switching LoRA dynamically is supported but you need to do some extra work.
It is possible because the compiled graph and `CUDA Graph` shares the same
underlaying data (pointers) with the original UNet model. So all you need to do
is to update the original UNet model's parameters inplace.
Expand All @@ -267,23 +267,21 @@ def update_state_dict(dst, src):
# As the traced forward function shares the same reference of the tensors,
# this modification will be reflected in the traced forward function.
dst[key].copy_(value)
return dst

# Load "another" lora into UNet
def load_new_lora(unet, lora):
# Switch "another" LoRA into UNet
def switch_lora(unet, lora):
# Store the original UNet parameters
state_dict = unet.state_dict()
# Load another lora into unet
# Load another LoRA into unet
unet.load_attn_procs(lora)
# Inplace copy current UNet parameters to the original unet parameters
state_dict = update_state_dict(state_dict, unet.state_dict())
update_state_dict(state_dict, unet.state_dict())
# Load the original UNet parameters back.
# We use assign=True because we still want to hold the references
# of the original UNet parameters
unet.load_state_dict(state_dict, assign=True)
return unet

compiled_model.uent = load_new_lora(compiled_model.unet, '/path/to/lora')
switch_lora(compiled_model.unet, lora_b_path)
```

### Some Common Methods To Speed Up PyTorch
Expand Down
12 changes: 5 additions & 7 deletions tests/compilers/test_stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,23 +281,21 @@ def update_state_dict(dst, src):
# As the traced forward function shares the same reference of the tensors,
# this modification will be reflected in the traced forward function.
dst[key].copy_(value)
return dst

# Load "another" lora into UNet
def load_new_lora(unet, lora):
# Switch "another" LoRA into UNet
def switch_lora(unet, lora):
# Store the original UNet parameters
state_dict = unet.state_dict()
# Load another lora into unet
# Load another LoRA into unet
unet.load_attn_procs(lora)
# Inplace copy current UNet parameters to the original unet parameters
state_dict = update_state_dict(state_dict, unet.state_dict())
update_state_dict(state_dict, unet.state_dict())
# Load the original UNet parameters back.
# We use assign=True because we still want to hold the references
# of the original UNet parameters
unet.load_state_dict(state_dict, assign=True)
return unet

compiled_model.unet = load_new_lora(compiled_model.unet, lora_b_path)
switch_lora(compiled_model.unet, lora_b_path)

output_image = profiler.with_cProfile(call_faster_compiled_model)()
display_image(output_image)
Expand Down

0 comments on commit 3e31fa4

Please sign in to comment.