Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert From XLab controlnet to diffusers format script #9301

Closed
chuck-ma opened this issue Aug 28, 2024 · 6 comments
Closed

Convert From XLab controlnet to diffusers format script #9301

chuck-ma opened this issue Aug 28, 2024 · 6 comments

Comments

@chuck-ma
Copy link

chuck-ma commented Aug 28, 2024

Is your feature request related to a problem? Please describe.
convert from XLab controlnet to diffusers format

Describe the solution you'd like.
convert from XLab controlnet to diffusers format

Additional context.

I use the script as below:

import torch
import safetensors.torch
from huggingface_hub import hf_hub_download

def convert_flux_transformer_checkpoint_to_diffusers(
    original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
):
    converted_state_dict = {}

    ## time_text_embed.timestep_embedder <-  time_in
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
        "time_in.in_layer.weight"
    )
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
        "time_in.in_layer.bias"
    )
    converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
        "time_in.out_layer.weight"
    )
    converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
        "time_in.out_layer.bias"
    )

    ## time_text_embed.text_embedder <- vector_in
    converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
        "vector_in.in_layer.weight"
    )
    converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
        "vector_in.in_layer.bias"
    )
    converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
        "vector_in.out_layer.weight"
    )
    converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
        "vector_in.out_layer.bias"
    )

    # guidance
    has_guidance = any("guidance" in k for k in original_state_dict)
    if has_guidance:
        converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
            "guidance_in.in_layer.weight"
        )
        converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
            "guidance_in.in_layer.bias"
        )
        converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
            "guidance_in.out_layer.weight"
        )
        converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
            "guidance_in.out_layer.bias"
        )

    # context_embedder
    converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
    converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")

    # x_embedder
    converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
    converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")

    # double transformer blocks
    for i in range(num_layers):
        block_prefix = f"transformer_blocks.{i}."
        # norms.
        ## norm1
        converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.img_mod.lin.weight"
        )
        converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.img_mod.lin.bias"
        )
        ## norm1_context
        converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_mod.lin.weight"
        )
        converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_mod.lin.bias"
        )
        # Q, K, V
        sample_q, sample_k, sample_v = torch.chunk(
            original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
        )
        context_q, context_k, context_v = torch.chunk(
            original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
        )
        sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
            original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
        )
        context_q_bias, context_k_bias, context_v_bias = torch.chunk(
            original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
        converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
        converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
        converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
        converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
        converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
        converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
        converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
        converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
        converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
        # qk_norm
        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.img_attn.norm.query_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.img_attn.norm.key_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
        )
        # ff img_mlp
        converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.img_mlp.0.weight"
        )
        converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.img_mlp.0.bias"
        )
        converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.img_mlp.2.weight"
        )
        converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.img_mlp.2.bias"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_mlp.0.weight"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_mlp.0.bias"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_mlp.2.weight"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_mlp.2.bias"
        )
        # output projections.
        converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.img_attn.proj.weight"
        )
        converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.img_attn.proj.bias"
        )
        converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_attn.proj.weight"
        )
        converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
            f"double_blocks.{i}.txt_attn.proj.bias"
        )
        
        converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict.pop(
            f"controlnet_blocks.{i}.weight"
        )
        converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict.pop(
            f"controlnet_blocks.{i}.bias"
        )

    # single transfomer blocks
    for i in range(num_single_layers):
        pass




    return converted_state_dict

Here is how i convert to diffusers format

import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel

import torch
import safetensors.torch
from huggingface_hub import hf_hub_download

import torch
import safetensors.torch
from huggingface_hub import hf_hub_download


ckpt_path = hf_hub_download("XLabs-AI/flux-controlnet-depth-v3", 
                            filename="flux-depth-controlnet-v3.safetensors")

original_state_dict = safetensors.torch.load_file(ckpt_path)

num_layers = 2
num_single_layers = 0
inner_dim = 3072
mlp_ratio = 4.0
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
            original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
        )
print("begin init")

controlnet = FluxControlNetModel(guidance_embeds=True, num_single_layers=0, num_layers=2,
                                 
                                ).to(torch.float16)
print("finish init")
controlnet.load_state_dict(converted_transformer_state_dict, strict=False)
print("finish load state dict")

Then i load directly in diffusers api. It works. However, the control of the depth map seems ineffective.
Wonder why.

base_model = "black-forest-labs/FLUX.1-dev"

pipe = FluxControlNetPipeline.from_pretrained(base_model, 

                                              controlnet=controlnet, 
                                              torch_dtype=torch.bfloat16)

pipe.enable_model_cpu_offload()

control_image = load_image("https://hf-mirror.com/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/depth.jpg")

prompt = "Realistic style, a photo of a girl  in the sea"

width, height = 1024, 1024
controlnet_conditioning_scale = 1.0


image = pipe(
    prompt, 
    control_image=control_image,
    width=width,
    height=height,
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    num_inference_steps=28, 
    guidance_scale=3.5,
).images[0]

from diffusers.utils import load_image, make_image_grid


make_image_grid([control_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)

output like this:

image

Any advice @sayakpaul

@sayakpaul
Copy link
Member

sayakpaul commented Aug 28, 2024

Are you using the identical settings from the original checkpoints i.e., same guidance scale, number of inference steps, etc.? I don't see which model checkpoint you're exactly using, though.

Ccing @DN6. It might make sense to have this supported through from_single_file().

@chuck-ma
Copy link
Author

chuck-ma commented Aug 29, 2024

Are you using the identical settings from the original checkpoints i.e., same guidance scale, number of inference steps, etc.? I don't see which model checkpoint you're exactly using, though.

Ccing @DN6. It might make sense to have this supported through from_single_file().

I'm using this model: https://huggingface.co/XLabs-AI/flux-controlnet-depth-v3
I will check if the settings are identical.

Besides, I'm converting like below:

import torch
import safetensors.torch
from huggingface_hub import hf_hub_download


ckpt_path = hf_hub_download("XLabs-AI/flux-controlnet-depth-v3", 
                            filename="flux-depth-controlnet-v3.safetensors")

original_state_dict = safetensors.torch.load_file(ckpt_path)

num_layers = 2
num_single_layers = 0
inner_dim = 3072
mlp_ratio = 4.0
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
            original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
        )

@chenbinghui1
Copy link

chenbinghui1 commented Aug 29, 2024

@chuck-ma Please check the forward code, the processing of XLAB-controlnet is different from diffusers, especially at "controlnet residual" part

@sayakpaul
Copy link
Member

@chenbinghui1 thanks for your insights. Would you like to take a stab at opening a PR to mitigate the differences? We'd, of course, be more than happy to provide guidance in the process.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 27, 2024
@a-r-r-o-w a-r-r-o-w removed the stale Issues that haven't received updates label Oct 15, 2024
@a-r-r-o-w
Copy link
Member

Supported in #9687, thanks to @Anghellia! Marking the issue as closed, but feel free to open if something's missing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants