Skip to content

Commit

Permalink
[Single File] Support loading Comfy UI Flux checkpoints (#9243)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
DN6 authored Aug 23, 2024
1 parent 2d9ccf3 commit 255ac59
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -258,7 +261,7 @@
"timestep_spacing": "leading",
}

LDM_VAE_KEY = "first_stage_model."
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
Expand All @@ -267,7 +270,6 @@
"cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.",
]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]

Expand Down Expand Up @@ -523,8 +525,10 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "animatediff_v3"

elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
if "guidance_in.in_layer.bias" in checkpoint:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
model_type = "flux-dev"
else:
model_type = "flux-schnell"
Expand Down Expand Up @@ -1183,7 +1187,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
vae_state_dict = {}
keys = list(checkpoint.keys())
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
vae_key = ""
for ldm_vae_key in LDM_VAE_KEYS:
if any(k.startswith(ldm_vae_key) for k in keys):
vae_key = ldm_vae_key

for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
Expand Down Expand Up @@ -1896,6 +1904,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):

def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
Expand Down

0 comments on commit 255ac59

Please sign in to comment.