Skip to content

Commit

Permalink
[authored by @Anghellia) Add support of Xlabs Controlnets #9638 (#9687)
Browse files Browse the repository at this point in the history
* Add support of Xlabs Controlnets


---------

Co-authored-by: Anzhella Pankratova <[email protected]>
  • Loading branch information
yiyixuxu and Anghellia authored Oct 15, 2024
1 parent 2ffbb88 commit 3e9a28a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 32 deletions.
22 changes: 20 additions & 2 deletions src/diffusers/models/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
super().__init__()
self.out_channels = in_channels
Expand Down Expand Up @@ -106,7 +107,14 @@ def __init__(
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)

self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
if conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
)
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
else:
self.input_hint_block = None
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))

self.gradient_checkpointing = False

Expand Down Expand Up @@ -269,6 +277,16 @@ def forward(
)
hidden_states = self.x_embedder(hidden_states)

if self.input_hint_block is not None:
controlnet_cond = self.input_hint_block(controlnet_cond)
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
height = height_pw // self.config.patch_size
width = width_pw // self.config.patch_size
controlnet_cond = controlnet_cond.reshape(
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
)
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)

Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def forward(
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Expand Down Expand Up @@ -508,7 +509,13 @@ def custom_forward(*inputs):
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

Expand Down
63 changes: 34 additions & 29 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,19 +754,22 @@ def __call__(
)
height, width = control_image.shape[-2:]

# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None:
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
Expand All @@ -777,8 +780,9 @@ def __call__(

elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []

for control_image_ in control_image:
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
Expand All @@ -790,20 +794,20 @@ def __call__(
)
height, width = control_image_.shape[-2:]

# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

if self.controlnet.nets[0].input_hint_block is None:
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)

control_image = control_images
Expand Down Expand Up @@ -927,6 +931,7 @@ def __call__(
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]

# compute the previous noisy sample x_t -> x_t-1
Expand Down

0 comments on commit 3e9a28a

Please sign in to comment.