diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index d97f92d42c0..fe7a47f3a3f 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -55,7 +55,7 @@
TextConditioningData,
TextConditioningRegions,
)
-from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
+from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
@@ -884,7 +884,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
- attention_processor_cls=CustomAttnProcessor2_0,
+ attention_processor_cls=CustomAttnProcessor,
),
unet=None,
scheduler=scheduler,
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 6c39760bdc8..06a197f630e 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -28,11 +28,11 @@
DEFAULT_VRAM_CACHE = 0.25
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
-ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
-ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
+ATTENTION_TYPE = Literal["auto", "normal", "xformers", "torch-sdp"]
+ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
-CONFIG_SCHEMA_VERSION = "4.0.2"
+CONFIG_SCHEMA_VERSION = "4.0.3"
def get_default_ram_cache_size() -> float:
@@ -107,8 +107,8 @@ class InvokeAIAppConfig(BaseSettings):
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
- attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
- attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
+ attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `torch-sdp`
+ attention_slice_size: Slice size
Valid values: `auto`, `none`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
@@ -181,7 +181,7 @@ class InvokeAIAppConfig(BaseSettings):
# GENERATION
sequential_guidance: bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.")
attention_type: ATTENTION_TYPE = Field(default="auto", description="Attention type.")
- attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size, valid when attention_type=="sliced".')
+ attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size')
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
@@ -433,6 +433,30 @@ def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str
return parsed_config_dict
+def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
+ """Migrate v4.0.2 config dictionary to a v4.0.3 config dictionary.
+
+ Args:
+ config_dict: A dictionary of settings from a v4.0.2 config file.
+
+ Returns:
+ An config dict with the settings migrated to v4.0.3.
+ """
+ parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
+ attention_type = parsed_config_dict.get("attention_type", None)
+
+ # now attention_slice_size means enabling slicing attention
+ if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict:
+ del parsed_config_dict["attention_slice_size"]
+
+ # sliced moved to attention_slice_size
+ if attention_type == "sliced":
+ parsed_config_dict["attention_type"] = "auto"
+
+ parsed_config_dict["schema_version"] = "4.0.3"
+ return parsed_config_dict
+
+
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -458,6 +482,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
+ if loaded_config_dict["schema_version"] == "4.0.2":
+ migrated = True
+ loaded_config_dict = migrate_v4_0_2_to_4_0_3_config_dict(loaded_config_dict)
if migrated:
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index b3a668518b0..6c2dca11f37 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -1,13 +1,11 @@
from __future__ import annotations
import math
-from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union
import einops
import PIL.Image
-import psutil
import torch
import torchvision.transforms as T
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
@@ -15,17 +13,13 @@
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
-from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
-from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
-from invokeai.backend.util.attention import auto_detect_slice_size
-from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@@ -167,66 +161,6 @@ def __init__(
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
- def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
- """
- if xformers is available, use it, otherwise use sliced attention.
- """
- config = get_config()
- if config.attention_type == "xformers":
- self.enable_xformers_memory_efficient_attention()
- return
- elif config.attention_type == "sliced":
- slice_size = config.attention_slice_size
- if slice_size == "auto":
- slice_size = auto_detect_slice_size(latents)
- elif slice_size == "balanced":
- slice_size = "auto"
- self.enable_attention_slicing(slice_size=slice_size)
- return
- elif config.attention_type == "normal":
- self.disable_attention_slicing()
- return
- elif config.attention_type == "torch-sdp":
- if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
- # diffusers enables sdp automatically
- return
- else:
- raise Exception("torch-sdp attention slicing not available")
-
- # the remainder if this code is called when attention_type=='auto'
- if self.unet.device.type == "cuda":
- if is_xformers_available():
- self.enable_xformers_memory_efficient_attention()
- return
- elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
- # diffusers enables sdp automatically
- return
-
- if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
- mem_free = psutil.virtual_memory().free
- elif self.unet.device.type == "cuda":
- mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
- else:
- raise ValueError(f"unrecognized device {self.unet.device}")
- # input tensor of [1, 4, h/8, w/8]
- # output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
- bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
- max_size_required_for_baddbmm = (
- 16
- * latents.size(dim=2)
- * latents.size(dim=3)
- * latents.size(dim=2)
- * latents.size(dim=3)
- * bytes_per_element_needed_for_baddbmm_duplication
- )
- if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
- self.enable_attention_slicing(slice_size="max")
- elif torch.backends.mps.is_available():
- # diffusers recommends always enabling for mps
- self.enable_attention_slicing(slice_size="max")
- else:
- self.disable_attention_slicing()
-
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
@@ -321,8 +255,6 @@ def latents_from_embeddings(
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
- self._adjust_memory_efficient_attention(latents)
-
# Handle mask guidance (a.k.a. inpainting).
mask_guidance: AddsMaskGuidance | None = None
if mask is not None and not is_inpainting_model(self.unet):
@@ -347,23 +279,14 @@ def latents_from_embeddings(
is_gradient_mask=is_gradient_mask,
)
- use_ip_adapter = ip_adapter_data is not None
- use_regional_prompting = (
- conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
- )
- unet_attention_patcher = None
- attn_ctx = nullcontext()
-
- if use_ip_adapter or use_regional_prompting:
- ip_adapters: Optional[List[UNetIPAdapterData]] = (
- [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
- if use_ip_adapter
- else None
- )
- unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
- attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
+ ip_adapters: Optional[List[UNetIPAdapterData]] = None
+ if ip_adapter_data is not None:
+ ip_adapters = [
+ {"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data
+ ]
- with attn_ctx:
+ unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
+ with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model):
callback(
PipelineIntermediateState(
step=-1,
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
new file mode 100644
index 00000000000..68d3bdc7c84
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -0,0 +1,561 @@
+import math
+from dataclasses import dataclass
+from typing import List, Optional
+
+import psutil
+import torch
+import torch.nn.functional as F
+from diffusers.models.attention_processor import Attention
+from diffusers.utils.import_utils import is_xformers_available
+
+import invokeai.backend.util.logging as logger
+from invokeai.app.services.config.config_default import get_config
+from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
+from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
+from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
+from invokeai.backend.util.devices import TorchDevice
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@dataclass
+class IPAdapterAttentionWeights:
+ ip_adapter_weights: IPAttentionProcessorWeights
+ skip: bool
+
+
+class CustomAttnProcessor:
+ """A custom implementation of attention processor that supports additional Invoke features.
+ This implementation is based on
+ AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732)
+ SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
+ XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113)
+ AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
+ Supported custom features:
+ - IP-Adapter
+ - Regional prompt attention
+ """
+
+ def __init__(
+ self,
+ ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
+ ):
+ """Initialize a CustomAttnProcessor.
+ Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
+ layer-specific are passed to __init__().
+ Args:
+ ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
+ for the i'th IP-Adapter.
+ """
+
+ self._ip_adapter_attention_weights = ip_adapter_attention_weights
+
+ device = TorchDevice.choose_torch_device()
+ self.is_old_cuda = device.type == "cuda" and torch.cuda.get_device_capability(device)[0] < 8
+
+ config = get_config()
+ self.attention_type = config.attention_type
+ if self.attention_type == "auto":
+ self.attention_type = self._select_attention_type()
+
+ self.slice_size = config.attention_slice_size
+ if self.slice_size == "auto":
+ self.slice_size = self._select_slice_size()
+
+ if self.attention_type == "xformers" and xformers is None:
+ raise ImportError("xformers attention requires xformers module to be installed.")
+
+ def _select_attention_type(self) -> str:
+ device = TorchDevice.choose_torch_device()
+ # On some mps system normal attention still faster than torch-sdp, on others - on par
+ # Results torch-sdp vs normal attention
+ # gogurt: 67.993s vs 67.729s
+ # Adreitz: 260.868s vs 226.638s
+ if device.type == "mps":
+ return "normal"
+ elif device.type == "cuda":
+ # In testing on a Tesla P40 (Pascal architecture), torch-sdp is much slower than xformers
+ # (8.84 s/it vs. 1.81 s/it for SDXL). We have not tested extensively to find the precise GPU architecture or
+ # compute capability where this performance gap begins.
+ # Flash Attention is supported from sm80 compute capability onwards in PyTorch
+ # (https://pytorch.org/blog/accelerated-pytorch-2/). For now, we use this as the cutoff for selecting
+ # between xformers and torch-sdp.
+ if self.is_old_cuda:
+ if xformers is not None:
+ return "xformers"
+ logger.warning(
+ f"xFormers is not installed, but is recommended for best performance with GPU {torch.cuda.get_device_properties(device).name}"
+ )
+
+ return "torch-sdp"
+ else: # cpu
+ return "torch-sdp"
+
+ def _select_slice_size(self) -> str:
+ device = TorchDevice.choose_torch_device()
+ if device.type in ["cpu", "mps"]:
+ total_ram_gb = math.ceil(psutil.virtual_memory().total / 2**30)
+ if total_ram_gb <= 16:
+ return "max"
+ if total_ram_gb <= 32:
+ return "balanced"
+ return "none"
+ elif device.type == "cuda":
+ total_vram_gb = math.ceil(torch.cuda.get_device_properties(device).total_memory / 2**30)
+ if total_vram_gb <= 4:
+ return "max"
+ if total_vram_gb <= 6:
+ return "balanced"
+ return "none"
+ else:
+ raise ValueError(f"Unknown device: {device.type}")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ # For Regional Prompting:
+ regional_prompt_data: Optional[RegionalPromptData] = None,
+ # For IP-Adapter:
+ regional_ip_data: Optional[RegionalIPData] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ # If true, we are doing cross-attention, if false we are doing self-attention.
+ is_cross_attention = encoder_hidden_states is not None
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ query_length = hidden_states.shape[1]
+
+ # Regional Prompt Attention Mask
+ if regional_prompt_data is not None and is_cross_attention:
+ prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
+ query_seq_len=query_length, key_seq_len=key_length
+ )
+
+ if attention_mask is None:
+ attention_mask = prompt_region_attention_mask
+ else:
+ attention_mask = prompt_region_attention_mask + attention_mask
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ hidden_states = self.run_attention(
+ attn=attn,
+ query=query,
+ key=key,
+ value=value,
+ attention_mask=attention_mask,
+ )
+
+ if is_cross_attention:
+ hidden_states = self.run_ip_adapters(
+ attn=attn,
+ hidden_states=hidden_states,
+ regional_ip_data=regional_ip_data,
+ query_length=query_length,
+ query=query,
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+ def run_ip_adapters(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ regional_ip_data: Optional[RegionalIPData],
+ query_length: int, # TODO: just read from query?
+ query: torch.Tensor,
+ ) -> torch.Tensor:
+ if self._ip_adapter_attention_weights is None:
+ # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
+ assert regional_ip_data is None
+ return hidden_states
+
+ assert regional_ip_data is not None
+ ip_masks = regional_ip_data.get_masks(query_seq_len=query_length)
+
+ assert (
+ len(regional_ip_data.image_prompt_embeds)
+ == len(self._ip_adapter_attention_weights)
+ == len(regional_ip_data.scales)
+ == ip_masks.shape[1]
+ )
+
+ for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds):
+ # The batch dimensions should match.
+ # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0]
+ # The token_len dimensions should match.
+ # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1]
+
+ if self._ip_adapter_attention_weights[ipa_index].skip:
+ continue
+
+ ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
+ ipa_scale = regional_ip_data.scales[ipa_index]
+ ip_mask = ip_masks[0, ipa_index, ...]
+
+ # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
+ ip_key = ipa_weights.to_k_ip(ip_hidden_states)
+ ip_value = ipa_weights.to_v_ip(ip_hidden_states)
+
+ ip_hidden_states = self.run_attention(
+ attn=attn,
+ query=query,
+ key=ip_key,
+ value=ip_value,
+ attention_mask=None,
+ )
+
+ # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
+ hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
+
+ return hidden_states
+
+ def _get_slice_size(self, attn: Attention) -> Optional[int]:
+ if self.slice_size == "none":
+ return None
+ if isinstance(self.slice_size, int):
+ return self.slice_size
+
+ if self.slice_size == "max":
+ return 1
+ if self.slice_size == "balanced":
+ return max(1, attn.sliceable_head_dim // 2)
+
+ raise ValueError(f"Incorrect slice_size value: {self.slice_size}")
+
+ def run_attention(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ no_sliced: bool = False,
+ ) -> torch.Tensor:
+ slice_size = self._get_slice_size(attn)
+ if not no_sliced and slice_size is not None:
+ return self.run_attention_sliced(
+ attn=attn,
+ query=query,
+ key=key,
+ value=value,
+ attention_mask=attention_mask,
+ slice_size=slice_size,
+ )
+
+ if self.attention_type == "torch-sdp":
+ attn_call = self.run_attention_sdp
+ elif self.attention_type == "normal":
+ attn_call = self.run_attention_normal
+ elif self.attention_type == "xformers":
+ attn_call = self.run_attention_xformers
+ else:
+ raise Exception(f"Unknown attention type: {self.attention_type}")
+
+ return attn_call(
+ attn=attn,
+ query=query,
+ key=key,
+ value=value,
+ attention_mask=attention_mask,
+ )
+
+ @staticmethod
+ def _align_attention_mask_memory(attention_mask: torch.Tensor, alignment: int = 8) -> torch.Tensor:
+ if attention_mask.stride(-2) % alignment == 0 and attention_mask.stride(-2) != 0:
+ return attention_mask
+
+ last_mask_dim = attention_mask.shape[-1]
+ new_last_mask_dim = last_mask_dim + (alignment - (last_mask_dim % alignment))
+ attention_mask_mem = torch.empty(
+ attention_mask.shape[:-1] + (new_last_mask_dim,),
+ device=attention_mask.device,
+ dtype=attention_mask.dtype,
+ )
+ attention_mask_mem[..., :last_mask_dim] = attention_mask
+ return attention_mask_mem[..., :last_mask_dim]
+
+ @staticmethod
+ def _head_to_batch_dim(tensor: torch.Tensor, head_dim: int) -> torch.Tensor:
+ # [B, S, H*He] -> [B, S, H, He]
+ tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, head_dim)
+ # [B, S, H, He] -> [B, H, S, He]
+ tensor = tensor.permute(0, 2, 1, 3)
+ # [B, H, S, He] -> [B*H, S, He]
+ tensor = tensor.reshape(-1, tensor.shape[2], head_dim)
+ return tensor
+
+ @staticmethod
+ def _batch_to_head_dim(tensor: torch.Tensor, batch_size: int) -> torch.Tensor:
+ # [B*H, S, He] -> [B, H, S, He]
+ tensor = tensor.reshape(batch_size, -1, tensor.shape[1], tensor.shape[2])
+ # [B, H, S, He] -> [B, S, H, He]
+ tensor = tensor.permute(0, 2, 1, 3)
+ # [B, S, H, He] -> [B, S, H*He]
+ tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1)
+ return tensor
+
+ def run_attention_normal(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ batch_size = query.shape[0]
+ head_dim = attn.to_q.weight.shape[0] // attn.heads
+
+ query = self._head_to_batch_dim(query, head_dim)
+ key = self._head_to_batch_dim(key, head_dim)
+ value = self._head_to_batch_dim(value, head_dim)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+
+ hidden_states = self._batch_to_head_dim(hidden_states, batch_size)
+ return hidden_states
+
+ def run_attention_xformers(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ multihead: Optional[bool] = None,
+ ) -> torch.Tensor:
+ batch_size = query.shape[0]
+ head_dim = attn.to_q.weight.shape[0] // attn.heads
+
+ # batched execution on xformers slightly faster for small heads count
+ # 8 heads, fp16 (100000 attention calls):
+ # xformers(dim3): 20.155955553054810s vram: 16483328b
+ # xformers(dim4): 17.558132648468018s vram: 16483328b
+ # 1 head, fp16 (100000 attention calls):
+ # xformers(dim3): 5.660739183425903s vram: 9516032b
+ # xformers(dim4): 6.114191055297852s vram: 9516032b
+ if multihead is None:
+ heads_count = query.shape[2] // head_dim
+ multihead = heads_count >= 4
+
+ if multihead:
+ # [B, S, H*He] -> [B, S, H, He]
+ query = query.view(batch_size, query.shape[1], -1, head_dim)
+ key = key.view(batch_size, key.shape[1], -1, head_dim)
+ value = value.view(batch_size, value.shape[1], -1, head_dim)
+
+ if attention_mask is not None:
+ # [B*H, 1, S_key] -> [B, H, 1, S_key]
+ attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
+ # expand our mask's singleton query dimension:
+ # [B, H, 1, S_key] ->
+ # [B, H, S_query, S_key]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ attention_mask = attention_mask.expand(-1, -1, query.shape[1], -1)
+ # xformers requires mask memory to be aligned to 8
+ attention_mask = self._align_attention_mask_memory(attention_mask)
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
+ )
+ # [B, S_query, H, He] -> [B, S_query, H*He]
+ hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,))
+ hidden_states = hidden_states.to(query.dtype)
+
+ else:
+ # contiguous inputs slightly faster in batched execution
+ # [B, S, H*He] -> [B*H, S, He]
+ query = self._head_to_batch_dim(query, head_dim).contiguous()
+ key = self._head_to_batch_dim(key, head_dim).contiguous()
+ value = self._head_to_batch_dim(value, head_dim).contiguous()
+
+ if attention_mask is not None:
+ # expand our mask's singleton query dimension:
+ # [B*H, 1, S_key] ->
+ # [B*H, S_query, S_key]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ attention_mask = attention_mask.expand(-1, query.shape[1], -1)
+ # xformers requires mask memory to be aligned to 8
+ attention_mask = self._align_attention_mask_memory(attention_mask)
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ # [B*H, S_query, He] -> [B, S_query, H*He]
+ hidden_states = self._batch_to_head_dim(hidden_states, batch_size)
+
+ return hidden_states
+
+ def run_attention_sdp(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ multihead: Optional[bool] = None,
+ ) -> torch.Tensor:
+ batch_size = query.shape[0]
+ head_dim = attn.to_q.weight.shape[0] // attn.heads
+
+ if multihead is None:
+ # multihead extremely slow on old cuda gpu:
+ # fp16 (100000 attention calls):
+ # torch-sdp(dim3): 30.07543110847473s vram: 23954432b
+ # torch-sdp(dim4): 299.3908393383026s vram: 13861888b
+ multihead = not self.is_old_cuda
+
+ if multihead:
+ # [B, S, H*He] -> [B, H, S, He]
+ query = query.view(batch_size, query.shape[1], -1, head_dim).transpose(1, 2)
+ key = key.view(batch_size, key.shape[1], -1, head_dim).transpose(1, 2)
+ value = value.view(batch_size, value.shape[1], -1, head_dim).transpose(1, 2)
+
+ if attention_mask is not None:
+ # [B*H, 1, S_key] -> [B, H, 1, S_key]
+ attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
+ # mask alignment to 8 decreases memory consumption and increases speed
+ # fp16 (100000 attention calls):
+ # torch-sdp(dim4, mask): 6.1701478958129880s vram: 7864320b
+ # torch-sdp(dim4, aligned mask): 3.3127212524414062s vram: 2621440b
+ # fp32 (100000 attention calls):
+ # torch-sdp(dim4, mask): 23.0943229198455800s vram: 16121856b
+ # torch-sdp(dim4, aligned mask): 17.3104763031005860s vram: 5636096b
+ attention_mask = self._align_attention_mask_memory(attention_mask)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
+ )
+
+ # [B, H, S_query, He] -> [B, S_query, H, He]
+ hidden_states = hidden_states.transpose(1, 2)
+ # [B, S_query, H, He] -> [B, S_query, H*He]
+ hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,))
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ # [B, S, H*He] -> [B*H, S, He]
+ query = self._head_to_batch_dim(query, head_dim)
+ key = self._head_to_batch_dim(key, head_dim)
+ value = self._head_to_batch_dim(value, head_dim)
+
+ # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key]
+ # and there no noticable changes from memory alignment in batched run:
+ # fp16 (100000 attention calls):
+ # torch-sdp(dim3, mask): 9.7391905784606930s vram: 12713984b
+ # torch-sdp(dim3, aligned mask): 10.0090200901031500s vram: 12713984b
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
+ )
+
+ hidden_states = hidden_states.to(query.dtype)
+ # [B*H, S_query, He] -> [B, S_query, H*He]
+ hidden_states = self._batch_to_head_dim(hidden_states, batch_size)
+
+ return hidden_states
+
+ def run_attention_sliced(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ slice_size: int,
+ ) -> torch.Tensor:
+ batch_size = query.shape[0]
+ head_dim = attn.to_q.weight.shape[0] // attn.heads
+ heads_count = query.shape[2] // head_dim
+
+ # [B, S, H*He] -> [B, H, S, He]
+ query = query.reshape(query.shape[0], query.shape[1], -1, head_dim).transpose(1, 2)
+ key = key.reshape(key.shape[0], key.shape[1], -1, head_dim).transpose(1, 2)
+ value = value.reshape(value.shape[0], value.shape[1], -1, head_dim).transpose(1, 2)
+ # [B*H, S_query/1, S_key] -> [B, H, S_query/1, S_key]
+ if attention_mask is not None:
+ attention_mask = attention_mask.reshape(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
+
+ # [B, H, S_query, He]
+ hidden_states = torch.empty(query.shape, device=query.device, dtype=query.dtype)
+
+ for i in range((heads_count - 1) // slice_size + 1):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ # [B, H_s, S, He] -> [B, S, H_s*He]
+ query_slice = query[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, query.shape[2], -1)
+ key_slice = key[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, key.shape[2], -1)
+ value_slice = value[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, value.shape[2], -1)
+
+ # [B, H_s, S_query/1, S_key] -> [B*H_s, S_query/1, S_key]
+ attn_mask_slice = None
+ if attention_mask is not None:
+ attn_mask_slice = attention_mask[:, start_idx:end_idx, :, :].reshape((-1,) + attention_mask.shape[-2:])
+
+ # [B, S_query, H_s*He]
+ hidden_states_slice = self.run_attention(
+ attn=attn,
+ query=query_slice,
+ key=key_slice,
+ value=value_slice,
+ attention_mask=attn_mask_slice,
+ no_sliced=True,
+ )
+
+ # [B, S_query, H_s*He] -> [B, H_s, S_query, He]
+ hidden_states[:, start_idx:end_idx] = hidden_states_slice.reshape(
+ hidden_states_slice.shape[:-1] + (-1, head_dim)
+ ).transpose(1, 2)
+
+ # [B, H_s, S_query, He] -> [B, S_query, H_s*He]
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states.reshape(hidden_states.shape[:-2] + (-1,))
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
deleted file mode 100644
index 1334313fe6e..00000000000
--- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
+++ /dev/null
@@ -1,214 +0,0 @@
-from dataclasses import dataclass
-from typing import List, Optional, cast
-
-import torch
-import torch.nn.functional as F
-from diffusers.models.attention_processor import Attention, AttnProcessor2_0
-
-from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
-from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
-from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
-
-
-@dataclass
-class IPAdapterAttentionWeights:
- ip_adapter_weights: IPAttentionProcessorWeights
- skip: bool
-
-
-class CustomAttnProcessor2_0(AttnProcessor2_0):
- """A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
- This implementation is based on
- https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
- Supported custom features:
- - IP-Adapter
- - Regional prompt attention
- """
-
- def __init__(
- self,
- ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
- ):
- """Initialize a CustomAttnProcessor2_0.
- Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
- layer-specific are passed to __init__().
- Args:
- ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
- for the i'th IP-Adapter.
- """
- super().__init__()
- self._ip_adapter_attention_weights = ip_adapter_attention_weights
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- # For Regional Prompting:
- regional_prompt_data: Optional[RegionalPromptData] = None,
- percent_through: Optional[torch.Tensor] = None,
- # For IP-Adapter:
- regional_ip_data: Optional[RegionalIPData] = None,
- *args,
- **kwargs,
- ) -> torch.FloatTensor:
- """Apply attention.
- Args:
- regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
- apply regional prompt masking.
- regional_ip_data: The IP-Adapter data for the current batch.
- """
- # If true, we are doing cross-attention, if false we are doing self-attention.
- is_cross_attention = encoder_hidden_states is not None
-
- # Start unmodified block from AttnProcessor2_0.
- # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- # End unmodified block from AttnProcessor2_0.
-
- _, query_seq_len, _ = hidden_states.shape
- # Handle regional prompt attention masks.
- if regional_prompt_data is not None and is_cross_attention:
- assert percent_through is not None
- prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
- query_seq_len=query_seq_len, key_seq_len=sequence_length
- )
-
- if attention_mask is None:
- attention_mask = prompt_region_attention_mask
- else:
- attention_mask = prompt_region_attention_mask + attention_mask
-
- # Start unmodified block from AttnProcessor2_0.
- # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- # End unmodified block from AttnProcessor2_0.
-
- # Apply IP-Adapter conditioning.
- if is_cross_attention:
- if self._ip_adapter_attention_weights:
- assert regional_ip_data is not None
- ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
-
- assert (
- len(regional_ip_data.image_prompt_embeds)
- == len(self._ip_adapter_attention_weights)
- == len(regional_ip_data.scales)
- == ip_masks.shape[1]
- )
-
- for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
- ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
- ipa_scale = regional_ip_data.scales[ipa_index]
- ip_mask = ip_masks[0, ipa_index, ...]
-
- # The batch dimensions should match.
- assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
- # The token_len dimensions should match.
- assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
-
- ip_hidden_states = ipa_embed
-
- # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
-
- if not self._ip_adapter_attention_weights[ipa_index].skip:
- ip_key = ipa_weights.to_k_ip(ip_hidden_states)
- ip_value = ipa_weights.to_v_ip(ip_hidden_states)
-
- # Expected ip_key and ip_value shape:
- # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # Expected ip_key and ip_value shape:
- # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
-
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
-
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
- # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
- hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
- else:
- # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
- assert regional_ip_data is None
-
- # Start unmodified block from AttnProcessor2_0.
- # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- # End of unmodified block from AttnProcessor2_0
-
- # casting torch.Tensor to torch.FloatTensor to avoid type issues
- return cast(torch.FloatTensor, hidden_states)
diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
index f418133e49f..d5f54ccff16 100644
--- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
+++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
@@ -335,7 +335,6 @@ def _apply_standard_conditioning(
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
- cross_attention_kwargs["percent_through"] = step_index / total_step_count
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
@@ -426,7 +425,6 @@ def _apply_standard_conditioning_sequentially(
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
- cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
@@ -474,7 +472,6 @@ def _apply_standard_conditioning_sequentially(
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
- cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
index ac00a8e06ea..8ba8b3acf38 100644
--- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
+++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
@@ -4,8 +4,8 @@
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
-from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
- CustomAttnProcessor2_0,
+from invokeai.backend.stable_diffusion.diffusion.custom_attention import (
+ CustomAttnProcessor,
IPAdapterAttentionWeights,
)
@@ -16,7 +16,7 @@ class UNetIPAdapterData(TypedDict):
class UNetAttentionPatcher:
- """A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
+ """A class for patching a UNet with CustomAttnProcessor attention layers."""
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapter_data
@@ -27,11 +27,12 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel):
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
+
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor") or self._ip_adapters is None:
# "attn1" processors do not use IP-Adapters.
- attn_procs[name] = CustomAttnProcessor2_0()
+ attn_procs[name] = CustomAttnProcessor()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
@@ -48,12 +49,14 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel):
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
- attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
+ attn_procs[name] = CustomAttnProcessor(
+ ip_adapter_attention_weights=ip_adapter_attention_weights_collection,
+ )
return attn_procs
@contextmanager
- def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
+ def apply_custom_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py
index 4191db734f9..dddf7e555d4 100644
--- a/invokeai/backend/stable_diffusion/diffusion_backend.py
+++ b/invokeai/backend/stable_diffusion/diffusion_backend.py
@@ -114,9 +114,6 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio
sample=sample,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
- cross_attention_kwargs=dict( # noqa: C408
- percent_through=ctx.step_index / len(ctx.inputs.timesteps),
- ),
)
ctx.conditioning_mode = conditioning_mode
diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py
index a48a681af3f..8728779e00b 100644
--- a/invokeai/backend/stable_diffusion/extensions/controlnet.py
+++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py
@@ -112,8 +112,6 @@ def pre_unet_step(self, ctx: DenoiseContext):
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
- total_steps = len(ctx.inputs.timesteps)
-
model_input = ctx.latent_model_input
image_tensor = self._image_tensor
if conditioning_mode == ConditioningMode.Both:
@@ -124,9 +122,6 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
sample=model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditioning
- cross_attention_kwargs=dict( # noqa: C408
- percent_through=ctx.step_index / total_steps,
- ),
)
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
index 6c07fc1c2c8..f78facaf581 100644
--- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
+++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
@@ -13,6 +13,7 @@
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
+from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.tiles.utils import Tile
@@ -63,132 +64,132 @@ def multi_diffusion_denoise(
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
assert isinstance(latents, torch.Tensor) # For static type checking.
- # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
- # cropping into regions.
- self._adjust_memory_efficient_attention(latents)
-
- # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
- # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
- # separate scheduler state for each region batch.
- # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
- # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
- # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
- # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
- # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
- region_batch_schedulers: list[SchedulerMixin] = [
- copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
- ]
-
- callback(
- PipelineIntermediateState(
- step=-1,
- order=self.scheduler.order,
- total_steps=len(timesteps),
- timestep=self.scheduler.config.num_train_timesteps,
- latents=latents,
- )
- )
-
- for i, t in enumerate(self.progress_bar(timesteps)):
- batched_t = t.expand(batch_size)
-
- merged_latents = torch.zeros_like(latents)
- merged_latents_weights = torch.zeros(
- (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
- )
- merged_pred_original: torch.Tensor | None = None
- for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
- # Switch to the scheduler for the region batch.
- self.scheduler = region_batch_schedulers[region_idx]
-
- # Crop the inputs to the region.
- region_latents = latents[
- :,
- :,
- region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
- region_conditioning.region.coords.left : region_conditioning.region.coords.right,
- ]
-
- # Run the denoising step on the region.
- step_output = self.step(
- t=batched_t,
- latents=region_latents,
- conditioning_data=region_conditioning.text_conditioning_data,
- step_index=i,
- total_step_count=len(timesteps),
- scheduler_step_kwargs=scheduler_step_kwargs,
- mask_guidance=None,
- mask=None,
- masked_latents=None,
- control_data=region_conditioning.control_data,
- )
-
- # Build a region_weight matrix that applies gradient blending to the edges of the region.
- region = region_conditioning.region
- _, _, region_height, region_width = step_output.prev_sample.shape
- region_weight = torch.ones(
- (1, 1, region_height, region_width),
- dtype=latents.dtype,
- device=latents.device,
- )
- if region.overlap.left > 0:
- left_grad = torch.linspace(
- 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
- ).view((1, 1, 1, -1))
- region_weight[:, :, :, : region.overlap.left] *= left_grad
- if region.overlap.top > 0:
- top_grad = torch.linspace(
- 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
- ).view((1, 1, -1, 1))
- region_weight[:, :, : region.overlap.top, :] *= top_grad
- if region.overlap.right > 0:
- right_grad = torch.linspace(
- 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
- ).view((1, 1, 1, -1))
- region_weight[:, :, :, -region.overlap.right :] *= right_grad
- if region.overlap.bottom > 0:
- bottom_grad = torch.linspace(
- 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
- ).view((1, 1, -1, 1))
- region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
-
- # Update the merged results with the region results.
- merged_latents[
- :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
- ] += step_output.prev_sample * region_weight
- merged_latents_weights[
- :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
- ] += region_weight
-
- pred_orig_sample = getattr(step_output, "pred_original_sample", None)
- if pred_orig_sample is not None:
- # If one region has pred_original_sample, then we can assume that all regions will have it, because
- # they all use the same scheduler.
- if merged_pred_original is None:
- merged_pred_original = torch.zeros_like(latents)
- merged_pred_original[
- :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
- ] += pred_orig_sample
-
- # Normalize the merged results.
- latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
- # For debugging, uncomment this line to visualize the region seams:
- # latents = torch.where(merged_latents_weights > 1, 0.0, latents)
- predicted_original = None
- if merged_pred_original is not None:
- predicted_original = torch.where(
- merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
- )
+ unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None)
+ with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model):
+ # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
+ # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
+ # separate scheduler state for each region batch.
+ # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
+ # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
+ # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
+ # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
+ # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
+ region_batch_schedulers: list[SchedulerMixin] = [
+ copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
+ ]
callback(
PipelineIntermediateState(
- step=i,
+ step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
- timestep=int(t),
+ timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
- predicted_original=predicted_original,
)
)
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ batched_t = t.expand(batch_size)
+
+ merged_latents = torch.zeros_like(latents)
+ merged_latents_weights = torch.zeros(
+ (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
+ )
+ merged_pred_original: torch.Tensor | None = None
+ for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
+ # Switch to the scheduler for the region batch.
+ self.scheduler = region_batch_schedulers[region_idx]
+
+ # Crop the inputs to the region.
+ region_latents = latents[
+ :,
+ :,
+ region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
+ region_conditioning.region.coords.left : region_conditioning.region.coords.right,
+ ]
+
+ # Run the denoising step on the region.
+ step_output = self.step(
+ t=batched_t,
+ latents=region_latents,
+ conditioning_data=region_conditioning.text_conditioning_data,
+ step_index=i,
+ total_step_count=len(timesteps),
+ scheduler_step_kwargs=scheduler_step_kwargs,
+ mask_guidance=None,
+ mask=None,
+ masked_latents=None,
+ control_data=region_conditioning.control_data,
+ )
+
+ # Build a region_weight matrix that applies gradient blending to the edges of the region.
+ region = region_conditioning.region
+ _, _, region_height, region_width = step_output.prev_sample.shape
+ region_weight = torch.ones(
+ (1, 1, region_height, region_width),
+ dtype=latents.dtype,
+ device=latents.device,
+ )
+ if region.overlap.left > 0:
+ left_grad = torch.linspace(
+ 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
+ ).view((1, 1, 1, -1))
+ region_weight[:, :, :, : region.overlap.left] *= left_grad
+ if region.overlap.top > 0:
+ top_grad = torch.linspace(
+ 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
+ ).view((1, 1, -1, 1))
+ region_weight[:, :, : region.overlap.top, :] *= top_grad
+ if region.overlap.right > 0:
+ right_grad = torch.linspace(
+ 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
+ ).view((1, 1, 1, -1))
+ region_weight[:, :, :, -region.overlap.right :] *= right_grad
+ if region.overlap.bottom > 0:
+ bottom_grad = torch.linspace(
+ 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
+ ).view((1, 1, -1, 1))
+ region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
+
+ # Update the merged results with the region results.
+ merged_latents[
+ :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
+ ] += step_output.prev_sample * region_weight
+ merged_latents_weights[
+ :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
+ ] += region_weight
+
+ pred_orig_sample = getattr(step_output, "pred_original_sample", None)
+ if pred_orig_sample is not None:
+ # If one region has pred_original_sample, then we can assume that all regions will have it, because
+ # they all use the same scheduler.
+ if merged_pred_original is None:
+ merged_pred_original = torch.zeros_like(latents)
+ merged_pred_original[
+ :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
+ ] += pred_orig_sample
+
+ # Normalize the merged results.
+ latents = torch.where(
+ merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents
+ )
+ # For debugging, uncomment this line to visualize the region seams:
+ # latents = torch.where(merged_latents_weights > 1, 0.0, latents)
+ predicted_original = None
+ if merged_pred_original is not None:
+ predicted_original = torch.where(
+ merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
+ )
+
+ callback(
+ PipelineIntermediateState(
+ step=i,
+ order=self.scheduler.order,
+ total_steps=len(timesteps),
+ timestep=int(t),
+ latents=latents,
+ predicted_original=predicted_original,
+ )
+ )
+
return latents
diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py
index 7e362fe9589..1991aed3ccc 100644
--- a/invokeai/backend/util/hotfixes.py
+++ b/invokeai/backend/util/hotfixes.py
@@ -802,8 +802,11 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0):
if xformers_available:
# TODO: remove when fixed in diffusers
+ from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor
+
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
+ # TODO: remove? or there still possible calls to xformers not by our attention processor?
def new_memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
@@ -815,16 +818,8 @@ def new_memory_efficient_attention(
op=None,
):
# diffusers not align shape to 8, which is required by xformers
- if attn_bias is not None and type(attn_bias) is torch.Tensor:
- orig_size = attn_bias.shape[-1]
- new_size = ((orig_size + 7) // 8) * 8
- aligned_attn_bias = torch.zeros(
- (attn_bias.shape[0], attn_bias.shape[1], new_size),
- device=attn_bias.device,
- dtype=attn_bias.dtype,
- )
- aligned_attn_bias[:, :, :orig_size] = attn_bias
- attn_bias = aligned_attn_bias[:, :, :orig_size]
+ if attn_bias is not None and isinstance(attn_bias, torch.Tensor):
+ attn_bias = CustomAttnProcessor._align_attention_mask_memory(attn_bias)
return _xformers_memory_efficient_attention(
query=query,