Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux controlnet control_guidance_start and control_guidance_end implement #9571

Merged
merged 9 commits into from
Oct 10, 2024
42 changes: 39 additions & 3 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
>>> image = pipe(
... prompt,
... control_image=control_image,
... controlnet_conditioning_scale=0.6,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=1.0,
... num_inference_steps=28,
... guidance_scale=3.5,
... ).images[0]
Expand Down Expand Up @@ -572,6 +574,8 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
ighoshsubho marked this conversation as resolved.
Show resolved Hide resolved
control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
Expand Down Expand Up @@ -614,6 +618,10 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
Expand Down Expand Up @@ -674,6 +682,17 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
Expand Down Expand Up @@ -839,7 +858,16 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# 6. Denoising loop
# 6. Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)

# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
Expand All @@ -856,12 +884,20 @@ def __call__(
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
... prompt,
... image=init_image,
... control_image=control_image,
... controlnet_conditioning_scale=0.6,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=1.0,
... strength=0.7,
... num_inference_steps=2,
... guidance_scale=3.5,
Expand Down Expand Up @@ -631,6 +633,8 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -710,6 +714,17 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

self.check_inputs(
prompt,
prompt_2,
Expand Down Expand Up @@ -862,6 +877,14 @@ def __call__(
latents,
)

controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)

num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

Expand All @@ -877,11 +900,19 @@ def __call__(
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
... image=init_image,
... mask_image=mask_image,
... control_image=control_image,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=0.7,
... strength=0.7,
... num_inference_steps=28,
Expand Down Expand Up @@ -737,6 +739,8 @@ def __call__(
timesteps: List[int] = None,
num_inference_steps: int = 28,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -783,6 +787,10 @@ def __call__(
Custom timesteps to use for the denoising process.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_mode (`int` or `List[int]`, *optional*):
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
Expand Down Expand Up @@ -826,6 +834,17 @@ def __call__(
global_height = height
global_width = width

if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

# 1. Check inputs
self.check_inputs(
prompt,
Expand Down Expand Up @@ -1031,6 +1050,14 @@ def __call__(
generator,
)

controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)

# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
Expand All @@ -1049,11 +1076,19 @@ def __call__(
else:
guidance = None

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down
Loading