diff --git a/README.md b/README.md index 8159de0..f3f56f6 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy - [Install From Source](#install-from-source) - [Usage](#usage) - [Optimize StableDiffusionPipeline](#optimize-stablediffusionpipeline) - - [Dynamically Load LoRA](#dynamically-load-lora) + - [Dynamically Switch LoRA](#dynamically-switch-lora) - [Some Common Methods To Speed Up PyTorch](#some-common-methods-to-speed-up-pytorch) - [Trouble Shooting](#trouble-shooting) @@ -250,9 +250,9 @@ output_image = compiled_model(**kwarg_inputs).images[0] output_image = compiled_model(**kwarg_inputs).images[0] ``` -### Dynamically Load LoRA +### Dynamically Switch LoRA -Loading LoRA dynamically is supported but you need to do some extra work. +Switching LoRA dynamically is supported but you need to do some extra work. It is possible because the compiled graph and `CUDA Graph` shares the same underlaying data (pointers) with the original UNet model. So all you need to do is to update the original UNet model's parameters inplace. @@ -267,23 +267,21 @@ def update_state_dict(dst, src): # As the traced forward function shares the same reference of the tensors, # this modification will be reflected in the traced forward function. dst[key].copy_(value) - return dst -# Load "another" lora into UNet -def load_new_lora(unet, lora): +# Switch "another" LoRA into UNet +def switch_lora(unet, lora): # Store the original UNet parameters state_dict = unet.state_dict() - # Load another lora into unet + # Load another LoRA into unet unet.load_attn_procs(lora) # Inplace copy current UNet parameters to the original unet parameters - state_dict = update_state_dict(state_dict, unet.state_dict()) + update_state_dict(state_dict, unet.state_dict()) # Load the original UNet parameters back. # We use assign=True because we still want to hold the references # of the original UNet parameters unet.load_state_dict(state_dict, assign=True) - return unet -compiled_model.uent = load_new_lora(compiled_model.unet, '/path/to/lora') +switch_lora(compiled_model.unet, lora_b_path) ``` ### Some Common Methods To Speed Up PyTorch diff --git a/tests/compilers/test_stable_diffusion_pipeline_compiler.py b/tests/compilers/test_stable_diffusion_pipeline_compiler.py index 14da3c2..fd0c167 100644 --- a/tests/compilers/test_stable_diffusion_pipeline_compiler.py +++ b/tests/compilers/test_stable_diffusion_pipeline_compiler.py @@ -281,23 +281,21 @@ def update_state_dict(dst, src): # As the traced forward function shares the same reference of the tensors, # this modification will be reflected in the traced forward function. dst[key].copy_(value) - return dst - # Load "another" lora into UNet - def load_new_lora(unet, lora): + # Switch "another" LoRA into UNet + def switch_lora(unet, lora): # Store the original UNet parameters state_dict = unet.state_dict() - # Load another lora into unet + # Load another LoRA into unet unet.load_attn_procs(lora) # Inplace copy current UNet parameters to the original unet parameters - state_dict = update_state_dict(state_dict, unet.state_dict()) + update_state_dict(state_dict, unet.state_dict()) # Load the original UNet parameters back. # We use assign=True because we still want to hold the references # of the original UNet parameters unet.load_state_dict(state_dict, assign=True) - return unet - compiled_model.unet = load_new_lora(compiled_model.unet, lora_b_path) + switch_lora(compiled_model.unet, lora_b_path) output_image = profiler.with_cProfile(call_faster_compiled_model)() display_image(output_image)