Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Control LoRA #9999

Merged
merged 79 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
2829679
update
a-r-r-o-w Nov 21, 2024
be67dbd
Merge branch 'main' into flux-new
a-r-r-o-w Nov 21, 2024
f56ffb1
add
yiyixuxu Nov 21, 2024
7e4df06
update
a-r-r-o-w Nov 21, 2024
9ea52da
Merge remote-tracking branch 'origin/flux-fill-yiyi' into flux-new
a-r-r-o-w Nov 21, 2024
217e90c
add control-lora conversion script; make flux loader handle norms; fi…
a-r-r-o-w Nov 21, 2024
b4f1cbf
control lora updates
a-r-r-o-w Nov 22, 2024
414b30b
remove copied-from
a-r-r-o-w Nov 22, 2024
6b02ac2
create separate pipelines for flux control
a-r-r-o-w Nov 22, 2024
3169bf5
make fix-copies
a-r-r-o-w Nov 22, 2024
f7f006d
update docs
a-r-r-o-w Nov 22, 2024
8bb940e
add tests
a-r-r-o-w Nov 22, 2024
9e615fd
fix
a-r-r-o-w Nov 22, 2024
6d168db
Merge branch 'main' into flux-new
a-r-r-o-w Nov 22, 2024
89fd970
Apply suggestions from code review
a-r-r-o-w Nov 22, 2024
73cfc51
remove control lora changes
a-r-r-o-w Nov 22, 2024
c94966f
apply suggestions from review
a-r-r-o-w Nov 22, 2024
cfe13e7
Revert "remove control lora changes"
a-r-r-o-w Nov 22, 2024
0c959a7
update
a-r-r-o-w Nov 23, 2024
6ef2c8b
update
a-r-r-o-w Nov 23, 2024
42970ee
improve log messages
a-r-r-o-w Nov 23, 2024
2ec93ba
Merge branch 'main' into flux-control-lora
a-r-r-o-w Nov 23, 2024
993f3d3
Merge branch 'main' into flux-control-lora
sayakpaul Nov 25, 2024
6523fa6
updates.
sayakpaul Nov 25, 2024
81ab40b
updates
sayakpaul Nov 25, 2024
4432e73
Merge branch 'main' into flux-control-lora
sayakpaul Nov 25, 2024
0f747c0
Merge branch 'main' into flux-control-lora
sayakpaul Nov 26, 2024
6d0c6dc
Merge branch 'flux-control-lora' into sayak-flux-control-lora
sayakpaul Nov 26, 2024
1633619
support register_config.
sayakpaul Nov 26, 2024
b9039b1
fix
sayakpaul Nov 26, 2024
5f94d74
fix
sayakpaul Nov 26, 2024
bd31651
fix
sayakpaul Nov 26, 2024
e18b7ad
Merge branch 'main' into flux-control-lora
a-r-r-o-w Nov 27, 2024
f54ec56
updates
sayakpaul Nov 28, 2024
8032405
updates
sayakpaul Nov 28, 2024
6b70bf7
updates
sayakpaul Nov 28, 2024
3726e2d
fix-copies
sayakpaul Nov 28, 2024
b6ca9d9
Merge branch 'main' into flux-control-lora
sayakpaul Nov 28, 2024
908d151
fix
sayakpaul Nov 29, 2024
6af2097
Merge branch 'main' into flux-control-lora
sayakpaul Nov 29, 2024
07d44e7
apply suggestions from review
a-r-r-o-w Dec 1, 2024
b66e691
add tests
a-r-r-o-w Dec 1, 2024
66d7466
remove conversion script; enable on-the-fly conversion
a-r-r-o-w Dec 2, 2024
d827d1e
Merge branch 'main' into flux-control-lora
a-r-r-o-w Dec 2, 2024
64c821b
bias -> lora_bias.
sayakpaul Dec 2, 2024
30a89a6
fix-copies
sayakpaul Dec 2, 2024
bca1eaa
peft.py
sayakpaul Dec 2, 2024
6ce181b
Merge branch 'main' into flux-control-lora
sayakpaul Dec 2, 2024
e7df197
fix lora conversion
a-r-r-o-w Dec 2, 2024
5fd9fda
changes
sayakpaul Dec 3, 2024
a8c50ba
fix-copies
sayakpaul Dec 3, 2024
b12f797
updates for tests
sayakpaul Dec 3, 2024
f9bd3eb
fix
sayakpaul Dec 3, 2024
6b35c92
Merge branch 'main' into flux-control-lora
sayakpaul Dec 3, 2024
84c168c
alpha_pattern.
sayakpaul Dec 4, 2024
118ed9b
Merge branch 'main' into flux-control-lora
sayakpaul Dec 4, 2024
be1d788
add a test for varied lora ranks and alphas.
sayakpaul Dec 4, 2024
5b1bcd8
revert changes in num_channels_latents = self.transformer.config.in_c…
sayakpaul Dec 4, 2024
cde01e3
revert moe
sayakpaul Dec 4, 2024
79af91d
Merge branch 'main' into flux-control-lora
sayakpaul Dec 5, 2024
4b3efcc
Merge branch 'main' into flux-control-lora
sayakpaul Dec 5, 2024
f688ecf
add a sanity check on unexpected keys when loading norm layers.
sayakpaul Dec 5, 2024
d6518b7
Merge branch 'main' into flux-control-lora
sayakpaul Dec 6, 2024
ecbc4cb
fixes
sayakpaul Dec 6, 2024
55058e2
tests
sayakpaul Dec 6, 2024
a8bd03b
reviewer feedback
sayakpaul Dec 6, 2024
49c0242
fix
sayakpaul Dec 6, 2024
8b050ea
proper peft version for lora_bias
sayakpaul Dec 6, 2024
3204627
fix-copies
sayakpaul Dec 6, 2024
2b9bfa3
Merge branch 'main' into flux-control-lora
a-r-r-o-w Dec 6, 2024
130e592
remove debug code
a-r-r-o-w Dec 6, 2024
b20ec7d
update docs
a-r-r-o-w Dec 6, 2024
79d023a
Merge branch 'main' into flux-control-lora
sayakpaul Dec 7, 2024
d1715d3
integration tests
sayakpaul Dec 7, 2024
cbad4b3
nis
sayakpaul Dec 7, 2024
cd7c155
fuse and unload.
sayakpaul Dec 7, 2024
25616e2
fix
sayakpaul Dec 7, 2024
0b83deb
add slices.
sayakpaul Dec 7, 2024
60a68e2
Merge branch 'main' into flux-control-lora
a-r-r-o-w Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,309 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

