Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of Xlabs Controlnets #9638

Closed
wants to merge 7 commits into from

Conversation

Anghellia
Copy link
Contributor

@Anghellia Anghellia commented Oct 10, 2024

What does this PR do?

Hi!
This PR brings support of Xlabs Controlnets, so it can be used with Diffusers.
We converted checkpoints to the Diffusers format and it can be downloaded here:

The request: #9378

Who can review?

Anyone in the community is free to review the PR once the tests have passed.
@sayakpaul

How to use

Here is the example of code to launch Canny Controlnet.

import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from PIL import Image
import numpy as np

generator = torch.Generator(device="cuda").manual_seed(87544357)

controlnet = FluxControlNetModel.from_pretrained(
  "Xlabs-AI/flux-controlnet-canny-diffusers",
  torch_dtype=torch.bfloat16,
  use_safetensors=True,
)
pipe = FluxControlNetPipeline.from_pretrained(
  "black-forest-labs/FLUX.1-dev",
  controlnet=controlnet,
  torch_dtype=torch.bfloat16
)
pipe.to("cuda")

control_image = load_image("https://huggingface.co/Xlabs-AI/flux-controlnet-canny-diffusers/resolve/main/canny_example.png")
prompt = "handsome girl with rainbow hair, anime"

image = pipe(
    prompt,
    control_image=control_image,
    controlnet_conditioning_scale=0.7,
    num_inference_steps=25,
    guidance_scale=3.5,
    height=1024,
    width=768,
    generator=generator,
    num_images_per_prompt=1,
).images[0]

image.save("output_test_controlnet.png")

Examples

Group 1
"photo of village in the winter"

Group 2
"it programmer sitting in the office"

Group 3
"couple of man and woman in the water, dancing"

Group 4
"photo of woman in the beach"

Group 5
"futuristic bulding in the spain"

Group 6
"2d art, girl in the magic city, sparkles, fantasy"

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments that you can address/ignore based on what Sayak has to say

src/diffusers/models/controlnet_flux.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_flux.py Outdated Show resolved Hide resolved
@@ -773,6 +773,17 @@ def __call__(
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)

elif isinstance(self.controlnet, FluxControlNetModel) and self.controlnet.is_xlabs_controlnet:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these changes will have to be propagated to other pipeline files as well? Maybe you could rewrite this as:

if isinstance(self.controlnet):
    control_image = self.prepare_image(...)
    
    if self.controlnet.is_xlabs_controlnet:
        # remaining mismatching logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check it now.

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul October 10, 2024 21:58
@sayakpaul
Copy link
Member

@Anghellia thanks so much for this <3

Could you supplement this PR with an example code snippet and some resultant images? Ccing @asomoza for doing a test drive, too.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful PR.

