Skip to content

Commit

Permalink
Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers (#…
Browse files Browse the repository at this point in the history
…9915)

* Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers to fix bug on updating keys and old_state_dict


---------

Co-authored-by: raul_ar <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent 1235862 commit 3139d39
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,15 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.single_transformer_blocks.{block_num}"

if "proj_lora1" in old_key or "proj_lora2" in old_key:
if "proj_lora" in old_key:
new_key += ".proj_out"
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
new_key += ".norm.linear"
elif "qkv_lora" in old_key and "up" not in old_key:
handle_qkv(
old_state_dict,
new_state_dict,
old_key,
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
)

if "down" in old_key:
new_key += ".lora_A.weight"
Expand Down
25 changes: 25 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,28 @@ def test_flux_xlabs(self):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_xlabs_load_lora_with_single_blocks(self):
self.pipeline.load_lora_weights(
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()

prompt = "a wizard mouse playing chess"

out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array(
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

0 comments on commit 3139d39

Please sign in to comment.