return new_state_dict


def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
num_layers = 19
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight

for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
converted_state_dict[
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")

converted_state_dict[
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")

## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
f"vector_in.in_layer.{lora_key}.weight"
)
if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
f"vector_in.in_layer.{lora_key}.bias"
)

converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
f"vector_in.out_layer.{lora_key}.weight"
)
if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
f"vector_in.out_layer.{lora_key}.bias"
)

# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
converted_state_dict[
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")

converted_state_dict[
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")

# context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
f"txt_in.{lora_key}.weight"
)
if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
f"txt_in.{lora_key}.bias"
)

# x_embedder
converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
if f"img_in.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")

# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."

for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]):
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
# norms
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
)
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
)
if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
)

# Q, K, V
if lora_key == "lora_A":
sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])

context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
else:
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])

context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])

if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])

if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])

# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
)
if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
)
if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
)
if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
)
if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
)

# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
)
if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
)
if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
)

# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)

# single transfomer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."

for lora_key in ["lora_A", "lora_B"]:
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
)
if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
)

# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)

if lora_key == "lora_A":
lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])

if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
else:
q, k, v, mlp = torch.split(
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])

if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
q_bias, k_bias, v_bias, mlp_bias = torch.split(
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])

# output projections.
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.{lora_key}.weight"
)
if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.{lora_key}.bias"
)

# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)

for lora_key in ["lora_A", "lora_B"]:
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"final_layer.linear.{lora_key}.weight"
)
if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"final_layer.linear.{lora_key}.bias"
)

converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
)
if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
)
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict
Loading
Loading