From 3e8b63216e502a0dd7424c0fef4134187802b93f Mon Sep 17 00:00:00 2001 From: antoine-scenario <129726301+antoine-scenario@users.noreply.github.com> Date: Wed, 10 Jan 2024 09:02:11 +0100 Subject: [PATCH] Add IP-Adapter to StableDiffusionXLControlNetImg2ImgPipeline (#6293) * add IP-Adapter to StableDiffusionXLControlNetImg2ImgPipeline Update src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py Co-authored-by: YiYi Xu fix tests * fix failing test --------- Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul --- .../pipeline_controlnet_sd_xl_img2img.py | 77 +++++++++++++++++-- .../test_controlnet_sdxl_img2img.py | 2 + 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index cbe39f788518..12ff9bbbfbbb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -20,13 +20,23 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from diffusers.utils.import_utils import is_invisible_watermark_available from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...loaders import ( + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -147,7 +157,7 @@ def retrieve_latents( class StableDiffusionXLControlNetImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin ): r""" Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -159,6 +169,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -197,10 +208,19 @@ class StableDiffusionXLControlNetImg2ImgPipeline( Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -216,6 +236,8 @@ def __init__( requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() @@ -231,6 +253,8 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -515,6 +539,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1011,6 +1060,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1109,6 +1159,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1262,7 +1313,7 @@ def __call__( ) guess_mode = guess_mode or global_pool_conditions - # 3. Encode input prompt + # 3.1. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) @@ -1287,6 +1338,15 @@ def __call__( clip_skip=self.clip_skip, ) + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image and controlnet_conditioning_image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) @@ -1449,6 +1509,9 @@ def __call__( down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py index ee8c479b1894..ff09693d103f 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py @@ -136,6 +136,8 @@ def get_dummy_components(self, skip_first_text_encoder=False): "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } return components