diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index d97f92d42c0..fe7a47f3a3f 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -55,7 +55,7 @@ TextConditioningData, TextConditioningRegions, ) -from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 +from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt @@ -884,7 +884,7 @@ def step_callback(state: PipelineIntermediateState) -> None: seed=seed, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, + attention_processor_cls=CustomAttnProcessor, ), unet=None, scheduler=scheduler, diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 6c39760bdc8..06a197f630e 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -28,11 +28,11 @@ DEFAULT_VRAM_CACHE = 0.25 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"] -ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] -ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] +ATTENTION_TYPE = Literal["auto", "normal", "xformers", "torch-sdp"] +ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] -CONFIG_SCHEMA_VERSION = "4.0.2" +CONFIG_SCHEMA_VERSION = "4.0.3" def get_default_ram_cache_size() -> float: @@ -107,8 +107,8 @@ class InvokeAIAppConfig(BaseSettings): device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. - attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` - attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` + attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `torch-sdp` + attention_slice_size: Slice size
Valid values: `auto`, `none`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. @@ -181,7 +181,7 @@ class InvokeAIAppConfig(BaseSettings): # GENERATION sequential_guidance: bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.") attention_type: ATTENTION_TYPE = Field(default="auto", description="Attention type.") - attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size, valid when attention_type=="sliced".') + attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size') force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") @@ -433,6 +433,30 @@ def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str return parsed_config_dict +def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate v4.0.2 config dictionary to a v4.0.3 config dictionary. + + Args: + config_dict: A dictionary of settings from a v4.0.2 config file. + + Returns: + An config dict with the settings migrated to v4.0.3. + """ + parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict) + attention_type = parsed_config_dict.get("attention_type", None) + + # now attention_slice_size means enabling slicing attention + if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict: + del parsed_config_dict["attention_slice_size"] + + # sliced moved to attention_slice_size + if attention_type == "sliced": + parsed_config_dict["attention_type"] = "auto" + + parsed_config_dict["schema_version"] = "4.0.3" + return parsed_config_dict + + def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """Load and migrate a config file to the latest version. @@ -458,6 +482,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: if loaded_config_dict["schema_version"] == "4.0.1": migrated = True loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict) + if loaded_config_dict["schema_version"] == "4.0.2": + migrated = True + loaded_config_dict = migrate_v4_0_2_to_4_0_3_config_dict(loaded_config_dict) if migrated: shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index b3a668518b0..6c2dca11f37 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -1,13 +1,11 @@ from __future__ import annotations import math -from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Callable, List, Optional, Union import einops import PIL.Image -import psutil import torch import torchvision.transforms as T from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL @@ -15,17 +13,13 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -from diffusers.utils.import_utils import is_xformers_available from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState -from invokeai.backend.util.attention import auto_detect_slice_size -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel @@ -167,66 +161,6 @@ def __init__( self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - def _adjust_memory_efficient_attention(self, latents: torch.Tensor): - """ - if xformers is available, use it, otherwise use sliced attention. - """ - config = get_config() - if config.attention_type == "xformers": - self.enable_xformers_memory_efficient_attention() - return - elif config.attention_type == "sliced": - slice_size = config.attention_slice_size - if slice_size == "auto": - slice_size = auto_detect_slice_size(latents) - elif slice_size == "balanced": - slice_size = "auto" - self.enable_attention_slicing(slice_size=slice_size) - return - elif config.attention_type == "normal": - self.disable_attention_slicing() - return - elif config.attention_type == "torch-sdp": - if hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return - else: - raise Exception("torch-sdp attention slicing not available") - - # the remainder if this code is called when attention_type=='auto' - if self.unet.device.type == "cuda": - if is_xformers_available(): - self.enable_xformers_memory_efficient_attention() - return - elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return - - if self.unet.device.type == "cpu" or self.unet.device.type == "mps": - mem_free = psutil.virtual_memory().free - elif self.unet.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device)) - else: - raise ValueError(f"unrecognized device {self.unet.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 - max_size_required_for_baddbmm = ( - 16 - * latents.size(dim=2) - * latents.size(dim=3) - * latents.size(dim=2) - * latents.size(dim=3) - * bytes_per_element_needed_for_baddbmm_duplication - ) - if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size="max") - elif torch.backends.mps.is_available(): - # diffusers recommends always enabling for mps - self.enable_attention_slicing(slice_size="max") - else: - self.disable_attention_slicing() - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): raise Exception("Should not be called") @@ -321,8 +255,6 @@ def latents_from_embeddings( # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) - self._adjust_memory_efficient_attention(latents) - # Handle mask guidance (a.k.a. inpainting). mask_guidance: AddsMaskGuidance | None = None if mask is not None and not is_inpainting_model(self.unet): @@ -347,23 +279,14 @@ def latents_from_embeddings( is_gradient_mask=is_gradient_mask, ) - use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None - ) - unet_attention_patcher = None - attn_ctx = nullcontext() - - if use_ip_adapter or use_regional_prompting: - ip_adapters: Optional[List[UNetIPAdapterData]] = ( - [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data] - if use_ip_adapter - else None - ) - unet_attention_patcher = UNetAttentionPatcher(ip_adapters) - attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + ip_adapters: Optional[List[UNetIPAdapterData]] = None + if ip_adapter_data is not None: + ip_adapters = [ + {"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data + ] - with attn_ctx: + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model): callback( PipelineIntermediateState( step=-1, diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py new file mode 100644 index 00000000000..68d3bdc7c84 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -0,0 +1,561 @@ +import math +from dataclasses import dataclass +from typing import List, Optional + +import psutil +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.utils.import_utils import is_xformers_available + +import invokeai.backend.util.logging as logger +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights +from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData +from invokeai.backend.util.devices import TorchDevice + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@dataclass +class IPAdapterAttentionWeights: + ip_adapter_weights: IPAttentionProcessorWeights + skip: bool + + +class CustomAttnProcessor: + """A custom implementation of attention processor that supports additional Invoke features. + This implementation is based on + AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732) + SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616) + XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113) + AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204) + Supported custom features: + - IP-Adapter + - Regional prompt attention + """ + + def __init__( + self, + ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, + ): + """Initialize a CustomAttnProcessor. + Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are + layer-specific are passed to __init__(). + Args: + ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights + for the i'th IP-Adapter. + """ + + self._ip_adapter_attention_weights = ip_adapter_attention_weights + + device = TorchDevice.choose_torch_device() + self.is_old_cuda = device.type == "cuda" and torch.cuda.get_device_capability(device)[0] < 8 + + config = get_config() + self.attention_type = config.attention_type + if self.attention_type == "auto": + self.attention_type = self._select_attention_type() + + self.slice_size = config.attention_slice_size + if self.slice_size == "auto": + self.slice_size = self._select_slice_size() + + if self.attention_type == "xformers" and xformers is None: + raise ImportError("xformers attention requires xformers module to be installed.") + + def _select_attention_type(self) -> str: + device = TorchDevice.choose_torch_device() + # On some mps system normal attention still faster than torch-sdp, on others - on par + # Results torch-sdp vs normal attention + # gogurt: 67.993s vs 67.729s + # Adreitz: 260.868s vs 226.638s + if device.type == "mps": + return "normal" + elif device.type == "cuda": + # In testing on a Tesla P40 (Pascal architecture), torch-sdp is much slower than xformers + # (8.84 s/it vs. 1.81 s/it for SDXL). We have not tested extensively to find the precise GPU architecture or + # compute capability where this performance gap begins. + # Flash Attention is supported from sm80 compute capability onwards in PyTorch + # (https://pytorch.org/blog/accelerated-pytorch-2/). For now, we use this as the cutoff for selecting + # between xformers and torch-sdp. + if self.is_old_cuda: + if xformers is not None: + return "xformers" + logger.warning( + f"xFormers is not installed, but is recommended for best performance with GPU {torch.cuda.get_device_properties(device).name}" + ) + + return "torch-sdp" + else: # cpu + return "torch-sdp" + + def _select_slice_size(self) -> str: + device = TorchDevice.choose_torch_device() + if device.type in ["cpu", "mps"]: + total_ram_gb = math.ceil(psutil.virtual_memory().total / 2**30) + if total_ram_gb <= 16: + return "max" + if total_ram_gb <= 32: + return "balanced" + return "none" + elif device.type == "cuda": + total_vram_gb = math.ceil(torch.cuda.get_device_properties(device).total_memory / 2**30) + if total_vram_gb <= 4: + return "max" + if total_vram_gb <= 6: + return "balanced" + return "none" + else: + raise ValueError(f"Unknown device: {device.type}") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + # For Regional Prompting: + regional_prompt_data: Optional[RegionalPromptData] = None, + # For IP-Adapter: + regional_ip_data: Optional[RegionalIPData] = None, + *args, + **kwargs, + ) -> torch.Tensor: + # If true, we are doing cross-attention, if false we are doing self-attention. + is_cross_attention = encoder_hidden_states is not None + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, key_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + query_length = hidden_states.shape[1] + + # Regional Prompt Attention Mask + if regional_prompt_data is not None and is_cross_attention: + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_length, key_seq_len=key_length + ) + + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask + + attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + hidden_states = self.run_attention( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + if is_cross_attention: + hidden_states = self.run_ip_adapters( + attn=attn, + hidden_states=hidden_states, + regional_ip_data=regional_ip_data, + query_length=query_length, + query=query, + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + def run_ip_adapters( + self, + attn: Attention, + hidden_states: torch.Tensor, + regional_ip_data: Optional[RegionalIPData], + query_length: int, # TODO: just read from query? + query: torch.Tensor, + ) -> torch.Tensor: + if self._ip_adapter_attention_weights is None: + # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. + assert regional_ip_data is None + return hidden_states + + assert regional_ip_data is not None + ip_masks = regional_ip_data.get_masks(query_seq_len=query_length) + + assert ( + len(regional_ip_data.image_prompt_embeds) + == len(self._ip_adapter_attention_weights) + == len(regional_ip_data.scales) + == ip_masks.shape[1] + ) + + for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds): + # The batch dimensions should match. + # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0] + # The token_len dimensions should match. + # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1] + + if self._ip_adapter_attention_weights[ipa_index].skip: + continue + + ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights + ipa_scale = regional_ip_data.scales[ipa_index] + ip_mask = ip_masks[0, ipa_index, ...] + + # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) + + ip_hidden_states = self.run_attention( + attn=attn, + query=query, + key=ip_key, + value=ip_value, + attention_mask=None, + ) + + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask + + return hidden_states + + def _get_slice_size(self, attn: Attention) -> Optional[int]: + if self.slice_size == "none": + return None + if isinstance(self.slice_size, int): + return self.slice_size + + if self.slice_size == "max": + return 1 + if self.slice_size == "balanced": + return max(1, attn.sliceable_head_dim // 2) + + raise ValueError(f"Incorrect slice_size value: {self.slice_size}") + + def run_attention( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + no_sliced: bool = False, + ) -> torch.Tensor: + slice_size = self._get_slice_size(attn) + if not no_sliced and slice_size is not None: + return self.run_attention_sliced( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + slice_size=slice_size, + ) + + if self.attention_type == "torch-sdp": + attn_call = self.run_attention_sdp + elif self.attention_type == "normal": + attn_call = self.run_attention_normal + elif self.attention_type == "xformers": + attn_call = self.run_attention_xformers + else: + raise Exception(f"Unknown attention type: {self.attention_type}") + + return attn_call( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + @staticmethod + def _align_attention_mask_memory(attention_mask: torch.Tensor, alignment: int = 8) -> torch.Tensor: + if attention_mask.stride(-2) % alignment == 0 and attention_mask.stride(-2) != 0: + return attention_mask + + last_mask_dim = attention_mask.shape[-1] + new_last_mask_dim = last_mask_dim + (alignment - (last_mask_dim % alignment)) + attention_mask_mem = torch.empty( + attention_mask.shape[:-1] + (new_last_mask_dim,), + device=attention_mask.device, + dtype=attention_mask.dtype, + ) + attention_mask_mem[..., :last_mask_dim] = attention_mask + return attention_mask_mem[..., :last_mask_dim] + + @staticmethod + def _head_to_batch_dim(tensor: torch.Tensor, head_dim: int) -> torch.Tensor: + # [B, S, H*He] -> [B, S, H, He] + tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, head_dim) + # [B, S, H, He] -> [B, H, S, He] + tensor = tensor.permute(0, 2, 1, 3) + # [B, H, S, He] -> [B*H, S, He] + tensor = tensor.reshape(-1, tensor.shape[2], head_dim) + return tensor + + @staticmethod + def _batch_to_head_dim(tensor: torch.Tensor, batch_size: int) -> torch.Tensor: + # [B*H, S, He] -> [B, H, S, He] + tensor = tensor.reshape(batch_size, -1, tensor.shape[1], tensor.shape[2]) + # [B, H, S, He] -> [B, S, H, He] + tensor = tensor.permute(0, 2, 1, 3) + # [B, S, H, He] -> [B, S, H*He] + tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1) + return tensor + + def run_attention_normal( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + + query = self._head_to_batch_dim(query, head_dim) + key = self._head_to_batch_dim(key, head_dim) + value = self._head_to_batch_dim(value, head_dim) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + + hidden_states = self._batch_to_head_dim(hidden_states, batch_size) + return hidden_states + + def run_attention_xformers( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + multihead: Optional[bool] = None, + ) -> torch.Tensor: + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + + # batched execution on xformers slightly faster for small heads count + # 8 heads, fp16 (100000 attention calls): + # xformers(dim3): 20.155955553054810s vram: 16483328b + # xformers(dim4): 17.558132648468018s vram: 16483328b + # 1 head, fp16 (100000 attention calls): + # xformers(dim3): 5.660739183425903s vram: 9516032b + # xformers(dim4): 6.114191055297852s vram: 9516032b + if multihead is None: + heads_count = query.shape[2] // head_dim + multihead = heads_count >= 4 + + if multihead: + # [B, S, H*He] -> [B, S, H, He] + query = query.view(batch_size, query.shape[1], -1, head_dim) + key = key.view(batch_size, key.shape[1], -1, head_dim) + value = value.view(batch_size, value.shape[1], -1, head_dim) + + if attention_mask is not None: + # [B*H, 1, S_key] -> [B, H, 1, S_key] + attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) + # expand our mask's singleton query dimension: + # [B, H, 1, S_key] -> + # [B, H, S_query, S_key] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + attention_mask = attention_mask.expand(-1, -1, query.shape[1], -1) + # xformers requires mask memory to be aligned to 8 + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale + ) + # [B, S_query, H, He] -> [B, S_query, H*He] + hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,)) + hidden_states = hidden_states.to(query.dtype) + + else: + # contiguous inputs slightly faster in batched execution + # [B, S, H*He] -> [B*H, S, He] + query = self._head_to_batch_dim(query, head_dim).contiguous() + key = self._head_to_batch_dim(key, head_dim).contiguous() + value = self._head_to_batch_dim(value, head_dim).contiguous() + + if attention_mask is not None: + # expand our mask's singleton query dimension: + # [B*H, 1, S_key] -> + # [B*H, S_query, S_key] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + attention_mask = attention_mask.expand(-1, query.shape[1], -1) + # xformers requires mask memory to be aligned to 8 + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + # [B*H, S_query, He] -> [B, S_query, H*He] + hidden_states = self._batch_to_head_dim(hidden_states, batch_size) + + return hidden_states + + def run_attention_sdp( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + multihead: Optional[bool] = None, + ) -> torch.Tensor: + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + + if multihead is None: + # multihead extremely slow on old cuda gpu: + # fp16 (100000 attention calls): + # torch-sdp(dim3): 30.07543110847473s vram: 23954432b + # torch-sdp(dim4): 299.3908393383026s vram: 13861888b + multihead = not self.is_old_cuda + + if multihead: + # [B, S, H*He] -> [B, H, S, He] + query = query.view(batch_size, query.shape[1], -1, head_dim).transpose(1, 2) + key = key.view(batch_size, key.shape[1], -1, head_dim).transpose(1, 2) + value = value.view(batch_size, value.shape[1], -1, head_dim).transpose(1, 2) + + if attention_mask is not None: + # [B*H, 1, S_key] -> [B, H, 1, S_key] + attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) + # mask alignment to 8 decreases memory consumption and increases speed + # fp16 (100000 attention calls): + # torch-sdp(dim4, mask): 6.1701478958129880s vram: 7864320b + # torch-sdp(dim4, aligned mask): 3.3127212524414062s vram: 2621440b + # fp32 (100000 attention calls): + # torch-sdp(dim4, mask): 23.0943229198455800s vram: 16121856b + # torch-sdp(dim4, aligned mask): 17.3104763031005860s vram: 5636096b + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + ) + + # [B, H, S_query, He] -> [B, S_query, H, He] + hidden_states = hidden_states.transpose(1, 2) + # [B, S_query, H, He] -> [B, S_query, H*He] + hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,)) + hidden_states = hidden_states.to(query.dtype) + else: + # [B, S, H*He] -> [B*H, S, He] + query = self._head_to_batch_dim(query, head_dim) + key = self._head_to_batch_dim(key, head_dim) + value = self._head_to_batch_dim(value, head_dim) + + # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key] + # and there no noticable changes from memory alignment in batched run: + # fp16 (100000 attention calls): + # torch-sdp(dim3, mask): 9.7391905784606930s vram: 12713984b + # torch-sdp(dim3, aligned mask): 10.0090200901031500s vram: 12713984b + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + ) + + hidden_states = hidden_states.to(query.dtype) + # [B*H, S_query, He] -> [B, S_query, H*He] + hidden_states = self._batch_to_head_dim(hidden_states, batch_size) + + return hidden_states + + def run_attention_sliced( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + slice_size: int, + ) -> torch.Tensor: + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + heads_count = query.shape[2] // head_dim + + # [B, S, H*He] -> [B, H, S, He] + query = query.reshape(query.shape[0], query.shape[1], -1, head_dim).transpose(1, 2) + key = key.reshape(key.shape[0], key.shape[1], -1, head_dim).transpose(1, 2) + value = value.reshape(value.shape[0], value.shape[1], -1, head_dim).transpose(1, 2) + # [B*H, S_query/1, S_key] -> [B, H, S_query/1, S_key] + if attention_mask is not None: + attention_mask = attention_mask.reshape(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) + + # [B, H, S_query, He] + hidden_states = torch.empty(query.shape, device=query.device, dtype=query.dtype) + + for i in range((heads_count - 1) // slice_size + 1): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + # [B, H_s, S, He] -> [B, S, H_s*He] + query_slice = query[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, query.shape[2], -1) + key_slice = key[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, key.shape[2], -1) + value_slice = value[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, value.shape[2], -1) + + # [B, H_s, S_query/1, S_key] -> [B*H_s, S_query/1, S_key] + attn_mask_slice = None + if attention_mask is not None: + attn_mask_slice = attention_mask[:, start_idx:end_idx, :, :].reshape((-1,) + attention_mask.shape[-2:]) + + # [B, S_query, H_s*He] + hidden_states_slice = self.run_attention( + attn=attn, + query=query_slice, + key=key_slice, + value=value_slice, + attention_mask=attn_mask_slice, + no_sliced=True, + ) + + # [B, S_query, H_s*He] -> [B, H_s, S_query, He] + hidden_states[:, start_idx:end_idx] = hidden_states_slice.reshape( + hidden_states_slice.shape[:-1] + (-1, head_dim) + ).transpose(1, 2) + + # [B, H_s, S_query, He] -> [B, S_query, H_s*He] + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(hidden_states.shape[:-2] + (-1,)) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py deleted file mode 100644 index 1334313fe6e..00000000000 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ /dev/null @@ -1,214 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, cast - -import torch -import torch.nn.functional as F -from diffusers.models.attention_processor import Attention, AttnProcessor2_0 - -from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights -from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData -from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData - - -@dataclass -class IPAdapterAttentionWeights: - ip_adapter_weights: IPAttentionProcessorWeights - skip: bool - - -class CustomAttnProcessor2_0(AttnProcessor2_0): - """A custom implementation of AttnProcessor2_0 that supports additional Invoke features. - This implementation is based on - https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204 - Supported custom features: - - IP-Adapter - - Regional prompt attention - """ - - def __init__( - self, - ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, - ): - """Initialize a CustomAttnProcessor2_0. - Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are - layer-specific are passed to __init__(). - Args: - ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights - for the i'th IP-Adapter. - """ - super().__init__() - self._ip_adapter_attention_weights = ip_adapter_attention_weights - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - # For Regional Prompting: - regional_prompt_data: Optional[RegionalPromptData] = None, - percent_through: Optional[torch.Tensor] = None, - # For IP-Adapter: - regional_ip_data: Optional[RegionalIPData] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - """Apply attention. - Args: - regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to - apply regional prompt masking. - regional_ip_data: The IP-Adapter data for the current batch. - """ - # If true, we are doing cross-attention, if false we are doing self-attention. - is_cross_attention = encoder_hidden_states is not None - - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # End unmodified block from AttnProcessor2_0. - - _, query_seq_len, _ = hidden_states.shape - # Handle regional prompt attention masks. - if regional_prompt_data is not None and is_cross_attention: - assert percent_through is not None - prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_seq_len, key_seq_len=sequence_length - ) - - if attention_mask is None: - attention_mask = prompt_region_attention_mask - else: - attention_mask = prompt_region_attention_mask + attention_mask - - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # End unmodified block from AttnProcessor2_0. - - # Apply IP-Adapter conditioning. - if is_cross_attention: - if self._ip_adapter_attention_weights: - assert regional_ip_data is not None - ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) - - assert ( - len(regional_ip_data.image_prompt_embeds) - == len(self._ip_adapter_attention_weights) - == len(regional_ip_data.scales) - == ip_masks.shape[1] - ) - - for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): - ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights - ipa_scale = regional_ip_data.scales[ipa_index] - ip_mask = ip_masks[0, ipa_index, ...] - - # The batch dimensions should match. - assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] - - ip_hidden_states = ipa_embed - - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - - if not self._ip_adapter_attention_weights[ipa_index].skip: - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - # Expected ip_key and ip_value shape: - # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Expected ip_key and ip_value shape: - # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - - ip_hidden_states = ip_hidden_states.to(query.dtype) - - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask - else: - # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. - assert regional_ip_data is None - - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # End of unmodified block from AttnProcessor2_0 - - # casting torch.Tensor to torch.FloatTensor to avoid type issues - return cast(torch.FloatTensor, hidden_states) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49f..d5f54ccff16 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -335,7 +335,6 @@ def _apply_standard_conditioning( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=regions, device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds @@ -426,7 +425,6 @@ def _apply_standard_conditioning_sequentially( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( @@ -474,7 +472,6 @@ def _apply_standard_conditioning_sequentially( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index ac00a8e06ea..8ba8b3acf38 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -4,8 +4,8 @@ from diffusers.models import UNet2DConditionModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( - CustomAttnProcessor2_0, +from invokeai.backend.stable_diffusion.diffusion.custom_attention import ( + CustomAttnProcessor, IPAdapterAttentionWeights, ) @@ -16,7 +16,7 @@ class UNetIPAdapterData(TypedDict): class UNetAttentionPatcher: - """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" + """A class for patching a UNet with CustomAttnProcessor attention layers.""" def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]): self._ip_adapters = ip_adapter_data @@ -27,11 +27,12 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel): Note that the `unet` param is only used to determine attention block dimensions and naming. """ # Construct a dict of attention processors based on the UNet's architecture. + attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): if name.endswith("attn1.processor") or self._ip_adapters is None: # "attn1" processors do not use IP-Adapters. - attn_procs[name] = CustomAttnProcessor2_0() + attn_procs[name] = CustomAttnProcessor() else: # Collect the weights from each IP Adapter for the idx'th attention processor. ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = [] @@ -48,12 +49,14 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel): ) ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights) - attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection) + attn_procs[name] = CustomAttnProcessor( + ip_adapter_attention_weights=ip_adapter_attention_weights_collection, + ) return attn_procs @contextmanager - def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): + def apply_custom_attention(self, unet: UNet2DConditionModel): """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" attn_procs = self._prepare_attention_processors(unet) orig_attn_processors = unet.attn_processors diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4191db734f9..dddf7e555d4 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -114,9 +114,6 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio sample=sample, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.inputs.timesteps), - ), ) ctx.conditioning_mode = conditioning_mode diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py index a48a681af3f..8728779e00b 100644 --- a/invokeai/backend/stable_diffusion/extensions/controlnet.py +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -112,8 +112,6 @@ def pre_unet_step(self, ctx: DenoiseContext): ctx.unet_kwargs.mid_block_additional_residual += mid_sample def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): - total_steps = len(ctx.inputs.timesteps) - model_input = ctx.latent_model_input image_tensor = self._image_tensor if conditioning_mode == ConditioningMode.Both: @@ -124,9 +122,6 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con sample=model_input, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditioning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / total_steps, - ), ) ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 6c07fc1c2c8..f78facaf581 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -13,6 +13,7 @@ StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.tiles.utils import Tile @@ -63,132 +64,132 @@ def multi_diffusion_denoise( latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) assert isinstance(latents, torch.Tensor) # For static type checking. - # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after - # cropping into regions. - self._adjust_memory_efficient_attention(latents) - - # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since - # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a - # separate scheduler state for each region batch. - # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler - # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect - # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when - # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each - # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion. - region_batch_schedulers: list[SchedulerMixin] = [ - copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning - ] - - callback( - PipelineIntermediateState( - step=-1, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) - ) - - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - - merged_latents = torch.zeros_like(latents) - merged_latents_weights = torch.zeros( - (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype - ) - merged_pred_original: torch.Tensor | None = None - for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): - # Switch to the scheduler for the region batch. - self.scheduler = region_batch_schedulers[region_idx] - - # Crop the inputs to the region. - region_latents = latents[ - :, - :, - region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, - region_conditioning.region.coords.left : region_conditioning.region.coords.right, - ] - - # Run the denoising step on the region. - step_output = self.step( - t=batched_t, - latents=region_latents, - conditioning_data=region_conditioning.text_conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=None, - mask=None, - masked_latents=None, - control_data=region_conditioning.control_data, - ) - - # Build a region_weight matrix that applies gradient blending to the edges of the region. - region = region_conditioning.region - _, _, region_height, region_width = step_output.prev_sample.shape - region_weight = torch.ones( - (1, 1, region_height, region_width), - dtype=latents.dtype, - device=latents.device, - ) - if region.overlap.left > 0: - left_grad = torch.linspace( - 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype - ).view((1, 1, 1, -1)) - region_weight[:, :, :, : region.overlap.left] *= left_grad - if region.overlap.top > 0: - top_grad = torch.linspace( - 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype - ).view((1, 1, -1, 1)) - region_weight[:, :, : region.overlap.top, :] *= top_grad - if region.overlap.right > 0: - right_grad = torch.linspace( - 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype - ).view((1, 1, 1, -1)) - region_weight[:, :, :, -region.overlap.right :] *= right_grad - if region.overlap.bottom > 0: - bottom_grad = torch.linspace( - 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype - ).view((1, 1, -1, 1)) - region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad - - # Update the merged results with the region results. - merged_latents[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += step_output.prev_sample * region_weight - merged_latents_weights[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += region_weight - - pred_orig_sample = getattr(step_output, "pred_original_sample", None) - if pred_orig_sample is not None: - # If one region has pred_original_sample, then we can assume that all regions will have it, because - # they all use the same scheduler. - if merged_pred_original is None: - merged_pred_original = torch.zeros_like(latents) - merged_pred_original[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += pred_orig_sample - - # Normalize the merged results. - latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents) - # For debugging, uncomment this line to visualize the region seams: - # latents = torch.where(merged_latents_weights > 1, 0.0, latents) - predicted_original = None - if merged_pred_original is not None: - predicted_original = torch.where( - merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original - ) + unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None) + with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model): + # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since + # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a + # separate scheduler state for each region batch. + # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler + # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect + # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when + # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each + # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion. + region_batch_schedulers: list[SchedulerMixin] = [ + copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning + ] callback( PipelineIntermediateState( - step=i, + step=-1, order=self.scheduler.order, total_steps=len(timesteps), - timestep=int(t), + timestep=self.scheduler.config.num_train_timesteps, latents=latents, - predicted_original=predicted_original, ) ) + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t = t.expand(batch_size) + + merged_latents = torch.zeros_like(latents) + merged_latents_weights = torch.zeros( + (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype + ) + merged_pred_original: torch.Tensor | None = None + for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): + # Switch to the scheduler for the region batch. + self.scheduler = region_batch_schedulers[region_idx] + + # Crop the inputs to the region. + region_latents = latents[ + :, + :, + region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, + region_conditioning.region.coords.left : region_conditioning.region.coords.right, + ] + + # Run the denoising step on the region. + step_output = self.step( + t=batched_t, + latents=region_latents, + conditioning_data=region_conditioning.text_conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=None, + mask=None, + masked_latents=None, + control_data=region_conditioning.control_data, + ) + + # Build a region_weight matrix that applies gradient blending to the edges of the region. + region = region_conditioning.region + _, _, region_height, region_width = step_output.prev_sample.shape + region_weight = torch.ones( + (1, 1, region_height, region_width), + dtype=latents.dtype, + device=latents.device, + ) + if region.overlap.left > 0: + left_grad = torch.linspace( + 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype + ).view((1, 1, 1, -1)) + region_weight[:, :, :, : region.overlap.left] *= left_grad + if region.overlap.top > 0: + top_grad = torch.linspace( + 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype + ).view((1, 1, -1, 1)) + region_weight[:, :, : region.overlap.top, :] *= top_grad + if region.overlap.right > 0: + right_grad = torch.linspace( + 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype + ).view((1, 1, 1, -1)) + region_weight[:, :, :, -region.overlap.right :] *= right_grad + if region.overlap.bottom > 0: + bottom_grad = torch.linspace( + 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype + ).view((1, 1, -1, 1)) + region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad + + # Update the merged results with the region results. + merged_latents[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += step_output.prev_sample * region_weight + merged_latents_weights[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += region_weight + + pred_orig_sample = getattr(step_output, "pred_original_sample", None) + if pred_orig_sample is not None: + # If one region has pred_original_sample, then we can assume that all regions will have it, because + # they all use the same scheduler. + if merged_pred_original is None: + merged_pred_original = torch.zeros_like(latents) + merged_pred_original[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += pred_orig_sample + + # Normalize the merged results. + latents = torch.where( + merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents + ) + # For debugging, uncomment this line to visualize the region seams: + # latents = torch.where(merged_latents_weights > 1, 0.0, latents) + predicted_original = None + if merged_pred_original is not None: + predicted_original = torch.where( + merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original + ) + + callback( + PipelineIntermediateState( + step=i, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=int(t), + latents=latents, + predicted_original=predicted_original, + ) + ) + return latents diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 7e362fe9589..1991aed3ccc 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -802,8 +802,11 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0): if xformers_available: # TODO: remove when fixed in diffusers + from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor + _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention + # TODO: remove? or there still possible calls to xformers not by our attention processor? def new_memory_efficient_attention( query: torch.Tensor, key: torch.Tensor, @@ -815,16 +818,8 @@ def new_memory_efficient_attention( op=None, ): # diffusers not align shape to 8, which is required by xformers - if attn_bias is not None and type(attn_bias) is torch.Tensor: - orig_size = attn_bias.shape[-1] - new_size = ((orig_size + 7) // 8) * 8 - aligned_attn_bias = torch.zeros( - (attn_bias.shape[0], attn_bias.shape[1], new_size), - device=attn_bias.device, - dtype=attn_bias.dtype, - ) - aligned_attn_bias[:, :, :orig_size] = attn_bias - attn_bias = aligned_attn_bias[:, :, :orig_size] + if attn_bias is not None and isinstance(attn_bias, torch.Tensor): + attn_bias = CustomAttnProcessor._align_attention_mask_memory(attn_bias) return _xformers_memory_efficient_attention( query=query,