src/diffusers/models/transformers/transformer_flux.py Outdated Show resolved Hide resolved
@@ -55,6 +56,7 @@ def __init__(
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
is_xlabs_controlnet: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to determine if a ControlNet is of type xlabs? If not, then it's fine!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I am not sure. Actually, xlabs ControlNets are not so different from others. I see two main changes:

  1. We use num_layers=2 (the depth of ControlNet is 2). However, I don't think it's correct to rely solely on this, as one could also train a ControlNet with num_layers=2 using the diffusers script.

  2. We use input_hint_block, but we can specify this only if we check for a specific keyword in the model's state_dict. I think this may not apply in our case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe both could be combined to create a condition to determine if it's an Xlabs ControlNet? @yiyixuxu would love to know your thoughts here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind I think #9638 (comment) should cut the deal for us unless there's some differences in the forward method.

@Anghellia
Copy link
Contributor Author

@Anghellia thanks so much for this <3

Could you supplement this PR with an example code snippet and some resultant images? Ccing @asomoza for doing a test drive, too.

Thank you! Updated the PR with examples 🤗

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for the PR! I left some suggestions, let me know if they would work!

@@ -55,6 +55,7 @@ def __init__(
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
is_xlabs_controlnet: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_xlabs_controlnet: bool = False,
conditioning_embedding_channels: int = None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a new config conditioning_embedding_channels to the flux controlnet that defaults to None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion, it works!

src/diffusers/models/controlnet_flux.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_flux.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_flux.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_flux.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py Outdated Show resolved Hide resolved
@RimoChan
Copy link

RimoChan commented Oct 14, 2024

I encountered an issue where the FluxMultiControlNetModel is not compatible with this branch. I used the following command to install diffusers:

pip install git+https://github.com/XLabs-AI/diffusers.git@xlabs_controlnet_support

After installation, I tried executing the following code:

import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxMultiControlNetModel
from diffusers.pipelines import FluxControlNetPipeline



generator = torch.Generator(device="cuda").manual_seed(87544357)

controlnet = FluxMultiControlNetModel([
    FluxControlNetModel.from_pretrained(
        "Xlabs-AI/flux-controlnet-canny-diffusers",
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
    ),
    FluxControlNetModel.from_pretrained(
        "Xlabs-AI/flux-controlnet-canny-diffusers",
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
    ),
])
pipe = FluxControlNetPipeline.from_pretrained(
  '/mypath/to/FLUX.1-dev_official',
  controlnet=controlnet,
  torch_dtype=torch.bfloat16
)
pipe.to("cuda")

control_image = load_image("https://huggingface.co/Xlabs-AI/flux-controlnet-canny-diffusers/resolve/main/canny_example.png")

image = pipe(
    "handsome girl with rainbow hair, anime",
    control_image=[control_image, control_image],
    controlnet_conditioning_scale=[0.7, 0.7],
    num_inference_steps=25,
    guidance_scale=3.5,
    height=1024,
    width=768,
    generator=generator,
    num_images_per_prompt=1,
).images[0]

image.save("output_test_controlnet.png")

However, I encountered this error:

Traceback (most recent call last):
  File "/opt/tiger/test_1/t2s.py", line 33, in <module>
    image = pipe(
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 897, in __call__
    controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flux.py", line 503, in forward
    block_samples, single_block_samples = controlnet(
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flux.py", line 281, in forward
    controlnet_cond = self.input_hint_block(controlnet_cond)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/models/controlnet.py", line 99, in forward
    embedding = self.conv_in(conditioning)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[1, 1, 3072, 64] to have 3 channels, but got 1 channels instead

Could you please investigate this incompatibility?

@sayakpaul
Copy link
Member

MultiControlNet compatibility can be incorporated after this initial PR is merged.

@Anghellia Anghellia requested a review from yiyixuxu October 14, 2024 09:07
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I left one more feedbacks
let's merge this soon!

@@ -508,7 +508,11 @@ def custom_forward(*inputs):
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
# For Xlabs ControlNet.
if len(controlnet_block_samples) == 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest passing a flag down here, controlnet_repeat_interleave = False maybe?
This would break if someone trained a controlnet with 2 blocks but want to use the other indexing method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree
Updated with controlnet_blocks_repeat flag

@Anghellia
Copy link
Contributor Author

I encountered an issue where the FluxMultiControlNetModel is not compatible with this branch. I used the following command to install diffusers:

pip install git+https://github.com/XLabs-AI/diffusers.git@xlabs_controlnet_support

After installation, I tried executing the following code:

import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxMultiControlNetModel
from diffusers.pipelines import FluxControlNetPipeline



generator = torch.Generator(device="cuda").manual_seed(87544357)

controlnet = FluxMultiControlNetModel([
    FluxControlNetModel.from_pretrained(
        "Xlabs-AI/flux-controlnet-canny-diffusers",
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
    ),
    FluxControlNetModel.from_pretrained(
        "Xlabs-AI/flux-controlnet-canny-diffusers",
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
    ),
])
pipe = FluxControlNetPipeline.from_pretrained(
  '/mypath/to/FLUX.1-dev_official',
  controlnet=controlnet,
  torch_dtype=torch.bfloat16
)
pipe.to("cuda")

control_image = load_image("https://huggingface.co/Xlabs-AI/flux-controlnet-canny-diffusers/resolve/main/canny_example.png")

image = pipe(
    "handsome girl with rainbow hair, anime",
    control_image=[control_image, control_image],
    controlnet_conditioning_scale=[0.7, 0.7],
    num_inference_steps=25,
    guidance_scale=3.5,
    height=1024,
    width=768,
    generator=generator,
    num_images_per_prompt=1,
).images[0]

image.save("output_test_controlnet.png")

However, I encountered this error:

Traceback (most recent call last):
  File "/opt/tiger/test_1/t2s.py", line 33, in <module>
    image = pipe(
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 897, in __call__
    controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flux.py", line 503, in forward
    block_samples, single_block_samples = controlnet(
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flux.py", line 281, in forward
    controlnet_cond = self.input_hint_block(controlnet_cond)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/diffusers/models/controlnet.py", line 99, in forward
    embedding = self.conv_in(conditioning)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/tiger/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[1, 1, 3072, 64] to have 3 channels, but got 1 channels instead

Could you please investigate this incompatibility?

Fixed

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is look good to merge now after @yiyixuxu gives a final review!

For the failing style tests, could you run make style and push? Thanks

@sayakpaul
Copy link
Member

We could add some tests in a follow-up PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Anghellia
Copy link
Contributor Author

@a-r-r-o-w please launch tests

yiyixuxu added a commit that referenced this pull request Oct 15, 2024
* Add support of Xlabs Controlnets


---------

Co-authored-by: Anzhella Pankratova <[email protected]>
@yiyixuxu
Copy link
Collaborator

hey thanks for the PR!
I merged it in here #9687 since I cannot push into your PR
It is branched off your PR so all your commits are there and you're an author there:)

@yiyixuxu
Copy link
Collaborator

closing the PR now since we already merged it!

@yiyixuxu yiyixuxu closed this Oct 19, 2024
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Add support of Xlabs Controlnets


---------

Co-authored-by: Anzhella Pankratova <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants