From d0a81ae604c567ad2119cd578a60a21561779958 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 Aug 2024 16:21:29 +0200 Subject: [PATCH 01/20] update --- .../animatediff/pipeline_animatediff.py | 10 ++- .../pipeline_animatediff_video2video.py | 64 +++++++++++++------ 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index a1f0374e318a..e407b06837c5 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -557,11 +557,15 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, @@ -701,6 +705,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -783,6 +788,9 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 70a4201ca05c..38c0d5098447 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -628,23 +628,20 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device): def prepare_latents( self, - video, - height, - width, - num_channels_latents, - batch_size, - timestep, - dtype, - device, - generator, - latents=None, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, decode_chunk_size: int = 16, - ): - if latents is None: - num_frames = video.shape[1] - else: - num_frames = latents.shape[2] - + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] shape = ( batch_size, num_channels_latents, @@ -708,8 +705,13 @@ def prepare_latents( if shape != latents.shape: # [B, C, F, H, W] raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + latents = latents.to(device, dtype=dtype) + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + return latents @property @@ -735,6 +737,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -743,6 +749,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + enforce_inference_steps: bool = False, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, @@ -874,6 +881,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -917,11 +925,20 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + num_inference_steps += 1 + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) # 5. Prepare latent variables if latents is None: @@ -942,6 +959,7 @@ def __call__( generator=generator, latents=latents, decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -970,6 +988,10 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + + print("here:", t, self.scheduler.sigmas[-num_inference_steps:]) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From d55903d0b244f0a00d6c055ad4a257da82d0984b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 Aug 2024 17:20:05 +0200 Subject: [PATCH 02/20] implement prompt interpolation --- .../models/unets/unet_motion_model.py | 1 - .../animatediff/pipeline_animatediff.py | 56 ++++--- .../pipeline_animatediff_video2video.py | 110 +++++++------ src/diffusers/pipelines/free_noise_utils.py | 147 +++++++++++++++++- 4 files changed, 244 insertions(+), 70 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 73c9c70c4a11..7ac61821edd6 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2178,7 +2178,6 @@ def forward( emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index e407b06837c5..1f5ff3dd731e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -432,7 +432,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, @@ -470,8 +469,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -708,7 +707,7 @@ def __call__( self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -721,22 +720,39 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 38c0d5098447..6fcb2e30206c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -246,7 +246,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, @@ -299,7 +298,7 @@ def encode_prompt( else: scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -582,8 +581,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -884,7 +883,7 @@ def __call__( self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -892,39 +891,9 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + dtype = self.dtype - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_videos_per_prompt, - self.do_classifier_free_guidance, - ) - - # 4. Prepare timesteps + # 3. Prepare timesteps if not enforce_inference_steps: timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas @@ -936,16 +905,15 @@ def __call__( timesteps, denoising_inference_steps = retrieve_timesteps( self.scheduler, denoising_inference_steps, device, timesteps, sigmas ) - num_inference_steps += 1 timesteps = timesteps[-num_inference_steps:] latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - # 5. Prepare latent variables + # 4. Prepare latent variables if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) # Move the number of frames before the number of channels. video = video.permute(0, 2, 1, 3, 4) - video = video.to(device=device, dtype=prompt_embeds.dtype) + video = video.to(device=device, dtype=dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, @@ -954,7 +922,7 @@ def __call__( num_channels_latents=num_channels_latents, batch_size=batch_size * num_videos_per_prompt, timestep=latent_timestep, - dtype=prompt_embeds.dtype, + dtype=dtype, device=device, generator=generator, latents=latents, @@ -962,10 +930,59 @@ def __call__( add_noise=enforce_inference_steps, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Add image embeds for IP-Adapter + # 8. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None @@ -985,13 +1002,12 @@ def __call__( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - # 8. Denoising loop + # 9. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - print("here:", t, self.scheduler.sigmas[-num_inference_steps:]) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1027,14 +1043,14 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 9. Post-processing + # 10. Post-processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) - # 10. Offload all models + # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f8128abb9b58..17c52381d920 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import torch @@ -22,6 +22,7 @@ DownBlockMotion, UpBlockMotion, ) +from ..pipelines.pipeline_utils import DiffusionPipeline from ..utils import logging from ..utils.torch_utils import randn_tensor @@ -97,7 +98,135 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - + + def _check_inputs_free_noise( + self, + prompt, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + num_frames, + ) -> None: + if not isinstance(prompt, (str, dict)): + raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, dict)): + raise ValueError(f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}") + + if prompt_embeds is not None or negative_prompt_embeds is not None: + raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") + + frame_indices = [isinstance(x, int) for x in prompt.keys()] + frame_prompts = [isinstance(x, str) for x in prompt.values()] + min_frame = min(list(prompt.keys())) + max_frame = max(list(prompt.keys())) + + if not all(frame_indices): + raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.") + if not all(frame_prompts): + raise ValueError("Expected str values in `prompt` dict for FreeNoise.") + if min_frame != 0: + raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") + if max_frame >= num_frames: + raise ValueError(f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing.") + + def _encode_prompt_free_noise( + self, + prompt: Union[str, Dict[int, str]], + num_frames: int, + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, Dict[int, str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ) -> torch.Tensor: + if negative_prompt is None: + negative_prompt = "" + + # Ensure that we have a dictionary of prompts + if isinstance(prompt, str): + prompt = {0: prompt} + if isinstance(negative_prompt, str): + negative_prompt = {0: negative_prompt} + + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) + + # Sort the prompts based on frame indices + prompt = dict(sorted(prompt.items())) + negative_prompt = dict(sorted(negative_prompt.items())) + + # Ensure that we have a prompt for the last frame index + prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] + negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] + + frame_indices = list(prompt.keys()) + frame_prompts = list(prompt.values()) + frame_negative_indices = list(negative_prompt.keys()) + frame_negative_prompts = list(negative_prompt.values()) + + # Generate and interpolate positive prompts + prompt_embeds, _ = self.encode_prompt( + prompt=frame_prompts, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=False, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + shape = (num_frames, *prompt_embeds.shape[1:]) + prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) + + for i in range(len(frame_indices) - 1): + start_frame = frame_indices[i] + end_frame = frame_indices[i + 1] + start_tensor = prompt_embeds[i].unsqueeze(0) + end_tensor = prompt_embeds[i + 1].unsqueeze(0) + + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + # Generate and interpolate negative prompts + negative_prompt_embeds = None + negative_prompt_interpolation_embeds = None + + if do_classifier_free_guidance: + _, negative_prompt_embeds = self.encode_prompt( + prompt=[""] * len(frame_negative_prompts), + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=True, + negative_prompt=frame_negative_prompts, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) + + for i in range(len(frame_negative_indices) - 1): + start_frame = frame_negative_indices[i] + end_frame = frame_negative_indices[i + 1] + start_tensor = negative_prompt_embeds[i].unsqueeze(0) + end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) + + negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + prompt_embeds = prompt_interpolation_embeds + negative_prompt_embeds = negative_prompt_interpolation_embeds + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds, negative_prompt_embeds + def _prepare_latents_free_noise( self, batch_size: int, @@ -171,6 +300,18 @@ def _prepare_latents_free_noise( latents = latents[:, :, :num_frames] return latents + + def _lerp(self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor) -> torch.Tensor: + num_indices = end_index - start_index + 1 + interpolated_tensors = [] + + for i in range(num_indices): + alpha = i / (num_indices - 1) + interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor + interpolated_tensors.append(interpolated_tensor) + + interpolated_tensors = torch.cat(interpolated_tensors) + return interpolated_tensors def enable_free_noise( self, @@ -178,6 +319,7 @@ def enable_free_noise( context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", + prompt_interpolation_callback: Optional[Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -219,6 +361,7 @@ def enable_free_noise( self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme self._free_noise_noise_type = noise_type + self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: From a86eabe0bd27dfa5ba57938e57790961f2c2cf6a Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 Aug 2024 17:20:32 +0200 Subject: [PATCH 03/20] make style --- .../animatediff/pipeline_animatediff.py | 2 +- .../pipeline_animatediff_video2video.py | 2 +- src/diffusers/pipelines/free_noise_utils.py | 56 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 1f5ff3dd731e..cb6f50f43c4f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -745,7 +745,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) - + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 6fcb2e30206c..1ebe2b9b60dd 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -960,7 +960,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) - + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 17c52381d920..f710206d7730 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -98,7 +98,7 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - + def _check_inputs_free_noise( self, prompt, @@ -109,11 +109,13 @@ def _check_inputs_free_noise( ) -> None: if not isinstance(prompt, (str, dict)): raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") - + if negative_prompt is not None: if not isinstance(negative_prompt, (str, dict)): - raise ValueError(f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}") - + raise ValueError( + f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}" + ) + if prompt_embeds is not None or negative_prompt_embeds is not None: raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") @@ -129,7 +131,9 @@ def _check_inputs_free_noise( if min_frame != 0: raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") if max_frame >= num_frames: - raise ValueError(f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing.") + raise ValueError( + f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." + ) def _encode_prompt_free_noise( self, @@ -146,23 +150,23 @@ def _encode_prompt_free_noise( ) -> torch.Tensor: if negative_prompt is None: negative_prompt = "" - + # Ensure that we have a dictionary of prompts if isinstance(prompt, str): prompt = {0: prompt} if isinstance(negative_prompt, str): negative_prompt = {0: negative_prompt} - + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) - + # Sort the prompts based on frame indices prompt = dict(sorted(prompt.items())) negative_prompt = dict(sorted(negative_prompt.items())) - + # Ensure that we have a prompt for the last frame index prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] - + frame_indices = list(prompt.keys()) frame_prompts = list(prompt.values()) frame_negative_indices = list(negative_prompt.keys()) @@ -180,7 +184,7 @@ def _encode_prompt_free_noise( lora_scale=lora_scale, clip_skip=clip_skip, ) - + shape = (num_frames, *prompt_embeds.shape[1:]) prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) @@ -190,7 +194,9 @@ def _encode_prompt_free_noise( start_tensor = prompt_embeds[i].unsqueeze(0) end_tensor = prompt_embeds[i + 1].unsqueeze(0) - prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback( + start_frame, end_frame, start_tensor, end_tensor + ) # Generate and interpolate negative prompts negative_prompt_embeds = None @@ -208,7 +214,7 @@ def _encode_prompt_free_noise( lora_scale=lora_scale, clip_skip=clip_skip, ) - + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) for i in range(len(frame_negative_indices) - 1): @@ -217,16 +223,18 @@ def _encode_prompt_free_noise( start_tensor = negative_prompt_embeds[i].unsqueeze(0) end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) - negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) - + negative_prompt_interpolation_embeds[ + start_frame : end_frame + 1 + ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + prompt_embeds = prompt_interpolation_embeds negative_prompt_embeds = negative_prompt_interpolation_embeds - + if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds, negative_prompt_embeds - + def _prepare_latents_free_noise( self, batch_size: int, @@ -300,16 +308,18 @@ def _prepare_latents_free_noise( latents = latents[:, :, :num_frames] return latents - - def _lerp(self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor) -> torch.Tensor: + + def _lerp( + self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor + ) -> torch.Tensor: num_indices = end_index - start_index + 1 interpolated_tensors = [] - + for i in range(num_indices): alpha = i / (num_indices - 1) interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor interpolated_tensors.append(interpolated_tensor) - + interpolated_tensors = torch.cat(interpolated_tensors) return interpolated_tensors @@ -319,7 +329,9 @@ def enable_free_noise( context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", - prompt_interpolation_callback: Optional[Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None, + prompt_interpolation_callback: Optional[ + Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] + ] = None, ) -> None: r""" Enable long video generation using FreeNoise. From 94438e1439994900dcf69f39b7cd327acf61c20f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 02:05:32 +0200 Subject: [PATCH 04/20] resnet memory optimizations --- src/diffusers/models/attention.py | 16 ++- src/diffusers/models/attention_processor.py | 2 + .../models/unets/unet_motion_model.py | 102 ++++++++++++++++-- src/diffusers/pipelines/free_noise_utils.py | 7 ++ 4 files changed, 118 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e6858d842cbb..edccbc990fae 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -43,6 +43,12 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output +def _experimental_split_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, split_size: int, split_dim: int +) -> torch.Tensor: + return torch.cat([ff(hs_split) for hs_split in hidden_states.split(split_size, dim=split_dim)], dim=split_dim) + + @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" @@ -525,7 +531,10 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + ff_output = _experimental_split_feed_forward( + self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim + ) else: ff_output = self.ff(norm_hidden_states) @@ -1095,7 +1104,10 @@ def forward( norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + ff_output = _experimental_split_feed_forward( + self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim + ) else: ff_output = self.ff(norm_hidden_states) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ab1606b345..e3ebf1077dc2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2221,6 +2221,8 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj + # TODO: figure out a better way to do this + # hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0) hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7ac61821edd6..5d0bc7b810bb 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -49,6 +49,18 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def _chunked_resnet_forward( + resnet: ResnetBlock2D, hidden_states: torch.Tensor, temb: torch.Tensor, chunk_size: int, chunk_dim: int +) -> torch.Tensor: + return torch.cat( + [ + resnet(hs_split, t_split) + for hs_split, t_split in zip(hidden_states.split(chunk_size, chunk_dim), temb.split(chunk_size, chunk_dim)) + ], + dim=chunk_dim, + ) + + @dataclass class UNetMotionOutput(BaseOutput): """ @@ -116,7 +128,7 @@ def __init__( self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks @@ -306,6 +318,12 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -344,7 +362,12 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -493,6 +516,12 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -540,7 +569,12 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -695,6 +729,12 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -766,7 +806,12 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -866,6 +911,12 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -929,7 +980,12 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -1065,6 +1121,12 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -1080,7 +1142,12 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + self.resnets[0], hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: @@ -1125,11 +1192,18 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] + hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(hidden_states, temb) + + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1970,6 +2044,20 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) + def enable_resnet_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + chunk_size = chunk_size or 1 + + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_resnet"): + logger.debug(f"Enabling chunked resnet inference in: {name}") + module.set_chunk_resnet(chunk_size, dim) + + def disable_resnet_chunking(self) -> None: + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_resnet"): + logger.debug(f"Disabling chunked resnet inference in: {name}") + module.set_chunk_resnet(None) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f710206d7730..ff5af5af01da 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -332,6 +332,8 @@ def enable_free_noise( prompt_interpolation_callback: Optional[ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, + _chunk_size_resnet: Optional[int] = None, + _chunk_size_feed_forward: Optional[int] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -379,6 +381,11 @@ def enable_free_noise( for block in blocks: self._enable_free_noise_in_block(block) + if _chunk_size_resnet is not None: + self.unet.enable_resnet_chunking(_chunk_size_resnet, dim=0) + if _chunk_size_feed_forward is not None: + self.unet.enable_forward_chunking(_chunk_size_feed_forward, dim=0) + def disable_free_noise(self) -> None: self._free_noise_context_length = None From 74e3ab088cb300bc6e3a053753ff4b6eb3a95781 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 06:14:44 +0200 Subject: [PATCH 05/20] more memory optimizations; todo: refactor --- src/diffusers/models/attention.py | 23 +- .../models/unets/unet_motion_model.py | 309 ++++++++++++++---- src/diffusers/pipelines/free_noise_utils.py | 9 + 3 files changed, 271 insertions(+), 70 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index edccbc990fae..5005ad118894 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1096,19 +1096,38 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - hidden_states = torch.where( - num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, ).to(dtype) + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ).to(dtype) + # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: + # norm_hidden_states = torch.cat([ + # self.norm3(hs_split) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) + # ], dim=self._chunk_dim) + # ff_output = torch.cat([ + # self.ff(self.norm3(hs_split)) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) + # ], dim=self._chunk_dim) + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) ff_output = _experimental_split_feed_forward( self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim ) else: + norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5d0bc7b810bb..e01f5dbfe628 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -61,6 +61,29 @@ def _chunked_resnet_forward( ) +def _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + chunk_size: int, + chunk_dim: int, +) -> torch.Tensor: + return torch.cat( + [ + attn( + hs_split, ehs_split, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict=False + )[0] + for hs_split, ehs_split in zip( + hidden_states.split(chunk_size, chunk_dim), encoder_hidden_states.split(chunk_size, chunk_dim) + ) + ], + dim=chunk_dim, + ) + + @dataclass class UNetMotionOutput(BaseOutput): """ @@ -152,6 +175,12 @@ def __init__( ) self.proj_out = nn.Linear(inner_dim, in_channels) + self._chunk_size_motion_module = None + self._chunk_dim_motion_module = 0 + + def set_chunk_motion_module(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_motion_module = chunk_size + self._chunk_dim_motion_module = dim def forward( self, @@ -203,13 +232,37 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + if encoder_hidden_states is None: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split in hidden_states.split(self._chunk_size_motion_module) + ], + dim=self._chunk_dim_motion_module, + ) + else: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=ehs_split, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split, ehs_split in zip( + hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), + encoder_hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), + ) + ], + dim=self._chunk_dim_motion_module, + ) # 3. Output hidden_states = self.proj_out(hidden_states) @@ -318,12 +371,12 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_dim_resnet = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim def forward( self, @@ -362,9 +415,9 @@ def custom_forward(*inputs): ) else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) @@ -375,7 +428,16 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + downsampler(hs_split) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -516,12 +578,18 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_size_attn = None + self._chunk_dim_resnet = 0 + self._chunk_dim_attn = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim + + def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_attn = chunk_size + self._chunk_dim_attn = dim def forward( self, @@ -569,21 +637,34 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self._chunk_size_attn is not None: + hidden_states = _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + self._chunk_size_resnet, + self._chunk_dim_resnet, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -597,7 +678,16 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + downsampler(hs_split) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -729,12 +819,18 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_size_attn = None + self._chunk_dim_resnet = 0 + self._chunk_dim_attn = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim + + def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_attn = chunk_size + self._chunk_dim_attn = dim def forward( self, @@ -806,21 +902,34 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self._chunk_size_attn is not None: + hidden_states = _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + self._chunk_size_resnet, + self._chunk_dim_resnet, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -828,7 +937,16 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + upsampler(hs_split, upsample_size) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -911,12 +1029,12 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_dim_resnet = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim def forward( self, @@ -980,9 +1098,9 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) @@ -991,7 +1109,16 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + upsampler(hs_split, upsample_size) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1121,12 +1248,18 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_size_attn = None + self._chunk_dim_resnet = 0 + self._chunk_dim_attn = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim + + def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_attn = chunk_size + self._chunk_dim_attn = dim def forward( self, @@ -1142,9 +1275,9 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - self.resnets[0], hidden_states, temb, self._chunk_size, self._chunk_dim + self.resnets[0], hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = self.resnets[0](hidden_states, temb) @@ -1184,23 +1317,35 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self._chunk_size_attn is not None: + hidden_states = _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + self._chunk_size_resnet, + self._chunk_dim_resnet, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) @@ -2045,7 +2190,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int fn_recursive_feed_forward(module, None, 0) def enable_resnet_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 1 + chunk_size = chunk_size or 16 for name, module in self.named_modules(): if hasattr(module, "set_chunk_resnet"): @@ -2058,6 +2203,34 @@ def disable_resnet_chunking(self) -> None: logger.debug(f"Disabling chunked resnet inference in: {name}") module.set_chunk_resnet(None) + def enable_attn_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + chunk_size = chunk_size or 16 + + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_attn"): + logger.debug(f"Enabling chunked attn inference in: {name}") + module.set_chunk_attn(chunk_size, dim) + + def disable_attn_chunking(self) -> None: + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_attn"): + logger.debug(f"Disabling chunked attn inference in: {name}") + module.set_chunk_attn(None) + + def enable_motion_module_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + chunk_size = chunk_size or 256 + + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_motion_module"): + logger.debug(f"Enabling chunked motion module inference in: {name}") + module.set_chunk_motion_module(chunk_size, dim) + + def disable_motion_module_chunking(self) -> None: + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_motion_module"): + logger.debug(f"Disabling chunked motion module inference in: {name}") + module.set_chunk_motion_module(None) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index ff5af5af01da..e76ae2358b8d 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -70,6 +70,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim + ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -98,6 +101,9 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim + ) def _check_inputs_free_noise( self, @@ -332,6 +338,7 @@ def enable_free_noise( prompt_interpolation_callback: Optional[ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, + _chunk_size_attn: Optional[int] = None, _chunk_size_resnet: Optional[int] = None, _chunk_size_feed_forward: Optional[int] = None, ) -> None: @@ -381,6 +388,8 @@ def enable_free_noise( for block in blocks: self._enable_free_noise_in_block(block) + if _chunk_size_attn is not None: + self.unet.enable_attn_chunking(_chunk_size_attn, dim=0) if _chunk_size_resnet is not None: self.unet.enable_resnet_chunking(_chunk_size_resnet, dim=0) if _chunk_size_feed_forward is not None: From ec91064966ac944c7ae0cb0cd2d89345e33c796e Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 18:42:15 +0200 Subject: [PATCH 06/20] update --- src/diffusers/models/attention.py | 25 +++++-- .../models/unets/unet_motion_model.py | 69 +++++++++++-------- src/diffusers/pipelines/free_noise_utils.py | 2 +- 3 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5005ad118894..83d7ae6c448c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -981,15 +981,32 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: - if weighting_scheme == "pyramid": + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": if num_frames % 2 == 0: # num_frames = 4 => [1, 2, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) + mid = num_frames // 2 + weights = list(range(1, mid + 1)) weights = weights + weights[::-1] else: # num_frames = 5 => [1, 2, 3, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) - weights = weights + [num_frames // 2 + 1] + weights[::-1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) else: raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index e01f5dbfe628..26a29e5e4802 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -232,36 +232,47 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if encoder_hidden_states is None: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=None, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split in hidden_states.split(self._chunk_size_motion_module) - ], - dim=self._chunk_dim_motion_module, - ) + if self._chunk_size_motion_module is not None: + if encoder_hidden_states is None: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split in hidden_states.split(self._chunk_size_motion_module) + ], + dim=self._chunk_dim_motion_module, + ) + else: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=ehs_split, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split, ehs_split in zip( + hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), + encoder_hidden_states.split( + self._chunk_size_motion_module, self._chunk_dim_motion_module + ), + ) + ], + dim=self._chunk_dim_motion_module, + ) else: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=ehs_split, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split, ehs_split in zip( - hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), - encoder_hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), - ) - ], - dim=self._chunk_dim_motion_module, + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, ) # 3. Output diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index e76ae2358b8d..cbf323be15bc 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -364,7 +364,7 @@ def enable_free_noise( TODO """ - allowed_weighting_scheme = ["pyramid"] + allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] allowed_noise_type = ["shuffle_context", "repeat_context", "random"] if context_length > self.motion_adapter.config.motion_max_seq_length: From 65686818ab13e4fa0b17163139566dd2e390a59b Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 23:54:55 +0200 Subject: [PATCH 07/20] update animatediff controlnet with latest changes --- .../pipeline_animatediff_controlnet.py | 64 +++++++++++++------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 6e8b0e3e5fe3..5357d6d5b8d9 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -505,8 +505,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -699,6 +699,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -858,9 +862,10 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -883,22 +888,39 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -990,6 +1012,9 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1002,7 +1027,6 @@ def __call__( else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] From 761c44d116665836ef238991ab01f4cb0acf86db Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 21 Aug 2024 11:47:31 +0200 Subject: [PATCH 08/20] refactor chunked inference changes --- .../models/unets/unet_motion_model.py | 385 +++--------------- src/diffusers/pipelines/free_noise_utils.py | 119 +++++- 2 files changed, 154 insertions(+), 350 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 26a29e5e4802..6125feba5899 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -49,41 +49,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _chunked_resnet_forward( - resnet: ResnetBlock2D, hidden_states: torch.Tensor, temb: torch.Tensor, chunk_size: int, chunk_dim: int -) -> torch.Tensor: - return torch.cat( - [ - resnet(hs_split, t_split) - for hs_split, t_split in zip(hidden_states.split(chunk_size, chunk_dim), temb.split(chunk_size, chunk_dim)) - ], - dim=chunk_dim, - ) - - -def _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - chunk_size: int, - chunk_dim: int, -) -> torch.Tensor: - return torch.cat( - [ - attn( - hs_split, ehs_split, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict=False - )[0] - for hs_split, ehs_split in zip( - hidden_states.split(chunk_size, chunk_dim), encoder_hidden_states.split(chunk_size, chunk_dim) - ) - ], - dim=chunk_dim, - ) - - @dataclass class UNetMotionOutput(BaseOutput): """ @@ -175,12 +140,6 @@ def __init__( ) self.proj_out = nn.Linear(inner_dim, in_channels) - self._chunk_size_motion_module = None - self._chunk_dim_motion_module = 0 - - def set_chunk_motion_module(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_motion_module = chunk_size - self._chunk_dim_motion_module = dim def forward( self, @@ -228,55 +187,20 @@ def forward( hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(input=hidden_states) # 2. Blocks for block in self.transformer_blocks: - if self._chunk_size_motion_module is not None: - if encoder_hidden_states is None: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=None, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split in hidden_states.split(self._chunk_size_motion_module) - ], - dim=self._chunk_dim_motion_module, - ) - else: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=ehs_split, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split, ehs_split in zip( - hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), - encoder_hidden_states.split( - self._chunk_size_motion_module, self._chunk_dim_motion_module - ), - ) - ], - dim=self._chunk_dim_motion_module, - ) - else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) # 3. Output - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(input=hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -382,12 +306,6 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size_resnet = None - self._chunk_dim_resnet = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim def forward( self, @@ -426,12 +344,7 @@ def custom_forward(*inputs): ) else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -439,16 +352,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - downsampler(hs_split) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -589,18 +493,6 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size_resnet = None - self._chunk_size_attn = None - self._chunk_dim_resnet = 0 - self._chunk_dim_attn = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim - - def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_attn = chunk_size - self._chunk_dim_attn = dim def forward( self, @@ -639,42 +531,17 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self._chunk_size_attn is not None: - hidden_states = _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - self._chunk_size_resnet, - self._chunk_dim_resnet, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, @@ -689,16 +556,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - downsampler(hs_split) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -830,18 +688,6 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size_resnet = None - self._chunk_size_attn = None - self._chunk_dim_resnet = 0 - self._chunk_dim_attn = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim - - def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_attn = chunk_size - self._chunk_dim_attn = dim def forward( self, @@ -904,42 +750,17 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self._chunk_size_attn is not None: - hidden_states = _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - self._chunk_size_resnet, - self._chunk_dim_resnet, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, @@ -948,16 +769,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - upsampler(hs_split, upsample_size) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1040,12 +852,6 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size_resnet = None - self._chunk_dim_resnet = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim def forward( self, @@ -1109,27 +915,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - upsampler(hs_split, upsample_size) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1259,18 +1051,6 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False - self._chunk_size_resnet = None - self._chunk_size_attn = None - self._chunk_dim_resnet = 0 - self._chunk_dim_attn = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim - - def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_attn = chunk_size - self._chunk_dim_attn = dim def forward( self, @@ -1286,15 +1066,19 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - self.resnets[0], hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1307,14 +1091,6 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1328,38 +1104,11 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: - if self._chunk_size_attn is not None: - hidden_states = _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - self._chunk_size_resnet, - self._chunk_dim_resnet, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states @@ -2200,48 +1949,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - def enable_resnet_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 16 - - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_resnet"): - logger.debug(f"Enabling chunked resnet inference in: {name}") - module.set_chunk_resnet(chunk_size, dim) - - def disable_resnet_chunking(self) -> None: - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_resnet"): - logger.debug(f"Disabling chunked resnet inference in: {name}") - module.set_chunk_resnet(None) - - def enable_attn_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 16 - - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_attn"): - logger.debug(f"Enabling chunked attn inference in: {name}") - module.set_chunk_attn(chunk_size, dim) - - def disable_attn_chunking(self) -> None: - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_attn"): - logger.debug(f"Disabling chunked attn inference in: {name}") - module.set_chunk_attn(None) - - def enable_motion_module_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 256 - - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_motion_module"): - logger.debug(f"Enabling chunked motion module inference in: {name}") - module.set_chunk_motion_module(chunk_size, dim) - - def disable_motion_module_chunking(self) -> None: - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_motion_module"): - logger.debug(f"Disabling chunked motion module inference in: {name}") - module.set_chunk_motion_module(None) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index cbf323be15bc..817360b3eaea 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( + AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -30,6 +34,59 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class ChunkedInferenceModule(nn.Module): + def __init__( + self, + module: nn.Module, + chunk_size: int = 1, + chunk_dim: int = 0, + input_kwargs_to_chunk: List[str] = ["hidden_states"], + ) -> None: + super().__init__() + + self.module = module + self.chunk_size = chunk_size + self.chunk_dim = chunk_dim + self.input_kwargs_to_chunk = set(input_kwargs_to_chunk) + + def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + r"""Forward method of `ChunkedInferenceModule`. + + All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be + chunked that are specified in `inputs_to_chunk` when initializing the module. + """ + chunked_inputs = {} + + for key in list(kwargs.keys()): + if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]): + continue + chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim) + kwargs.pop(key) + + results = [] + for chunked_input in zip(*chunked_inputs.values()): + inputs = dict(zip(chunked_inputs.keys(), chunked_input)) + inputs.update(kwargs) + + for key, input in inputs.items(): + if torch.is_tensor(input): + print(key, input.shape) + else: + print(key) + + intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) + results.append(intermediate_tensor_or_tensor_tuple) + + if isinstance(results[0], torch.Tensor): + return torch.cat(results, dim=self.chunk_dim) + elif isinstance(results[0], tuple): + return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)]) + else: + raise ValueError( + "In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + ) + + class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -338,9 +395,6 @@ def enable_free_noise( prompt_interpolation_callback: Optional[ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, - _chunk_size_attn: Optional[int] = None, - _chunk_size_resnet: Optional[int] = None, - _chunk_size_feed_forward: Optional[int] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -388,13 +442,6 @@ def enable_free_noise( for block in blocks: self._enable_free_noise_in_block(block) - if _chunk_size_attn is not None: - self.unet.enable_attn_chunking(_chunk_size_attn, dim=0) - if _chunk_size_resnet is not None: - self.unet.enable_resnet_chunking(_chunk_size_resnet, dim=0) - if _chunk_size_feed_forward is not None: - self.unet.enable_forward_chunking(_chunk_size_feed_forward, dim=0) - def disable_free_noise(self) -> None: self._free_noise_context_length = None @@ -402,6 +449,56 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) + def _enable_chunked_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int + ) -> None: + for motion_module in motion_modules: + motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"]) + + for i in range(len(motion_module.transformer_blocks)): + motion_module.transformer_blocks[i] = ChunkedInferenceModule( + motion_module.transformer_blocks[i], + spatial_chunk_size, + 0, + ["hidden_states", "encoder_hidden_states"], + ) + + motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"]) + + def _enable_chunked_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_chunk_size: int + ) -> None: + for i in range(len(attentions)): + attentions[i] = ChunkedInferenceModule( + attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"] + ) + + def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None: + for i in range(len(resnets)): + resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"]) + + def _enable_chunked_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int + ) -> None: + for i in range(len(samplers)): + samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"]) + + def enable_free_noise_chunked_inference( + self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16 + ) -> None: + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + if getattr(block, "motion_modules", None) is not None: + self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size) + if getattr(block, "attentions", None) is not None: + self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size) + if getattr(block, "resnets", None) is not None: + self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size) + if getattr(block, "downsamplers", None) is not None: + self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size) + if getattr(block, "upsamplers", None) is not None: + self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size) + @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None From 6830fb08052039c549708f4a8ec77d3d4c29b6d7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 21 Aug 2024 11:52:05 +0200 Subject: [PATCH 09/20] remove print statements --- src/diffusers/pipelines/free_noise_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 817360b3eaea..1c1147a4949e 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -68,12 +68,6 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: inputs = dict(zip(chunked_inputs.keys(), chunked_input)) inputs.update(kwargs) - for key, input in inputs.items(): - if torch.is_tensor(input): - print(key, input.shape) - else: - print(key) - intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) results.append(intermediate_tensor_or_tensor_tuple) From 9e215c070d63f64496fe56d8dd3d1b1db5c11659 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 24 Aug 2024 02:12:52 +0200 Subject: [PATCH 10/20] update --- src/diffusers/models/attention.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 83d7ae6c448c..1823636e402c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -43,12 +43,6 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output -def _experimental_split_feed_forward( - ff: nn.Module, hidden_states: torch.Tensor, split_size: int, split_dim: int -) -> torch.Tensor: - return torch.cat([ff(hs_split) for hs_split in hidden_states.split(split_size, dim=split_dim)], dim=split_dim) - - @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" @@ -531,10 +525,7 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - ff_output = _experimental_split_feed_forward( - self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim - ) + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) @@ -1124,25 +1115,11 @@ def forward( dim=1, ).to(dtype) - # hidden_states = torch.where( - # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values - # ).to(dtype) - # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: - # norm_hidden_states = torch.cat([ - # self.norm3(hs_split) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) - # ], dim=self._chunk_dim) - # ff_output = torch.cat([ - # self.ff(self.norm3(hs_split)) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) - # ], dim=self._chunk_dim) - - # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - ff_output = _experimental_split_feed_forward( - self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim - ) + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) From fb96059eb7d39f5c22deae80436298f099bc2152 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 07:52:28 +0200 Subject: [PATCH 11/20] chunk -> split --- src/diffusers/models/attention.py | 1 - src/diffusers/models/attention_processor.py | 2 - src/diffusers/pipelines/free_noise_utils.py | 80 ++++++++++----------- 3 files changed, 39 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1823636e402c..efeb553c1947 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1121,7 +1121,6 @@ def forward( if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: - norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 17fbdc526a6d..9f9bc5a46e10 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2372,8 +2372,6 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - # TODO: figure out a better way to do this - # hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0) hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index cf5beb8723ca..7fcd34a9aabe 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -34,50 +34,50 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class ChunkedInferenceModule(nn.Module): +class SplitInferenceModule(nn.Module): def __init__( self, module: nn.Module, - chunk_size: int = 1, - chunk_dim: int = 0, - input_kwargs_to_chunk: List[str] = ["hidden_states"], + split_size: int = 1, + split_dim: int = 0, + input_kwargs_to_split: List[str] = ["hidden_states"], ) -> None: super().__init__() self.module = module - self.chunk_size = chunk_size - self.chunk_dim = chunk_dim - self.input_kwargs_to_chunk = set(input_kwargs_to_chunk) + self.split_size = split_size + self.split_dim = split_dim + self.input_kwargs_to_split = set(input_kwargs_to_split) def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - r"""Forward method of `ChunkedInferenceModule`. + r"""Forward method of `SplitInferenceModule`. - All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be - chunked that are specified in `inputs_to_chunk` when initializing the module. + All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be + split that are specified in `inputs_to_split` when initializing the module. """ - chunked_inputs = {} + split_inputs = {} for key in list(kwargs.keys()): - if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]): + if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): continue - chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim) + split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) kwargs.pop(key) results = [] - for chunked_input in zip(*chunked_inputs.values()): - inputs = dict(zip(chunked_inputs.keys(), chunked_input)) + for split_input in zip(*split_inputs.values()): + inputs = dict(zip(split_inputs.keys(), split_input)) inputs.update(kwargs) intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) results.append(intermediate_tensor_or_tensor_tuple) if isinstance(results[0], torch.Tensor): - return torch.cat(results, dim=self.chunk_dim) + return torch.cat(results, dim=self.split_dim) elif isinstance(results[0], tuple): - return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)]) + return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) else: raise ValueError( - "In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." ) @@ -603,55 +603,53 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) - def _enable_chunked_inference_motion_modules_( - self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int + def _enable_split_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int ) -> None: for motion_module in motion_modules: - motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"]) + motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) for i in range(len(motion_module.transformer_blocks)): - motion_module.transformer_blocks[i] = ChunkedInferenceModule( + motion_module.transformer_blocks[i] = SplitInferenceModule( motion_module.transformer_blocks[i], - spatial_chunk_size, + spatial_split_size, 0, ["hidden_states", "encoder_hidden_states"], ) - motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"]) + motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) - def _enable_chunked_inference_attentions_( - self, attentions: List[Transformer2DModel], temporal_chunk_size: int + def _enable_split_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_split_size: int ) -> None: for i in range(len(attentions)): - attentions[i] = ChunkedInferenceModule( - attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"] + attentions[i] = SplitInferenceModule( + attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] ) - def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None: + def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: for i in range(len(resnets)): - resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"]) + resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) - def _enable_chunked_inference_samplers_( - self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int + def _enable_split_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int ) -> None: for i in range(len(samplers)): - samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"]) + samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) - def enable_free_noise_chunked_inference( - self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16 - ) -> None: + def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: if getattr(block, "motion_modules", None) is not None: - self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size) + self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) if getattr(block, "attentions", None) is not None: - self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size) + self._enable_split_inference_attentions_(block.attentions, temporal_split_size) if getattr(block, "resnets", None) is not None: - self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size) + self._enable_split_inference_resnets_(block.resnets, temporal_split_size) if getattr(block, "downsamplers", None) is not None: - self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size) + self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) if getattr(block, "upsamplers", None) is not None: - self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size) + self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) @property def free_noise_enabled(self): From dc2c12b108f634c51c390de8dd8f898827999850 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 07:58:26 +0200 Subject: [PATCH 12/20] remove changes from incorrect conflict resolution --- src/diffusers/pipelines/free_noise_utils.py | 100 -------------------- 1 file changed, 100 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 7fcd34a9aabe..3c1cff04c790 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -192,106 +192,6 @@ def _check_inputs_free_noise( f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." ) - def _encode_prompt_free_noise( - self, - prompt: Union[str, Dict[int, str]], - num_frames: int, - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, Dict[int, str]]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ) -> torch.Tensor: - if negative_prompt is None: - negative_prompt = "" - - # Ensure that we have a dictionary of prompts - if isinstance(prompt, str): - prompt = {0: prompt} - if isinstance(negative_prompt, str): - negative_prompt = {0: negative_prompt} - - self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) - - # Sort the prompts based on frame indices - prompt = dict(sorted(prompt.items())) - negative_prompt = dict(sorted(negative_prompt.items())) - - # Ensure that we have a prompt for the last frame index - prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] - negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] - - frame_indices = list(prompt.keys()) - frame_prompts = list(prompt.values()) - frame_negative_indices = list(negative_prompt.keys()) - frame_negative_prompts = list(negative_prompt.values()) - - # Generate and interpolate positive prompts - prompt_embeds, _ = self.encode_prompt( - prompt=frame_prompts, - device=device, - num_images_per_prompt=num_videos_per_prompt, - do_classifier_free_guidance=False, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - lora_scale=lora_scale, - clip_skip=clip_skip, - ) - - shape = (num_frames, *prompt_embeds.shape[1:]) - prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) - - for i in range(len(frame_indices) - 1): - start_frame = frame_indices[i] - end_frame = frame_indices[i + 1] - start_tensor = prompt_embeds[i].unsqueeze(0) - end_tensor = prompt_embeds[i + 1].unsqueeze(0) - - prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback( - start_frame, end_frame, start_tensor, end_tensor - ) - - # Generate and interpolate negative prompts - negative_prompt_embeds = None - negative_prompt_interpolation_embeds = None - - if do_classifier_free_guidance: - _, negative_prompt_embeds = self.encode_prompt( - prompt=[""] * len(frame_negative_prompts), - device=device, - num_images_per_prompt=num_videos_per_prompt, - do_classifier_free_guidance=True, - negative_prompt=frame_negative_prompts, - prompt_embeds=None, - negative_prompt_embeds=None, - lora_scale=lora_scale, - clip_skip=clip_skip, - ) - - negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) - - for i in range(len(frame_negative_indices) - 1): - start_frame = frame_negative_indices[i] - end_frame = frame_negative_indices[i + 1] - start_tensor = negative_prompt_embeds[i].unsqueeze(0) - end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) - - negative_prompt_interpolation_embeds[ - start_frame : end_frame + 1 - ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) - - prompt_embeds = prompt_interpolation_embeds - negative_prompt_embeds = negative_prompt_interpolation_embeds - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds, negative_prompt_embeds - def _check_inputs_free_noise( self, prompt, From 12f0ae11ba93d7b37ae5b3bb4ef79d8d09d810b8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 07:59:36 +0200 Subject: [PATCH 13/20] remove changes from incorrect conflict resolution --- src/diffusers/pipelines/free_noise_utils.py | 36 --------------------- 1 file changed, 36 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 3c1cff04c790..b894a44192cc 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -192,42 +192,6 @@ def _check_inputs_free_noise( f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." ) - def _check_inputs_free_noise( - self, - prompt, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - num_frames, - ) -> None: - if not isinstance(prompt, (str, dict)): - raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") - - if negative_prompt is not None: - if not isinstance(negative_prompt, (str, dict)): - raise ValueError( - f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}" - ) - - if prompt_embeds is not None or negative_prompt_embeds is not None: - raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") - - frame_indices = [isinstance(x, int) for x in prompt.keys()] - frame_prompts = [isinstance(x, str) for x in prompt.values()] - min_frame = min(list(prompt.keys())) - max_frame = max(list(prompt.keys())) - - if not all(frame_indices): - raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.") - if not all(frame_prompts): - raise ValueError("Expected str values in `prompt` dict for FreeNoise.") - if min_frame != 0: - raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") - if max_frame >= num_frames: - raise ValueError( - f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." - ) - def _encode_prompt_free_noise( self, prompt: Union[str, Dict[int, str]], From 661a0b389d2cfb50be104c6a3ea2ab810d98c174 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 11:45:55 +0200 Subject: [PATCH 14/20] add explanation of SplitInferenceModule --- src/diffusers/pipelines/free_noise_utils.py | 53 +++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index b894a44192cc..fabd98062bcb 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -50,19 +50,65 @@ def __init__( self.input_kwargs_to_split = set(input_kwargs_to_split) def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - r"""Forward method of `SplitInferenceModule`. + r"""Forward method for the `SplitInferenceModule`. - All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be - split that are specified in `inputs_to_split` when initializing the module. + This method processes the input by splitting specified keyword arguments along a given dimension, running the + underlying module on each split, and then concatenating the results. The splitting is controlled by the + `split_size` and `split_dim` parameters specified during initialization. + + Args: + *args (`Any`): + Positional arguments that are passed directly to the `module` without modification. + **kwargs (`Dict[str, torch.Tensor]`): + Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the + entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword + arguments are passed unchanged. + + Returns: + `Union[torch.Tensor, Tuple[torch.Tensor]]`: + The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred + without it. + - If the underlying module returns a single tensor, the result will be a single concatenated tensor + along the same `split_dim` after processing all splits. + - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated + along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule( + ... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"] + ... ) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input_data=input_tensor) + ``` + + This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + It is also possible to nest `SplitInferenceModule` across different split dimensions. """ split_inputs = {} + # 1. Split inputs that were specified during initialization and also present in passed kwargs for key in list(kwargs.keys()): if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): continue split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) kwargs.pop(key) + # 2. Invoke forward pass across each split results = [] for split_input in zip(*split_inputs.values()): inputs = dict(zip(split_inputs.keys(), split_input)) @@ -71,6 +117,7 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) results.append(intermediate_tensor_or_tensor_tuple) + # 3. Concatenate split restuls to obtain final outputs if isinstance(results[0], torch.Tensor): return torch.cat(results, dim=self.split_dim) elif isinstance(results[0], tuple): From c55a50a271b2cefa8fe340a4f2a3ab9b9d374ec0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 12:02:02 +0200 Subject: [PATCH 15/20] update docs --- docs/source/en/api/pipelines/animatediff.md | 58 +++++++++++++++++++++ src/diffusers/pipelines/free_noise_utils.py | 15 ++++++ 2 files changed, 73 insertions(+) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index bfd6ab973d5e..eeae17a4fb37 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -822,6 +822,64 @@ export_to_gif(frames, "animatelcm-motion-lora.gif") +## Using FreeNoise + +[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu. + +FreeNoise is a sampling mechanism that allows the generation of longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper. + +```python +import torch +from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter +from diffusers.utils import export_to_video, load_image + +# Load pipeline +dtype = torch.float16 +motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype) + +pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype) +pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + +pipe.load_lora_weights( + "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora" +) +pipe.set_adapters(["lcm_lora"], [0.8]) + +# Enable FreeNoise for long prompt generation +pipe.enable_free_noise(context_length=16, context_stride=4) +pipe.to("cuda") + +# Optionally, enable memory efficient inference +pipe.enable_free_noise_split_inference() +pipe.unet.enable_forward_chunking(16) + +# Can be a single prompt, or a dictionary with frame timesteps +prompt = { + 0: "A caterpillar on a leaf, high quality, photorealistic", + 40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic", + 80: "A cocoon on a leaf, flowers in the backgrond, photorealistic", + 120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic", + 160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic", + 200: "A beautiful butterfly, flying away in a forest, photorealistic", + 240: "A cyberpunk butterfly, neon lights, glowing", +} +negative_prompt = "bad quality, worst quality, jpeg artifacts" + +# Run inference +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=256, + guidance_scale=2.5, + num_inference_steps=10, + generator=torch.Generator("cpu").manual_seed(0), +) + +# Save video +frames = output.frames[0] +export_to_video(frames, "output.mp4", fps=16) +``` ## Using `from_single_file` with the MotionAdapter diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index fabd98062bcb..9b81312693d9 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -549,6 +549,21 @@ def _enable_split_inference_samplers_( samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: + r""" + Enable FreeNoise memory optimizations by utilizing + [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. + + Args: + spatial_split_size (`int`, defaults to `256`): + The split size across spatial dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion + modeling blocks. + temporal_split_size (`int`, defaults to `16`): + The split size across temporal dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial + attention, resnets, downsampling and upsampling blocks. + """ + # TODO(aryan): Discuss on what's the best way to provide more control to users blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: if getattr(block, "motion_modules", None) is not None: From 32961bed66595f3c8efa555c5d37a7b377abdd44 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 12:04:05 +0200 Subject: [PATCH 16/20] Revert "update docs" This reverts commit c55a50a271b2cefa8fe340a4f2a3ab9b9d374ec0. --- docs/source/en/api/pipelines/animatediff.md | 58 --------------------- src/diffusers/pipelines/free_noise_utils.py | 15 ------ 2 files changed, 73 deletions(-) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index eeae17a4fb37..bfd6ab973d5e 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -822,64 +822,6 @@ export_to_gif(frames, "animatelcm-motion-lora.gif") -## Using FreeNoise - -[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu. - -FreeNoise is a sampling mechanism that allows the generation of longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper. - -```python -import torch -from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter -from diffusers.utils import export_to_video, load_image - -# Load pipeline -dtype = torch.float16 -motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype) -vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype) - -pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype) -pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") - -pipe.load_lora_weights( - "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora" -) -pipe.set_adapters(["lcm_lora"], [0.8]) - -# Enable FreeNoise for long prompt generation -pipe.enable_free_noise(context_length=16, context_stride=4) -pipe.to("cuda") - -# Optionally, enable memory efficient inference -pipe.enable_free_noise_split_inference() -pipe.unet.enable_forward_chunking(16) - -# Can be a single prompt, or a dictionary with frame timesteps -prompt = { - 0: "A caterpillar on a leaf, high quality, photorealistic", - 40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic", - 80: "A cocoon on a leaf, flowers in the backgrond, photorealistic", - 120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic", - 160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic", - 200: "A beautiful butterfly, flying away in a forest, photorealistic", - 240: "A cyberpunk butterfly, neon lights, glowing", -} -negative_prompt = "bad quality, worst quality, jpeg artifacts" - -# Run inference -output = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=256, - guidance_scale=2.5, - num_inference_steps=10, - generator=torch.Generator("cpu").manual_seed(0), -) - -# Save video -frames = output.frames[0] -export_to_video(frames, "output.mp4", fps=16) -``` ## Using `from_single_file` with the MotionAdapter diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 9b81312693d9..fabd98062bcb 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -549,21 +549,6 @@ def _enable_split_inference_samplers_( samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: - r""" - Enable FreeNoise memory optimizations by utilizing - [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. - - Args: - spatial_split_size (`int`, defaults to `256`): - The split size across spatial dimensions for internal blocks. This is used in facilitating split - inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion - modeling blocks. - temporal_split_size (`int`, defaults to `16`): - The split size across temporal dimensions for internal blocks. This is used in facilitating split - inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial - attention, resnets, downsampling and upsampling blocks. - """ - # TODO(aryan): Discuss on what's the best way to provide more control to users blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: if getattr(block, "motion_modules", None) is not None: From 256ee34408683d4557ca6a696ae5f8d3f347fc4f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 12:10:35 +0200 Subject: [PATCH 17/20] update docstring for freenoise split inference --- src/diffusers/pipelines/free_noise_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index fabd98062bcb..9b81312693d9 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -549,6 +549,21 @@ def _enable_split_inference_samplers_( samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: + r""" + Enable FreeNoise memory optimizations by utilizing + [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. + + Args: + spatial_split_size (`int`, defaults to `256`): + The split size across spatial dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion + modeling blocks. + temporal_split_size (`int`, defaults to `16`): + The split size across temporal dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial + attention, resnets, downsampling and upsampling blocks. + """ + # TODO(aryan): Discuss on what's the best way to provide more control to users blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: if getattr(block, "motion_modules", None) is not None: From c7bf8dd833cbcc064acf0b1f3ffbfef774dfc785 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 12:38:29 +0200 Subject: [PATCH 18/20] apply suggestions from review --- src/diffusers/models/attention.py | 11 +++++++++++ src/diffusers/pipelines/free_noise_utils.py | 7 +++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index efeb553c1947..84db0d061768 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1104,6 +1104,17 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights + # TODO(aryan): Maybe this could be done in a better way. + # + # Previously, this was: + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ) + # + # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory + # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes + # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly + # looked into this deeply because other memory optimizations led to more pronounced reductions. hidden_states = torch.cat( [ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 9b81312693d9..4160a964461e 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -83,15 +83,14 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: Example: ```python >>> import torch + >>> import torch.nn as nn >>> model = nn.Linear(1000, 1000) - >>> split_module = SplitInferenceModule( - ... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"] - ... ) + >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) >>> input_tensor = torch.randn(42, 1000) >>> # Will split the tensor into 21 slices of shape [2, 1000]. - >>> output = split_module(input_data=input_tensor) + >>> output = split_module(input=input_tensor) ``` This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking From 9e556be3afa56ce379ffed913693b323f92e2809 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 13:31:39 +0200 Subject: [PATCH 19/20] add tests --- .../pipelines/animatediff/test_animatediff.py | 24 ++++++++++++++++ .../test_animatediff_video2video.py | 28 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 677267305373..54c83d6a1b68 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,6 +460,30 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device) + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 59146115b90a..c3fd4c73736a 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -492,6 +492,34 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_split_inference["num_inference_steps"] = 2 + inputs_enable_split_inference["strength"] = 0.5 + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) From 098bfd15a6e177b59dc60e25f2cfe2b988a7a2e8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 13:40:45 +0200 Subject: [PATCH 20/20] apply suggestions from review --- src/diffusers/pipelines/free_noise_utils.py | 65 +++++++++++++-------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 4160a964461e..dc0071a494e3 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -35,6 +35,46 @@ class SplitInferenceModule(nn.Module): + r""" + A wrapper module class that splits inputs along a specified dimension before performing a forward pass. + + This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + Args: + module (`nn.Module`): + The underlying PyTorch module that will be applied to each chunk of split inputs. + split_size (`int`, defaults to `1`): + The size of each chunk after splitting the input tensor. + split_dim (`int`, defaults to `0`): + The dimension along which the input tensors are split. + input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`): + A list of keyword arguments (strings) that represent the input tensors to be split. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + >>> import torch.nn as nn + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input=input_tensor) + ``` + + It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex + multi-dimensional splitting. + """ + def __init__( self, module: nn.Module, @@ -72,31 +112,6 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: along the same `split_dim` after processing all splits. - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. - - Workflow: - 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using - `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. - 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments - that were passed. - 3. The output tensors from each split are concatenated back together along `split_dim` before returning. - - Example: - ```python - >>> import torch - >>> import torch.nn as nn - - >>> model = nn.Linear(1000, 1000) - >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) - - >>> input_tensor = torch.randn(42, 1000) - >>> # Will split the tensor into 21 slices of shape [2, 1000]. - >>> output = split_module(input=input_tensor) - ``` - - This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking - them into smaller chunks, processing each chunk separately, and then reassembling the results. - - It is also possible to nest `SplitInferenceModule` across different split dimensions. """ split_inputs = {}