From cd2dccf6f294d070f750788b3e5ca296706b66e1 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 27 Jun 2024 16:18:21 +0300 Subject: [PATCH 01/25] Redo attention processor to support other attention types --- .../stable_diffusion/diffusers_pipeline.py | 4 +- .../diffusion/custom_atttention.py | 416 +++++++++++++----- .../diffusion/unet_attention_patcher.py | 43 +- 3 files changed, 347 insertions(+), 116 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ee464f73e1f..fe5310c216d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -330,8 +330,6 @@ def latents_from_embeddings( # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) - self._adjust_memory_efficient_attention(latents) - # Handle mask guidance (a.k.a. inpainting). mask_guidance: AddsMaskGuidance | None = None if mask is not None and not is_inpainting_model(self.unet): @@ -371,6 +369,8 @@ def latents_from_embeddings( ) unet_attention_patcher = UNetAttentionPatcher(ip_adapters) attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + else: + self._adjust_memory_efficient_attention(latents) with attn_ctx: callback( diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 1334313fe6e..8a2c62354d0 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -1,14 +1,20 @@ from dataclasses import dataclass -from typing import List, Optional, cast +from typing import List, Optional, Union, Callable, cast import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention, AttnProcessor2_0 +from diffusers.utils.import_utils import is_xformers_available from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None @dataclass class IPAdapterAttentionWeights: @@ -16,10 +22,13 @@ class IPAdapterAttentionWeights: skip: bool -class CustomAttnProcessor2_0(AttnProcessor2_0): - """A custom implementation of AttnProcessor2_0 that supports additional Invoke features. +class CustomAttnProcessor: + """A custom implementation of attention processor that supports additional Invoke features. This implementation is based on - https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204 + AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732) + SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616) + XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113) + AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204) Supported custom features: - IP-Adapter - Regional prompt attention @@ -27,17 +36,48 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): def __init__( self, + attention_type: str, ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, + # xformers + attention_op: Optional[Callable] = None, + # sliced + slice_size: Optional[Union[str, int]] = None, # TODO: or "auto"? + ): - """Initialize a CustomAttnProcessor2_0. + """Initialize a CustomAttnProcessor. Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are layer-specific are passed to __init__(). Args: ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights for the i'th IP-Adapter. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. """ - super().__init__() + if attention_type not in ["normal", "sliced", "xformers", "torch-sdp"]: + raise Exception(f"Unknown attention type: {attention_type}") + + if attention_type == "xformers" and xformers is None: + raise ImportError("xformers attention requires xformers module to be installed.") + + if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + if attention_type == "sliced": + if slice_size is None: + raise Exception(f"slice_size required for sliced attention") + if slice_size not in ["auto", "max"] and not isinstance(slice_size, int): + raise Exception(f"Unsupported slice_size: {slice_size}") + self._ip_adapter_attention_weights = ip_adapter_attention_weights + self.attention_type = attention_type + self.attention_op = attention_op + self.slice_size = slice_size def __call__( self, @@ -53,19 +93,12 @@ def __call__( regional_ip_data: Optional[RegionalIPData] = None, *args, **kwargs, - ) -> torch.FloatTensor: - """Apply attention. - Args: - regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to - apply regional prompt masking. - regional_ip_data: The IP-Adapter data for the current batch. - """ + ) -> torch.Tensor: # If true, we are doing cross-attention, if false we are doing self-attention. is_cross_attention = encoder_hidden_states is not None - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv residual = hidden_states + if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -75,18 +108,134 @@ def __call__( batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( + batch_size, key_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # End unmodified block from AttnProcessor2_0. + query_length = hidden_states.shape[1] + + attention_mask = self.prepare_attention_mask( + attn=attn, + attention_mask=attention_mask, + batch_size=batch_size, + key_length=key_length, + query_length=query_length, + is_cross_attention=is_cross_attention, + regional_prompt_data=regional_prompt_data, + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + hidden_states = self.run_attention( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + if is_cross_attention: + hidden_states = self.run_ip_adapters( + attn=attn, + hidden_states=hidden_states, + regional_ip_data=regional_ip_data, + query_length=query_length, + query=query, + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + def run_ip_adapters( + self, + attn: Attention, + hidden_states: torch.Tensor, + regional_ip_data: Optional[RegionalIPData], + query_length: int, # TODO: just read from query? + query: torch.Tensor, + ) -> torch.Tensor: + if self._ip_adapter_attention_weights is None: + # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. + assert regional_ip_data is None + return hidden_states + + ip_masks = regional_ip_data.get_masks(query_seq_len=query_length) + + assert ( + len(regional_ip_data.image_prompt_embeds) + == len(self._ip_adapter_attention_weights) + == len(regional_ip_data.scales) + == ip_masks.shape[1] + ) + + for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds): + # The batch dimensions should match. + #assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0] + # The token_len dimensions should match. + #assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1] + + if self._ip_adapter_attention_weights[ipa_index].skip: + continue + + ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights + ipa_scale = regional_ip_data.scales[ipa_index] + ip_mask = ip_masks[0, ipa_index, ...] + + # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) + + ip_hidden_states = self.run_attention( + attn=attn, + query=query, + key=ip_key, + value=ip_value, + attention_mask=None, + ) + + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask + + return hidden_states + + + def prepare_attention_mask( + self, + attn: Attention, + attention_mask: Optional[torch.Tensor], + batch_size: int, + key_length: int, + query_length: int, + is_cross_attention: bool, + regional_prompt_data: Optional[RegionalPromptData], + ) -> Optional[torch.Tensor]: - _, query_seq_len, _ = hidden_states.shape - # Handle regional prompt attention masks. if regional_prompt_data is not None and is_cross_attention: - assert percent_through is not None prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_seq_len, key_seq_len=sequence_length + query_seq_len=query_length, key_seq_len=key_length ) if attention_mask is None: @@ -94,32 +243,112 @@ def __call__( else: attention_mask = prompt_region_attention_mask + attention_mask - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) - query = attn.to_q(hidden_states) + if self.attention_type in ["normal", "sliced"]: + pass - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + elif self.attention_type == "xformers": + if attention_mask is not None: + # expand our mask's singleton query_length dimension: + # [batch*heads, 1, key_length] -> + # [batch*heads, query_length, key_length] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_length, key_length] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + attention_mask = attention_mask.expand(-1, query_length, -1) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + elif self.attention_type == "torch-sdp": + if attention_mask is not None: + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + else: + raise Exception(f"Unknown attention type: {self.attention_type}") + + return attention_mask + def run_attention( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.attention_type == "normal": + attn_call = self.run_attention_normal + elif self.attention_type == "xformers": + attn_call = self.run_attention_xformers + elif self.attention_type == "torch-sdp": + attn_call = self.run_attention_sdp + elif self.attention_type == "sliced": + attn_call = self.run_attention_sliced + else: + raise Exception(f"Unknown attention type: {self.attention_type}") + + return attn_call( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + def run_attention_normal( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + return hidden_states + + def run_attention_xformers( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + # attention_op + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + return hidden_states + + def run_attention_sdp( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + batch_size = key.shape[0] inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -131,84 +360,51 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # End unmodified block from AttnProcessor2_0. - - # Apply IP-Adapter conditioning. - if is_cross_attention: - if self._ip_adapter_attention_weights: - assert regional_ip_data is not None - ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) - - assert ( - len(regional_ip_data.image_prompt_embeds) - == len(self._ip_adapter_attention_weights) - == len(regional_ip_data.scales) - == ip_masks.shape[1] - ) - - for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): - ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights - ipa_scale = regional_ip_data.scales[ipa_index] - ip_mask = ip_masks[0, ipa_index, ...] - - # The batch dimensions should match. - assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] - ip_hidden_states = ipa_embed + return hidden_states - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - - if not self._ip_adapter_attention_weights[ipa_index].skip: - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - # Expected ip_key and ip_value shape: - # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Expected ip_key and ip_value shape: - # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + def run_attention_sliced( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + # slice_size + if self.slice_size == "max": + slice_size = 1 + elif self.slice_size == "auto": + slice_size = max(1, attn.sliceable_head_dim // 2) + else: + slice_size = min(self.slice_size, attn.sliceable_head_dim) + + dim = query.shape[-1] + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + for i in range(batch_size_attention // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size - ip_hidden_states = ip_hidden_states.to(query.dtype) + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask - else: - # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. - assert regional_ip_data is None + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + hidden_states[start_idx:end_idx] = attn_slice - if attn.residual_connection: - hidden_states = hidden_states + residual + hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = hidden_states / attn.rescale_output_factor - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # End of unmodified block from AttnProcessor2_0 + return hidden_states - # casting torch.Tensor to torch.FloatTensor to avoid type issues - return cast(torch.FloatTensor, hidden_states) diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index ac00a8e06ea..c2f79607c02 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -2,10 +2,12 @@ from typing import List, Optional, TypedDict from diffusers.models import UNet2DConditionModel +from diffusers.utils.import_utils import is_xformers_available +from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( - CustomAttnProcessor2_0, + CustomAttnProcessor, IPAdapterAttentionWeights, ) @@ -16,22 +18,52 @@ class UNetIPAdapterData(TypedDict): class UNetAttentionPatcher: - """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" + """A class for patching a UNet with CustomAttnProcessor attention layers.""" def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]): self._ip_adapters = ip_adapter_data + def get_attention_processor_kwargs(self, unet: UNet2DConditionModel): + config = get_config() + kwargs = dict() + + # TODO: + attention_type = config.attention_type + if attention_type == "auto": + if self.unet.device.type == "cuda": + if is_xformers_available(): + attention_type = "xformers" + elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): + attention_type = "torch-sdp" + else: + attention_type = "normal" + else: + attention_type = "sliced" + + kwargs["attention_type"] = attention_type + + if attention_type == "sliced": + slice_size = config.attention_slice_size + if slice_size == "balanced": + slice_size = "auto" + kwargs["slice_size"] = slice_size + + return kwargs + def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention weights into them (if IP-Adapters are being applied). Note that the `unet` param is only used to determine attention block dimensions and naming. """ # Construct a dict of attention processors based on the UNet's architecture. + + attn_processor_kwargs = self.get_attention_processor_kwargs(unet) + attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): if name.endswith("attn1.processor") or self._ip_adapters is None: # "attn1" processors do not use IP-Adapters. - attn_procs[name] = CustomAttnProcessor2_0() + attn_procs[name] = CustomAttnProcessor(**attn_processor_kwargs) else: # Collect the weights from each IP Adapter for the idx'th attention processor. ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = [] @@ -48,7 +80,10 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel): ) ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights) - attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection) + attn_procs[name] = CustomAttnProcessor( + ip_adapter_attention_weights=ip_adapter_attention_weights_collection, + **attn_processor_kwargs, + ) return attn_procs From 9f40c2da8d81fdcb3af48e741a88159eee53508e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 28 Jul 2024 02:24:40 +0300 Subject: [PATCH 02/25] Remove xformers and normal attention --- invokeai/app/api/routers/app_info.py | 8 +- .../stable_diffusion/diffusers_pipeline.py | 36 +++--- .../diffusion/custom_atttention.py | 107 +++--------------- .../diffusion/unet_attention_patcher.py | 22 ++-- invokeai/backend/util/hotfixes.py | 46 -------- 5 files changed, 38 insertions(+), 181 deletions(-) diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index c3bc98a0387..22827ef7879 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -1,6 +1,6 @@ import typing from enum import Enum -from importlib.metadata import PackageNotFoundError, version +from importlib.metadata import version from pathlib import Path from platform import python_version from typing import Optional @@ -76,10 +76,6 @@ async def get_version() -> AppVersion: @app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions) async def get_app_deps() -> AppDependencyVersions: - try: - xformers = version("xformers") - except PackageNotFoundError: - xformers = None return AppDependencyVersions( accelerate=version("accelerate"), compel=version("compel"), @@ -93,7 +89,7 @@ async def get_app_deps() -> AppDependencyVersions: torch=torch.version.__version__, torchvision=version("torchvision"), transformers=version("transformers"), - xformers=xformers, + xformers=None, # TODO: ask frontend ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index fe5310c216d..33115cbeb4d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -10,15 +10,16 @@ import psutil import torch import torchvision.transforms as T +from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -from diffusers.utils.import_utils import is_xformers_available from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent @@ -177,14 +178,13 @@ def __init__( self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) def _adjust_memory_efficient_attention(self, latents: torch.Tensor): - """ - if xformers is available, use it, otherwise use sliced attention. - """ config = get_config() - if config.attention_type == "xformers": - self.enable_xformers_memory_efficient_attention() - return - elif config.attention_type == "sliced": + attention_type = config.attention_type + if attention_type in ["normal", "xformers"]: + logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') + attention_type = "torch-sdp" + + if config.attention_type == "sliced": slice_size = config.attention_slice_size if slice_size == "auto": slice_size = auto_detect_slice_size(latents) @@ -192,24 +192,14 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor): slice_size = "auto" self.enable_attention_slicing(slice_size=slice_size) return - elif config.attention_type == "normal": - self.disable_attention_slicing() - return elif config.attention_type == "torch-sdp": - if hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return - else: - raise Exception("torch-sdp attention slicing not available") + self.unet.set_attn_processor(AttnProcessor2_0()) + return # the remainder if this code is called when attention_type=='auto' if self.unet.device.type == "cuda": - if is_xformers_available(): - self.enable_xformers_memory_efficient_attention() - return - elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return + self.unet.set_attn_processor(AttnProcessor2_0()) + return if self.unet.device.type == "cpu" or self.unet.device.type == "mps": mem_free = psutil.virtual_memory().free @@ -234,7 +224,7 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor): # diffusers recommends always enabling for mps self.enable_attention_slicing(slice_size="max") else: - self.disable_attention_slicing() + self.unet.set_attn_processor(AttnProcessor2_0()) def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): raise Exception("Should not be called") diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 8a2c62354d0..d5f78ded70f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -1,20 +1,14 @@ from dataclasses import dataclass -from typing import List, Optional, Union, Callable, cast +from typing import List, Optional, Union import torch import torch.nn.functional as F -from diffusers.models.attention_processor import Attention, AttnProcessor2_0 -from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention_processor import Attention from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None @dataclass class IPAdapterAttentionWeights: @@ -25,9 +19,7 @@ class IPAdapterAttentionWeights: class CustomAttnProcessor: """A custom implementation of attention processor that supports additional Invoke features. This implementation is based on - AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732) SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616) - XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113) AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204) Supported custom features: - IP-Adapter @@ -38,11 +30,8 @@ def __init__( self, attention_type: str, ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, - # xformers - attention_op: Optional[Callable] = None, # sliced - slice_size: Optional[Union[str, int]] = None, # TODO: or "auto"? - + slice_size: Optional[Union[str, int]] = None, ): """Initialize a CustomAttnProcessor. Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are @@ -50,33 +39,24 @@ def __init__( Args: ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights for the i'th IP-Adapter. - attention_op (`Callable`, *optional*, defaults to `None`): - The base - [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to - use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best - operator. slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ - if attention_type not in ["normal", "sliced", "xformers", "torch-sdp"]: - raise Exception(f"Unknown attention type: {attention_type}") - - if attention_type == "xformers" and xformers is None: - raise ImportError("xformers attention requires xformers module to be installed.") + if attention_type not in ["sliced", "torch-sdp"]: + raise ValueError(f"Unknown attention type: {attention_type}") if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") if attention_type == "sliced": if slice_size is None: - raise Exception(f"slice_size required for sliced attention") + raise ValueError("slice_size required for sliced attention") if slice_size not in ["auto", "max"] and not isinstance(slice_size, int): - raise Exception(f"Unsupported slice_size: {slice_size}") + raise ValueError(f"Unsupported slice_size: {slice_size}") self._ip_adapter_attention_weights = ip_adapter_attention_weights self.attention_type = attention_type - self.attention_op = attention_op self.slice_size = slice_size def __call__( @@ -165,16 +145,14 @@ def __call__( hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - def run_ip_adapters( self, attn: Attention, hidden_states: torch.Tensor, regional_ip_data: Optional[RegionalIPData], - query_length: int, # TODO: just read from query? + query_length: int, # TODO: just read from query? query: torch.Tensor, ) -> torch.Tensor: if self._ip_adapter_attention_weights is None: @@ -193,9 +171,9 @@ def run_ip_adapters( for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds): # The batch dimensions should match. - #assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0] + # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0] # The token_len dimensions should match. - #assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1] + # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1] if self._ip_adapter_attention_weights[ipa_index].skip: continue @@ -221,7 +199,6 @@ def run_ip_adapters( return hidden_states - def prepare_attention_mask( self, attn: Attention, @@ -232,7 +209,6 @@ def prepare_attention_mask( is_cross_attention: bool, regional_prompt_data: Optional[RegionalPromptData], ) -> Optional[torch.Tensor]: - if regional_prompt_data is not None and is_cross_attention: prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( query_seq_len=query_length, key_seq_len=key_length @@ -243,22 +219,11 @@ def prepare_attention_mask( else: attention_mask = prompt_region_attention_mask + attention_mask - attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) - if self.attention_type in ["normal", "sliced"]: + if self.attention_type == "sliced": pass - elif self.attention_type == "xformers": - if attention_mask is not None: - # expand our mask's singleton query_length dimension: - # [batch*heads, 1, key_length] -> - # [batch*heads, query_length, key_length] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_length, key_length] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - attention_mask = attention_mask.expand(-1, query_length, -1) - elif self.attention_type == "torch-sdp": if attention_mask is not None: # scaled_dot_product_attention expects attention_mask shape to be @@ -278,11 +243,7 @@ def run_attention( value: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: - if self.attention_type == "normal": - attn_call = self.run_attention_normal - elif self.attention_type == "xformers": - attn_call = self.run_attention_xformers - elif self.attention_type == "torch-sdp": + if self.attention_type == "torch-sdp": attn_call = self.run_attention_sdp elif self.attention_type == "sliced": attn_call = self.run_attention_sliced @@ -297,45 +258,6 @@ def run_attention( attention_mask=attention_mask, ) - def run_attention_normal( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - return hidden_states - - def run_attention_xformers( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - # attention_op - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - return hidden_states - def run_attention_sdp( self, attn: Attention, @@ -382,7 +304,7 @@ def run_attention_sliced( dim = query.shape[-1] query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) + key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) batch_size_attention, query_tokens, _ = query.shape @@ -399,12 +321,9 @@ def run_attention_sliced( attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) - return hidden_states - diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index c2f79607c02..065b0563578 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -2,8 +2,8 @@ from typing import List, Optional, TypedDict from diffusers.models import UNet2DConditionModel -from diffusers.utils.import_utils import is_xformers_available +import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( @@ -25,23 +25,21 @@ def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]): def get_attention_processor_kwargs(self, unet: UNet2DConditionModel): config = get_config() - kwargs = dict() - - # TODO: + kwargs = {} + attention_type = config.attention_type + if attention_type in ["normal", "xformers"]: + logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') + attention_type = "torch-sdp" + if attention_type == "auto": - if self.unet.device.type == "cuda": - if is_xformers_available(): - attention_type = "xformers" - elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - attention_type = "torch-sdp" - else: - attention_type = "normal" + if unet.device.type == "cuda": + attention_type = "torch-sdp" else: attention_type = "sliced" kwargs["attention_type"] = attention_type - + if attention_type == "sliced": slice_size = config.attention_slice_size if slice_size == "balanced": diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 7e362fe9589..a9ed2538825 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -791,49 +791,3 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0): diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward - -try: - import xformers - - xformers_available = True -except Exception: - xformers_available = False - - -if xformers_available: - # TODO: remove when fixed in diffusers - _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention - - def new_memory_efficient_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias=None, - p: float = 0.0, - scale: Optional[float] = None, - *, - op=None, - ): - # diffusers not align shape to 8, which is required by xformers - if attn_bias is not None and type(attn_bias) is torch.Tensor: - orig_size = attn_bias.shape[-1] - new_size = ((orig_size + 7) // 8) * 8 - aligned_attn_bias = torch.zeros( - (attn_bias.shape[0], attn_bias.shape[1], new_size), - device=attn_bias.device, - dtype=attn_bias.dtype, - ) - aligned_attn_bias[:, :, :orig_size] = attn_bias - attn_bias = aligned_attn_bias[:, :, :orig_size] - - return _xformers_memory_efficient_attention( - query=query, - key=key, - value=value, - attn_bias=attn_bias, - p=p, - scale=scale, - op=op, - ) - - xformers.ops.memory_efficient_attention = new_memory_efficient_attention From 1ab827619c10ea3014cdadc8d89e6766568f7d46 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 28 Jul 2024 02:26:32 +0300 Subject: [PATCH 03/25] Fix file name --- .../diffusion/{custom_atttention.py => custom_attention.py} | 0 .../stable_diffusion/diffusion/unet_attention_patcher.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename invokeai/backend/stable_diffusion/diffusion/{custom_atttention.py => custom_attention.py} (100%) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py similarity index 100% rename from invokeai/backend/stable_diffusion/diffusion/custom_atttention.py rename to invokeai/backend/stable_diffusion/diffusion/custom_attention.py diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 065b0563578..3365fafef26 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -6,7 +6,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( +from invokeai.backend.stable_diffusion.diffusion.custom_attention import ( CustomAttnProcessor, IPAdapterAttentionWeights, ) From 89c37c3979338df6a78f4b2b78a912e3a1e1fba1 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 28 Jul 2024 02:56:49 +0300 Subject: [PATCH 04/25] Sync fixes --- invokeai/app/invocations/denoise_latents.py | 4 +- .../diffusion/custom_attention.py | 44 ++++++++++++------- .../diffusion/unet_attention_patcher.py | 32 +------------- 3 files changed, 31 insertions(+), 49 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265c..c446f714b60 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -55,7 +55,7 @@ TextConditioningData, TextConditioningRegions, ) -from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 +from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt @@ -810,7 +810,7 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput: seed=seed, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, + attention_processor_cls=CustomAttnProcessor, ), unet=None, scheduler=scheduler, diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index d5f78ded70f..4ba53e7f784 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -1,13 +1,16 @@ from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention +import invokeai.backend.util.logging as logger +from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData +from invokeai.backend.util.devices import TorchDevice @dataclass @@ -28,10 +31,7 @@ class CustomAttnProcessor: def __init__( self, - attention_type: str, ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, - # sliced - slice_size: Optional[Union[str, int]] = None, ): """Initialize a CustomAttnProcessor. Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are @@ -39,25 +39,37 @@ def __init__( Args: ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights for the i'th IP-Adapter. - slice_size (`int`, *optional*): - The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and - `attention_head_dim` must be a multiple of the `slice_size`. """ - if attention_type not in ["sliced", "torch-sdp"]: - raise ValueError(f"Unknown attention type: {attention_type}") + + self._ip_adapter_attention_weights = ip_adapter_attention_weights + self.attention_type, self.slice_size = self._select_attention() + + def _select_attention(self): + config = get_config() + attention_type = config.attention_type + if attention_type in ["normal", "xformers"]: + logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') + attention_type = "torch-sdp" + + if attention_type == "auto": + exec_device = TorchDevice.choose_torch_device() + if exec_device.type == "mps": + attention_type = "sliced" + else: + attention_type = "torch-sdp" if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + slice_size = None if attention_type == "sliced": - if slice_size is None: - raise ValueError("slice_size required for sliced attention") - if slice_size not in ["auto", "max"] and not isinstance(slice_size, int): - raise ValueError(f"Unsupported slice_size: {slice_size}") + slice_size = config.attention_slice_size + if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int): + raise ValueError(f"Unsupported attention_slice_size: {slice_size}") + if slice_size == "balanced": + slice_size = "auto" - self._ip_adapter_attention_weights = ip_adapter_attention_weights - self.attention_type = attention_type - self.slice_size = slice_size + return attention_type, slice_size def __call__( self, diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 3365fafef26..ce45ac157c2 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -3,8 +3,6 @@ from diffusers.models import UNet2DConditionModel -import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.stable_diffusion.diffusion.custom_attention import ( CustomAttnProcessor, @@ -23,31 +21,6 @@ class UNetAttentionPatcher: def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]): self._ip_adapters = ip_adapter_data - def get_attention_processor_kwargs(self, unet: UNet2DConditionModel): - config = get_config() - kwargs = {} - - attention_type = config.attention_type - if attention_type in ["normal", "xformers"]: - logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') - attention_type = "torch-sdp" - - if attention_type == "auto": - if unet.device.type == "cuda": - attention_type = "torch-sdp" - else: - attention_type = "sliced" - - kwargs["attention_type"] = attention_type - - if attention_type == "sliced": - slice_size = config.attention_slice_size - if slice_size == "balanced": - slice_size = "auto" - kwargs["slice_size"] = slice_size - - return kwargs - def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention weights into them (if IP-Adapters are being applied). @@ -55,13 +28,11 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel): """ # Construct a dict of attention processors based on the UNet's architecture. - attn_processor_kwargs = self.get_attention_processor_kwargs(unet) - attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): if name.endswith("attn1.processor") or self._ip_adapters is None: # "attn1" processors do not use IP-Adapters. - attn_procs[name] = CustomAttnProcessor(**attn_processor_kwargs) + attn_procs[name] = CustomAttnProcessor() else: # Collect the weights from each IP Adapter for the idx'th attention processor. ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = [] @@ -80,7 +51,6 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel): attn_procs[name] = CustomAttnProcessor( ip_adapter_attention_weights=ip_adapter_attention_weights_collection, - **attn_processor_kwargs, ) return attn_procs From e9cc750f8b1b732ddd4277051d89bacfce73cf57 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 28 Jul 2024 23:13:33 +0300 Subject: [PATCH 05/25] Update app config --- .../app/services/config/config_default.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 6c39760bdc8..36cb56c9dbe 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -28,11 +28,11 @@ DEFAULT_VRAM_CACHE = 0.25 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"] -ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] +ATTENTION_TYPE = Literal["auto", "sliced", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] -CONFIG_SCHEMA_VERSION = "4.0.2" +CONFIG_SCHEMA_VERSION = "4.0.3" def get_default_ram_cache_size() -> float: @@ -107,7 +107,7 @@ class InvokeAIAppConfig(BaseSettings): device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. - attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` + attention_type: Attention type.
Valid values: `auto`, `sliced`, `torch-sdp` attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. @@ -433,6 +433,24 @@ def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str return parsed_config_dict +def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate v4.0.2 config dictionary to a v4.0.3 config dictionary. + + Args: + config_dict: A dictionary of settings from a v4.0.2 config file. + + Returns: + An config dict with the settings migrated to v4.0.3. + """ + parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict) + # normal and xformers attentions removed in 4.0.3 + attention_type = parsed_config_dict.get("attention_type", None) + if attention_type in ["normal", "xformers"]: + parsed_config_dict["attention_type"] = "torch-sdp" + parsed_config_dict["schema_version"] = "4.0.3" + return parsed_config_dict + + def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """Load and migrate a config file to the latest version. @@ -458,6 +476,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: if loaded_config_dict["schema_version"] == "4.0.1": migrated = True loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict) + if loaded_config_dict["schema_version"] == "4.0.2": + migrated = True + loaded_config_dict = migrate_v4_0_2_to_4_0_3_config_dict(loaded_config_dict) if migrated: shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) From 4b6d61377af4142cb58ef03d0b420220704a34cf Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 29 Jul 2024 13:47:51 +0300 Subject: [PATCH 06/25] Remove remaining references to xformers --- docker/Dockerfile | 7 +------ flake.nix | 2 +- installer/lib/installer.py | 4 ++-- invokeai/app/api/routers/app_info.py | 2 -- pyproject.toml | 6 ------ 5 files changed, 4 insertions(+), 17 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 7ea078af0d9..24f2ff9e2f7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -43,12 +43,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \ fi &&\ - # xformers + triton fails to install on arm64 - if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ - pip install $extra_index_url_arg -e ".[xformers]"; \ - else \ - pip install $extra_index_url_arg -e "."; \ - fi + pip install $extra_index_url_arg -e "."; # #### Build the Web UI ------------------------------------ diff --git a/flake.nix b/flake.nix index 3ccc6658121..bf8d2ae9466 100644 --- a/flake.nix +++ b/flake.nix @@ -84,7 +84,7 @@ in { devShells.${system} = rec { - develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; }; + develop = mkShell { dir = "venv"; install = "-e '.' --extra-index-url https://download.pytorch.org/whl/cu118"; }; default = develop; }; }; diff --git a/installer/lib/installer.py b/installer/lib/installer.py index 11823b413e0..504c801df6d 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -418,11 +418,11 @@ def get_torch_source() -> Tuple[str | None, str | None]: url = "https://download.pytorch.org/whl/cpu" elif device.value == "cuda": # CUDA uses the default PyPi index - optional_modules = "[xformers,onnx-cuda]" + optional_modules = "[onnx-cuda]" elif OS == "Windows": if device.value == "cuda": url = "https://download.pytorch.org/whl/cu121" - optional_modules = "[xformers,onnx-cuda]" + optional_modules = "[onnx-cuda]" elif device.value == "cpu": # CPU uses the default PyPi index, no optional modules pass diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index e2556ecaa7c..9f87e2cdec0 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -56,7 +56,6 @@ class AppDependencyVersions(BaseModel): torch: str = Field(description="PyTorch version") torchvision: str = Field(description="PyTorch Vision version") transformers: str = Field(description="transformers version") - xformers: Optional[str] = Field(description="xformers version") class AppConfig(BaseModel): @@ -88,7 +87,6 @@ async def get_app_deps() -> AppDependencyVersions: torch=torch.version.__version__, torchvision=version("torchvision"), transformers=version("transformers"), - xformers=None, # TODO: ask frontend ) diff --git a/pyproject.toml b/pyproject.toml index 9acaa17e44d..d1be2215f0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,12 +95,6 @@ dependencies = [ ] [project.optional-dependencies] -"xformers" = [ - # Core generation dependencies, pinned for reproducible builds. - "xformers==0.0.25post1; sys_platform!='darwin'", - # Auxiliary dependencies, pinned only if necessary. - "triton; sys_platform=='linux'", -] "onnx" = ["onnxruntime"] "onnx-cuda" = ["onnxruntime-gpu"] "onnx-directml" = ["onnxruntime-directml"] From d5fa938eb005a11ff9ff03588ae6ae27c2483bdd Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 30 Jul 2024 04:09:02 +0300 Subject: [PATCH 07/25] Run api regen --- invokeai/frontend/web/src/services/api/schema.ts | 5 ----- 1 file changed, 5 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 59f9897f740..5334db35e1e 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -725,11 +725,6 @@ export type components = { * @description transformers version */ transformers: string; - /** - * Xformers - * @description xformers version - */ - xformers: string | null; }; /** * AppVersion From 5a9cc04e79623f95ee1f07d330879636199ea097 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 2 Aug 2024 00:46:17 +0300 Subject: [PATCH 08/25] Small rearrangement --- .../diffusion/custom_attention.py | 62 +++++-------------- 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 4ba53e7f784..3884fe06ee6 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -105,15 +105,18 @@ def __call__( ) query_length = hidden_states.shape[1] - attention_mask = self.prepare_attention_mask( - attn=attn, - attention_mask=attention_mask, - batch_size=batch_size, - key_length=key_length, - query_length=query_length, - is_cross_attention=is_cross_attention, - regional_prompt_data=regional_prompt_data, - ) + # Regional Prompt Attention Mask + if regional_prompt_data is not None and is_cross_attention: + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_length, key_seq_len=key_length + ) + + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask + + attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -211,42 +214,6 @@ def run_ip_adapters( return hidden_states - def prepare_attention_mask( - self, - attn: Attention, - attention_mask: Optional[torch.Tensor], - batch_size: int, - key_length: int, - query_length: int, - is_cross_attention: bool, - regional_prompt_data: Optional[RegionalPromptData], - ) -> Optional[torch.Tensor]: - if regional_prompt_data is not None and is_cross_attention: - prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_length, key_seq_len=key_length - ) - - if attention_mask is None: - attention_mask = prompt_region_attention_mask - else: - attention_mask = prompt_region_attention_mask + attention_mask - - attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) - - if self.attention_type == "sliced": - pass - - elif self.attention_type == "torch-sdp": - if attention_mask is not None: - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - else: - raise Exception(f"Unknown attention type: {self.attention_type}") - - return attention_mask - def run_attention( self, attn: Attention, @@ -286,6 +253,11 @@ def run_attention_sdp( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attention_mask is not None: + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( From be84746e6783a9259e8cbde96d78fbb8514058bb Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 2 Aug 2024 00:57:19 +0300 Subject: [PATCH 09/25] Add assert check Co-Authored-By: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> --- invokeai/backend/stable_diffusion/diffusion/custom_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 3884fe06ee6..ad2f68627a8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -175,6 +175,7 @@ def run_ip_adapters( assert regional_ip_data is None return hidden_states + assert regional_ip_data is not None ip_masks = regional_ip_data.get_masks(query_seq_len=query_length) assert ( From bf2f798341a40ff58b7df9dc10463237593d2d64 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 3 Aug 2024 01:27:01 +0300 Subject: [PATCH 10/25] Fix bad generation on slice_size not factor of heads count --- .gitignore | 8 + invokeai.yaml.bak | 6 + .../diffusion/custom_attention.py | 2 +- .../diffusion/custom_atttention.py | 383 ++++++++++++++++++ invokeai/frontend/web/scripts/typegen.js | 2 +- invokeai/frontend/web/vite.config.mts | 6 +- 6 files changed, 402 insertions(+), 5 deletions(-) create mode 100644 invokeai.yaml.bak create mode 100644 invokeai/backend/stable_diffusion/diffusion/custom_atttention.py diff --git a/.gitignore b/.gitignore index 29d27d78ed5..a9739a7294b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,13 @@ .idea/ +models/ +nodes/ +configs/ +databases/ +invokeai.yaml +invokeai.example.yaml +outputs/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/invokeai.yaml.bak b/invokeai.yaml.bak new file mode 100644 index 00000000000..b348590cae6 --- /dev/null +++ b/invokeai.yaml.bak @@ -0,0 +1,6 @@ +# Internal metadata - do not edit: +schema_version: 4.0.2 + +# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/: +host: 0.0.0.0 +attention_type: torch-sdp diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index ad2f68627a8..743a1d5658c 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -297,7 +297,7 @@ def run_attention_sliced( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - for i in range(batch_size_attention // slice_size): + for i in range((batch_size_attention - 1) // slice_size + 1): start_idx = i * slice_size end_idx = (i + 1) * slice_size diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py new file mode 100644 index 00000000000..c5a48847f8d --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -0,0 +1,383 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention + +import invokeai.backend.util.logging as logger +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights +from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData +from invokeai.backend.util.devices import TorchDevice + + +@dataclass +class IPAdapterAttentionWeights: + ip_adapter_weights: IPAttentionProcessorWeights + skip: bool + + +class CustomAttnProcessor2_0: + """A custom implementation of attention processor that supports additional Invoke features. + This implementation is based on + SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616) + AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204) + Supported custom features: + - IP-Adapter + - Regional prompt attention + """ + + def __init__( + self, + ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, + ): + """Initialize a CustomAttnProcessor. + Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are + layer-specific are passed to __init__(). + Args: + ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights + for the i'th IP-Adapter. + """ + + self._ip_adapter_attention_weights = ip_adapter_attention_weights + self.attention_type, self.slice_size = self._select_attention() + + def _select_attention(self): + config = get_config() + attention_type = config.attention_type + if attention_type in ["normal", "xformers"]: + logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') + attention_type = "torch-sdp" + + if attention_type == "auto": + exec_device = TorchDevice.choose_torch_device() + if exec_device.type == "mps": + attention_type = "sliced" + else: + attention_type = "torch-sdp" + + if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + slice_size = None + if attention_type == "sliced": + slice_size = config.attention_slice_size + if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int): + raise ValueError(f"Unsupported attention_slice_size: {slice_size}") + if slice_size == "balanced": + slice_size = "auto" + + return attention_type, slice_size + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + # For Regional Prompting: + regional_prompt_data: Optional[RegionalPromptData] = None, + percent_through: Optional[torch.Tensor] = None, + # For IP-Adapter: + regional_ip_data: Optional[RegionalIPData] = None, + *args, + **kwargs, + ) -> torch.Tensor: + # If true, we are doing cross-attention, if false we are doing self-attention. + is_cross_attention = encoder_hidden_states is not None + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, key_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + query_length = hidden_states.shape[1] + + # Regional Prompt Attention Mask + if regional_prompt_data is not None and is_cross_attention: + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_length, key_seq_len=key_length + ) + + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask + + attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + hidden_states = self.run_attention( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + if is_cross_attention: + hidden_states = self.run_ip_adapters( + attn=attn, + hidden_states=hidden_states, + regional_ip_data=regional_ip_data, + query_length=query_length, + query=query, + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + def run_ip_adapters( + self, + attn: Attention, + hidden_states: torch.Tensor, + regional_ip_data: Optional[RegionalIPData], + query_length: int, # TODO: just read from query? + query: torch.Tensor, + ) -> torch.Tensor: + if self._ip_adapter_attention_weights is None: + # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. + assert regional_ip_data is None + return hidden_states + + assert regional_ip_data is not None + ip_masks = regional_ip_data.get_masks(query_seq_len=query_length) + + assert ( + len(regional_ip_data.image_prompt_embeds) + == len(self._ip_adapter_attention_weights) + == len(regional_ip_data.scales) + == ip_masks.shape[1] + ) + + for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds): + # The batch dimensions should match. + # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0] + # The token_len dimensions should match. + # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1] + + if self._ip_adapter_attention_weights[ipa_index].skip: + continue + + ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights + ipa_scale = regional_ip_data.scales[ipa_index] + ip_mask = ip_masks[0, ipa_index, ...] + + # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) + + ip_hidden_states = self.run_attention( + attn=attn, + query=query, + key=ip_key, + value=ip_value, + attention_mask=None, + ) + + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask + + return hidden_states + + def run_attention( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.attention_type == "torch-sdp": + attn_call = self.run_attention_sdp + elif self.attention_type == "sliced": + attn_call = self.run_attention_sliced + else: + raise Exception(f"Unknown attention type: {self.attention_type}") + + return attn_call( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + def run_attention_sdp( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + batch_size = key.shape[0] + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attention_mask is not None: + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + return hidden_states + + def run_attention_sliced( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + if True: + func = self._run_attention_sliced_norm + else: + func = self._run_attention_sliced_sdp + + return func( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + ) + + def _run_attention_sliced_norm( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + # slice_size + if self.slice_size == "max": + slice_size = 1 + elif self.slice_size == "auto": + slice_size = max(1, attn.sliceable_head_dim // 2) + else: + slice_size = min(self.slice_size, attn.sliceable_head_dim) + + dim = query.shape[-1] + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(batch_size_attention // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + return hidden_states + + + def _run_attention_sliced_sdp( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + # slice_size + if self.slice_size == "max": + slice_size = 1 + elif self.slice_size == "auto": + slice_size = max(1, attn.sliceable_head_dim // 2) + else: + slice_size = min(self.slice_size, attn.sliceable_head_dim) + + dim = query.shape[-1] + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(batch_size_attention // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + value_slice = value[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + attn_slice = F.scaled_dot_product_attention( + query_slice, key_slice, value_slice, attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False + ) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + return hidden_states diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js index fa2d791350d..435c82a1abb 100644 --- a/invokeai/frontend/web/scripts/typegen.js +++ b/invokeai/frontend/web/scripts/typegen.js @@ -3,7 +3,7 @@ import fs from 'node:fs'; import openapiTS from 'openapi-typescript'; -const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json'; +const OPENAPI_URL = 'http://192.168.5.199:9090/openapi.json'; const OUTPUT_FILE = 'src/services/api/schema.ts'; async function generateTypes(schema) { diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index a40c515465c..59a3cf1901f 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -71,18 +71,18 @@ export default defineConfig(({ mode }) => { proxy: { // Proxy socket.io to the nodes socketio server '/ws/socket.io': { - target: 'ws://127.0.0.1:9090', + target: 'ws://192.168.5.199:9090', ws: true, }, // Proxy openapi schema definiton '/openapi.json': { - target: 'http://127.0.0.1:9090/openapi.json', + target: 'http://192.168.5.199:9090/openapi.json', rewrite: (path) => path.replace(/^\/openapi.json/, ''), changeOrigin: true, }, // proxy nodes api '/api/': { - target: 'http://127.0.0.1:9090/api/', + target: 'http://192.168.5.199:9090/api/', rewrite: (path) => path.replace(/^\/api/, ''), changeOrigin: true, }, From 91cc89a75a6d70414e1a3414a6d32560b37b534b Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 3 Aug 2024 01:27:40 +0300 Subject: [PATCH 11/25] Use invoke slice_size values, to have less confusion --- .../backend/stable_diffusion/diffusion/custom_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 743a1d5658c..bf5e5fa5949 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -66,8 +66,8 @@ def _select_attention(self): slice_size = config.attention_slice_size if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int): raise ValueError(f"Unsupported attention_slice_size: {slice_size}") - if slice_size == "balanced": - slice_size = "auto" + if slice_size == "auto": + slice_size = "balanced" return attention_type, slice_size @@ -281,7 +281,7 @@ def run_attention_sliced( # slice_size if self.slice_size == "max": slice_size = 1 - elif self.slice_size == "auto": + elif self.slice_size == "balanced": slice_size = max(1, attn.sliceable_head_dim // 2) else: slice_size = min(self.slice_size, attn.sliceable_head_dim) From 719daebd187a4671d7d54d2b7b488cd858180716 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 3 Aug 2024 01:28:24 +0300 Subject: [PATCH 12/25] Add torch-sdp scale parameter support(added in torch 2.1) --- .../diffusion/custom_attention.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index bf5e5fa5949..eed3bb3fb9a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -44,6 +44,15 @@ def __init__( self._ip_adapter_attention_weights = ip_adapter_attention_weights self.attention_type, self.slice_size = self._select_attention() + # inspect didn't work because it's native function + # In 2.0 torch there no scale argument in sdp, it's added in 2.1 + # Probably can selected based on torch version instead + try: + F.scaled_dot_product_attention(torch.zeros(1,1), torch.zeros(1,1), torch.zeros(1,1), scale=0.5) + self.scaled_sdp = True + except: + self.scaled_sdp = False + def _select_attention(self): config = get_config() attention_type = config.attention_type @@ -260,9 +269,11 @@ def run_attention_sdp( attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 + scale_kwargs = {} + if self.scaled_sdp: + scale_kwargs["scale"] = attn.scale hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, **scale_kwargs ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) From a16fa31479ee1e0a65f43e033aec65f36209ff08 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 3 Aug 2024 01:28:48 +0300 Subject: [PATCH 13/25] Test implementation of sliced attention using torch-sdp --- .../diffusion/custom_attention.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index eed3bb3fb9a..460910a1078 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -314,12 +314,29 @@ def run_attention_sliced( query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] + value_slice = value[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice + # TODO: compare speed/memory on mps + # cuda, sd1, 31 step, 1024x1024 + # denoise_latents 1 19.667s 3.418G + # denoise_latents 1 11.601s 2.133G (sdp) + # cpu, sd1, 10 steps, 512x512 + # denoise_latents 1 43.859s 0.000G + # denoise_latents 1 40.696s 0.000G (sdp) + if False: + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx]) + else: + if attn_mask_slice is not None: + attn_mask_slice = attn_mask_slice.unsqueeze(0) + + scale_kwargs = {} + if self.scaled_sdp: + scale_kwargs["scale"] = attn.scale + hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention( + query_slice.unsqueeze(0), key_slice.unsqueeze(0), value_slice.unsqueeze(0), attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False, **scale_kwargs + ).squeeze(0) hidden_states = attn.batch_to_head_dim(hidden_states) return hidden_states From 7ffceaa7ff27a0b24bb6f0e8c2970d51609eebe4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 3 Aug 2024 02:33:30 +0300 Subject: [PATCH 14/25] Fix slice_size handling --- invokeai/backend/stable_diffusion/diffusers_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ace300ee03d..42c2fdba397 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -179,7 +179,7 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor): slice_size = config.attention_slice_size if slice_size == "auto": slice_size = auto_detect_slice_size(latents) - elif slice_size == "balanced": + if slice_size == "balanced": slice_size = "auto" self.enable_attention_slicing(slice_size=slice_size) return From c7e71038dd6836fd4c8376cbc639a99c00650a45 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 3 Aug 2024 17:04:00 +0300 Subject: [PATCH 15/25] Revert "Fix bad generation on slice_size not factor of heads count" This reverts commit bf2f798341a40ff58b7df9dc10463237593d2d64. --- .gitignore | 8 - invokeai.yaml.bak | 6 - .../diffusion/custom_atttention.py | 383 ------------------ invokeai/frontend/web/scripts/typegen.js | 2 +- invokeai/frontend/web/vite.config.mts | 6 +- 5 files changed, 4 insertions(+), 401 deletions(-) delete mode 100644 invokeai.yaml.bak delete mode 100644 invokeai/backend/stable_diffusion/diffusion/custom_atttention.py diff --git a/.gitignore b/.gitignore index a9739a7294b..29d27d78ed5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,5 @@ .idea/ -models/ -nodes/ -configs/ -databases/ -invokeai.yaml -invokeai.example.yaml -outputs/ - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/invokeai.yaml.bak b/invokeai.yaml.bak deleted file mode 100644 index b348590cae6..00000000000 --- a/invokeai.yaml.bak +++ /dev/null @@ -1,6 +0,0 @@ -# Internal metadata - do not edit: -schema_version: 4.0.2 - -# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/: -host: 0.0.0.0 -attention_type: torch-sdp diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py deleted file mode 100644 index c5a48847f8d..00000000000 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ /dev/null @@ -1,383 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import torch -import torch.nn.functional as F -from diffusers.models.attention_processor import Attention - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights -from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData -from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData -from invokeai.backend.util.devices import TorchDevice - - -@dataclass -class IPAdapterAttentionWeights: - ip_adapter_weights: IPAttentionProcessorWeights - skip: bool - - -class CustomAttnProcessor2_0: - """A custom implementation of attention processor that supports additional Invoke features. - This implementation is based on - SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616) - AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204) - Supported custom features: - - IP-Adapter - - Regional prompt attention - """ - - def __init__( - self, - ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, - ): - """Initialize a CustomAttnProcessor. - Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are - layer-specific are passed to __init__(). - Args: - ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights - for the i'th IP-Adapter. - """ - - self._ip_adapter_attention_weights = ip_adapter_attention_weights - self.attention_type, self.slice_size = self._select_attention() - - def _select_attention(self): - config = get_config() - attention_type = config.attention_type - if attention_type in ["normal", "xformers"]: - logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') - attention_type = "torch-sdp" - - if attention_type == "auto": - exec_device = TorchDevice.choose_torch_device() - if exec_device.type == "mps": - attention_type = "sliced" - else: - attention_type = "torch-sdp" - - if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - slice_size = None - if attention_type == "sliced": - slice_size = config.attention_slice_size - if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int): - raise ValueError(f"Unsupported attention_slice_size: {slice_size}") - if slice_size == "balanced": - slice_size = "auto" - - return attention_type, slice_size - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - # For Regional Prompting: - regional_prompt_data: Optional[RegionalPromptData] = None, - percent_through: Optional[torch.Tensor] = None, - # For IP-Adapter: - regional_ip_data: Optional[RegionalIPData] = None, - *args, - **kwargs, - ) -> torch.Tensor: - # If true, we are doing cross-attention, if false we are doing self-attention. - is_cross_attention = encoder_hidden_states is not None - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, key_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - query_length = hidden_states.shape[1] - - # Regional Prompt Attention Mask - if regional_prompt_data is not None and is_cross_attention: - prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_length, key_seq_len=key_length - ) - - if attention_mask is None: - attention_mask = prompt_region_attention_mask - else: - attention_mask = prompt_region_attention_mask + attention_mask - - attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - hidden_states = self.run_attention( - attn=attn, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - ) - - if is_cross_attention: - hidden_states = self.run_ip_adapters( - attn=attn, - hidden_states=hidden_states, - regional_ip_data=regional_ip_data, - query_length=query_length, - query=query, - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - - def run_ip_adapters( - self, - attn: Attention, - hidden_states: torch.Tensor, - regional_ip_data: Optional[RegionalIPData], - query_length: int, # TODO: just read from query? - query: torch.Tensor, - ) -> torch.Tensor: - if self._ip_adapter_attention_weights is None: - # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. - assert regional_ip_data is None - return hidden_states - - assert regional_ip_data is not None - ip_masks = regional_ip_data.get_masks(query_seq_len=query_length) - - assert ( - len(regional_ip_data.image_prompt_embeds) - == len(self._ip_adapter_attention_weights) - == len(regional_ip_data.scales) - == ip_masks.shape[1] - ) - - for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds): - # The batch dimensions should match. - # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1] - - if self._ip_adapter_attention_weights[ipa_index].skip: - continue - - ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights - ipa_scale = regional_ip_data.scales[ipa_index] - ip_mask = ip_masks[0, ipa_index, ...] - - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - ip_hidden_states = self.run_attention( - attn=attn, - query=query, - key=ip_key, - value=ip_value, - attention_mask=None, - ) - - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask - - return hidden_states - - def run_attention( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - if self.attention_type == "torch-sdp": - attn_call = self.run_attention_sdp - elif self.attention_type == "sliced": - attn_call = self.run_attention_sliced - else: - raise Exception(f"Unknown attention type: {self.attention_type}") - - return attn_call( - attn=attn, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - ) - - def run_attention_sdp( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - batch_size = key.shape[0] - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attention_mask is not None: - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - return hidden_states - - def run_attention_sliced( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - if True: - func = self._run_attention_sliced_norm - else: - func = self._run_attention_sliced_sdp - - return func( - attn=attn, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - ) - - def _run_attention_sliced_norm( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - # slice_size - if self.slice_size == "max": - slice_size = 1 - elif self.slice_size == "auto": - slice_size = max(1, attn.sliceable_head_dim // 2) - else: - slice_size = min(self.slice_size, attn.sliceable_head_dim) - - dim = query.shape[-1] - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, _ = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - for i in range(batch_size_attention // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - return hidden_states - - - def _run_attention_sliced_sdp( - self, - attn: Attention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - ) -> torch.Tensor: - # slice_size - if self.slice_size == "max": - slice_size = 1 - elif self.slice_size == "auto": - slice_size = max(1, attn.sliceable_head_dim // 2) - else: - slice_size = min(self.slice_size, attn.sliceable_head_dim) - - dim = query.shape[-1] - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, _ = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - for i in range(batch_size_attention // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - value_slice = value[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - attn_slice = F.scaled_dot_product_attention( - query_slice, key_slice, value_slice, attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False - ) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - return hidden_states diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js index 435c82a1abb..fa2d791350d 100644 --- a/invokeai/frontend/web/scripts/typegen.js +++ b/invokeai/frontend/web/scripts/typegen.js @@ -3,7 +3,7 @@ import fs from 'node:fs'; import openapiTS from 'openapi-typescript'; -const OPENAPI_URL = 'http://192.168.5.199:9090/openapi.json'; +const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json'; const OUTPUT_FILE = 'src/services/api/schema.ts'; async function generateTypes(schema) { diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index 59a3cf1901f..a40c515465c 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -71,18 +71,18 @@ export default defineConfig(({ mode }) => { proxy: { // Proxy socket.io to the nodes socketio server '/ws/socket.io': { - target: 'ws://192.168.5.199:9090', + target: 'ws://127.0.0.1:9090', ws: true, }, // Proxy openapi schema definiton '/openapi.json': { - target: 'http://192.168.5.199:9090/openapi.json', + target: 'http://127.0.0.1:9090/openapi.json', rewrite: (path) => path.replace(/^\/openapi.json/, ''), changeOrigin: true, }, // proxy nodes api '/api/': { - target: 'http://192.168.5.199:9090/api/', + target: 'http://127.0.0.1:9090/api/', rewrite: (path) => path.replace(/^\/api/, ''), changeOrigin: true, }, From 302dc9faeec3116987d626418a50d2bb3f53868e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 4 Aug 2024 02:05:32 +0300 Subject: [PATCH 16/25] Return normal attention, change slicing logic, remove old attention code --- .../app/services/config/config_default.py | 33 ++- .../stable_diffusion/diffusers_pipeline.py | 81 +----- .../diffusion/custom_attention.py | 152 +++++++---- .../diffusion/unet_attention_patcher.py | 2 +- .../multi_diffusion_pipeline.py | 239 +++++++++--------- 5 files changed, 253 insertions(+), 254 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 36cb56c9dbe..352b1b46830 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -13,6 +13,7 @@ from typing import Any, Literal, Optional import psutil +import torch import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict @@ -28,8 +29,8 @@ DEFAULT_VRAM_CACHE = 0.25 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"] -ATTENTION_TYPE = Literal["auto", "sliced", "torch-sdp"] -ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] +ATTENTION_TYPE = Literal["auto", "normal", "torch-sdp"] +ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] CONFIG_SCHEMA_VERSION = "4.0.3" @@ -181,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings): # GENERATION sequential_guidance: bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.") attention_type: ATTENTION_TYPE = Field(default="auto", description="Attention type.") - attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size, valid when attention_type=="sliced".') + attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size') force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") @@ -443,10 +444,30 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str An config dict with the settings migrated to v4.0.3. """ parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict) - # normal and xformers attentions removed in 4.0.3 attention_type = parsed_config_dict.get("attention_type", None) - if attention_type in ["normal", "xformers"]: - parsed_config_dict["attention_type"] = "torch-sdp" + + # now attention_slice_size means enabling slicing attention + if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict: + del parsed_config_dict["attention_slice_size"] + + # xformers attention removed, on mps better works normal attention + if attention_type == "xformers": + if torch.backends.mps.is_available(): + parsed_config_dict["attention_type"] = "normal" + else: + parsed_config_dict["attention_type"] = "torch-sdp" + + # slicing attention now enabled by `attention_slice_size` + if attention_type == "sliced": + if torch.backends.mps.is_available(): + parsed_config_dict["attention_type"] = "normal" + else: + parsed_config_dict["attention_type"] = "torch-sdp" + + # if no attention_slise_size in config, use balanced as default option + if "attention_slice_size" not in parsed_config_dict: + parsed_config_dict["attention_slice_size"] = "balanced" + parsed_config_dict["schema_version"] = "4.0.3" return parsed_config_dict diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 42c2fdba397..6c2dca11f37 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -1,16 +1,13 @@ from __future__ import annotations import math -from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Callable, List, Optional, Union import einops import PIL.Image -import psutil import torch import torchvision.transforms as T -from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline @@ -19,14 +16,10 @@ from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState -from invokeai.backend.util.attention import auto_detect_slice_size -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel @@ -168,55 +161,6 @@ def __init__( self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - def _adjust_memory_efficient_attention(self, latents: torch.Tensor): - config = get_config() - attention_type = config.attention_type - if attention_type in ["normal", "xformers"]: - logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') - attention_type = "torch-sdp" - - if config.attention_type == "sliced": - slice_size = config.attention_slice_size - if slice_size == "auto": - slice_size = auto_detect_slice_size(latents) - if slice_size == "balanced": - slice_size = "auto" - self.enable_attention_slicing(slice_size=slice_size) - return - elif config.attention_type == "torch-sdp": - self.unet.set_attn_processor(AttnProcessor2_0()) - return - - # the remainder if this code is called when attention_type=='auto' - if self.unet.device.type == "cuda": - self.unet.set_attn_processor(AttnProcessor2_0()) - return - - if self.unet.device.type == "cpu" or self.unet.device.type == "mps": - mem_free = psutil.virtual_memory().free - elif self.unet.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device)) - else: - raise ValueError(f"unrecognized device {self.unet.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 - max_size_required_for_baddbmm = ( - 16 - * latents.size(dim=2) - * latents.size(dim=3) - * latents.size(dim=2) - * latents.size(dim=3) - * bytes_per_element_needed_for_baddbmm_duplication - ) - if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size="max") - elif torch.backends.mps.is_available(): - # diffusers recommends always enabling for mps - self.enable_attention_slicing(slice_size="max") - else: - self.unet.set_attn_processor(AttnProcessor2_0()) - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): raise Exception("Should not be called") @@ -335,25 +279,14 @@ def latents_from_embeddings( is_gradient_mask=is_gradient_mask, ) - use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None - ) - unet_attention_patcher = None - attn_ctx = nullcontext() - - if use_ip_adapter or use_regional_prompting: - ip_adapters: Optional[List[UNetIPAdapterData]] = ( - [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data] - if use_ip_adapter - else None - ) - unet_attention_patcher = UNetAttentionPatcher(ip_adapters) - attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - else: - self._adjust_memory_efficient_attention(latents) + ip_adapters: Optional[List[UNetIPAdapterData]] = None + if ip_adapter_data is not None: + ip_adapters = [ + {"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data + ] - with attn_ctx: + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model): callback( PipelineIntermediateState( step=-1, diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 460910a1078..370c10b1d0b 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -1,11 +1,13 @@ +import math from dataclasses import dataclass from typing import List, Optional +import psutil import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention +from packaging.version import Version -import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData @@ -42,43 +44,48 @@ def __init__( """ self._ip_adapter_attention_weights = ip_adapter_attention_weights - self.attention_type, self.slice_size = self._select_attention() - - # inspect didn't work because it's native function - # In 2.0 torch there no scale argument in sdp, it's added in 2.1 - # Probably can selected based on torch version instead - try: - F.scaled_dot_product_attention(torch.zeros(1,1), torch.zeros(1,1), torch.zeros(1,1), scale=0.5) - self.scaled_sdp = True - except: - self.scaled_sdp = False - - def _select_attention(self): + config = get_config() - attention_type = config.attention_type - if attention_type in ["normal", "xformers"]: - logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.') - attention_type = "torch-sdp" - - if attention_type == "auto": - exec_device = TorchDevice.choose_torch_device() - if exec_device.type == "mps": - attention_type = "sliced" - else: - attention_type = "torch-sdp" + self.attention_type = config.attention_type + if self.attention_type == "auto": + self.attention_type = self._select_attention_type() - if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.slice_size = config.attention_slice_size + if self.slice_size == "auto": + self.slice_size = self._select_slice_size() - slice_size = None - if attention_type == "sliced": - slice_size = config.attention_slice_size - if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int): - raise ValueError(f"Unsupported attention_slice_size: {slice_size}") - if slice_size == "auto": - slice_size = "balanced" + if self.attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - return attention_type, slice_size + # In 2.0 torch there no `scale` argument in sdp, it's added in 2.1 + self.scaled_sdp = Version(torch.__version__) >= Version("2.1") + + def _select_attention_type(self) -> str: + device = TorchDevice.choose_torch_device() + # On some mps system normal attention still faster than torch-sdp on others - on par + if device.type == "mps": + return "normal" + else: # cuda, cpu + return "torch-sdp" + + def _select_slice_size(self) -> str: + device = TorchDevice.choose_torch_device() + if device.type in ["cpu", "mps"]: + total_ram_gb = math.ceil(psutil.virtual_memory().total / 2**30) + if total_ram_gb <= 16: + return "max" + if total_ram_gb <= 32: + return "balanced" + return "none" + elif device.type == "cuda": + total_vram_gb = math.ceil(torch.cuda.get_device_properties(device).total_memory / 2**30) + if total_vram_gb <= 4: + return "max" + if total_vram_gb <= 6: + return "balanced" + return "none" + else: + raise ValueError(f"Unknown device: {device.type}") def __call__( self, @@ -224,6 +231,19 @@ def run_ip_adapters( return hidden_states + def _get_slice_size(self, attn) -> Optional[int]: + if self.slice_size == "none": + return None + if isinstance(self.slice_size, int): + return self.slice_size + + if self.slice_size == "max": + return 1 + if self.slice_size == "balanced": + return max(1, attn.sliceable_head_dim // 2) + + raise ValueError(f"Incorrect slice_size value: {self.slice_size}") + def run_attention( self, attn: Attention, @@ -232,10 +252,21 @@ def run_attention( value: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: + slice_size = self._get_slice_size(attn) + if slice_size is not None: + return self.run_attention_sliced( + attn=attn, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + slice_size=slice_size, + ) + if self.attention_type == "torch-sdp": attn_call = self.run_attention_sdp - elif self.attention_type == "sliced": - attn_call = self.run_attention_sliced + elif self.attention_type == "normal": + attn_call = self.run_attention_normal else: raise Exception(f"Unknown attention type: {self.attention_type}") @@ -247,6 +278,24 @@ def run_attention( attention_mask=attention_mask, ) + def run_attention_normal( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + return hidden_states + def run_attention_sdp( self, attn: Attention, @@ -288,15 +337,8 @@ def run_attention_sliced( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + slice_size: int, ) -> torch.Tensor: - # slice_size - if self.slice_size == "max": - slice_size = 1 - elif self.slice_size == "balanced": - slice_size = max(1, attn.sliceable_head_dim // 2) - else: - slice_size = min(self.slice_size, attn.sliceable_head_dim) - dim = query.shape[-1] query = attn.head_to_batch_dim(query) @@ -317,17 +359,11 @@ def run_attention_sliced( value_slice = value[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - # TODO: compare speed/memory on mps - # cuda, sd1, 31 step, 1024x1024 - # denoise_latents 1 19.667s 3.418G - # denoise_latents 1 11.601s 2.133G (sdp) - # cpu, sd1, 10 steps, 512x512 - # denoise_latents 1 43.859s 0.000G - # denoise_latents 1 40.696s 0.000G (sdp) - if False: + if self.attention_type == "normal": attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx]) - else: + del attn_slice + elif self.attention_type == "torch-sdp": if attn_mask_slice is not None: attn_mask_slice = attn_mask_slice.unsqueeze(0) @@ -335,8 +371,16 @@ def run_attention_sliced( if self.scaled_sdp: scale_kwargs["scale"] = attn.scale hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention( - query_slice.unsqueeze(0), key_slice.unsqueeze(0), value_slice.unsqueeze(0), attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False, **scale_kwargs + query_slice.unsqueeze(0), + key_slice.unsqueeze(0), + value_slice.unsqueeze(0), + attn_mask=attn_mask_slice, + dropout_p=0.0, + is_causal=False, + **scale_kwargs, ).squeeze(0) + else: + raise ValueError(f"Unknown attention type: {self.attention_type}") hidden_states = attn.batch_to_head_dim(hidden_states) return hidden_states diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index ce45ac157c2..8ba8b3acf38 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -56,7 +56,7 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel): return attn_procs @contextmanager - def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): + def apply_custom_attention(self, unet: UNet2DConditionModel): """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" attn_procs = self._prepare_attention_processors(unet) orig_attn_processors = unet.attn_processors diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 6c07fc1c2c8..f78facaf581 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -13,6 +13,7 @@ StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.tiles.utils import Tile @@ -63,132 +64,132 @@ def multi_diffusion_denoise( latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) assert isinstance(latents, torch.Tensor) # For static type checking. - # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after - # cropping into regions. - self._adjust_memory_efficient_attention(latents) - - # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since - # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a - # separate scheduler state for each region batch. - # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler - # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect - # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when - # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each - # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion. - region_batch_schedulers: list[SchedulerMixin] = [ - copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning - ] - - callback( - PipelineIntermediateState( - step=-1, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) - ) - - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - - merged_latents = torch.zeros_like(latents) - merged_latents_weights = torch.zeros( - (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype - ) - merged_pred_original: torch.Tensor | None = None - for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): - # Switch to the scheduler for the region batch. - self.scheduler = region_batch_schedulers[region_idx] - - # Crop the inputs to the region. - region_latents = latents[ - :, - :, - region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, - region_conditioning.region.coords.left : region_conditioning.region.coords.right, - ] - - # Run the denoising step on the region. - step_output = self.step( - t=batched_t, - latents=region_latents, - conditioning_data=region_conditioning.text_conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=None, - mask=None, - masked_latents=None, - control_data=region_conditioning.control_data, - ) - - # Build a region_weight matrix that applies gradient blending to the edges of the region. - region = region_conditioning.region - _, _, region_height, region_width = step_output.prev_sample.shape - region_weight = torch.ones( - (1, 1, region_height, region_width), - dtype=latents.dtype, - device=latents.device, - ) - if region.overlap.left > 0: - left_grad = torch.linspace( - 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype - ).view((1, 1, 1, -1)) - region_weight[:, :, :, : region.overlap.left] *= left_grad - if region.overlap.top > 0: - top_grad = torch.linspace( - 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype - ).view((1, 1, -1, 1)) - region_weight[:, :, : region.overlap.top, :] *= top_grad - if region.overlap.right > 0: - right_grad = torch.linspace( - 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype - ).view((1, 1, 1, -1)) - region_weight[:, :, :, -region.overlap.right :] *= right_grad - if region.overlap.bottom > 0: - bottom_grad = torch.linspace( - 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype - ).view((1, 1, -1, 1)) - region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad - - # Update the merged results with the region results. - merged_latents[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += step_output.prev_sample * region_weight - merged_latents_weights[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += region_weight - - pred_orig_sample = getattr(step_output, "pred_original_sample", None) - if pred_orig_sample is not None: - # If one region has pred_original_sample, then we can assume that all regions will have it, because - # they all use the same scheduler. - if merged_pred_original is None: - merged_pred_original = torch.zeros_like(latents) - merged_pred_original[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += pred_orig_sample - - # Normalize the merged results. - latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents) - # For debugging, uncomment this line to visualize the region seams: - # latents = torch.where(merged_latents_weights > 1, 0.0, latents) - predicted_original = None - if merged_pred_original is not None: - predicted_original = torch.where( - merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original - ) + unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None) + with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model): + # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since + # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a + # separate scheduler state for each region batch. + # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler + # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect + # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when + # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each + # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion. + region_batch_schedulers: list[SchedulerMixin] = [ + copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning + ] callback( PipelineIntermediateState( - step=i, + step=-1, order=self.scheduler.order, total_steps=len(timesteps), - timestep=int(t), + timestep=self.scheduler.config.num_train_timesteps, latents=latents, - predicted_original=predicted_original, ) ) + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t = t.expand(batch_size) + + merged_latents = torch.zeros_like(latents) + merged_latents_weights = torch.zeros( + (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype + ) + merged_pred_original: torch.Tensor | None = None + for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): + # Switch to the scheduler for the region batch. + self.scheduler = region_batch_schedulers[region_idx] + + # Crop the inputs to the region. + region_latents = latents[ + :, + :, + region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, + region_conditioning.region.coords.left : region_conditioning.region.coords.right, + ] + + # Run the denoising step on the region. + step_output = self.step( + t=batched_t, + latents=region_latents, + conditioning_data=region_conditioning.text_conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=None, + mask=None, + masked_latents=None, + control_data=region_conditioning.control_data, + ) + + # Build a region_weight matrix that applies gradient blending to the edges of the region. + region = region_conditioning.region + _, _, region_height, region_width = step_output.prev_sample.shape + region_weight = torch.ones( + (1, 1, region_height, region_width), + dtype=latents.dtype, + device=latents.device, + ) + if region.overlap.left > 0: + left_grad = torch.linspace( + 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype + ).view((1, 1, 1, -1)) + region_weight[:, :, :, : region.overlap.left] *= left_grad + if region.overlap.top > 0: + top_grad = torch.linspace( + 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype + ).view((1, 1, -1, 1)) + region_weight[:, :, : region.overlap.top, :] *= top_grad + if region.overlap.right > 0: + right_grad = torch.linspace( + 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype + ).view((1, 1, 1, -1)) + region_weight[:, :, :, -region.overlap.right :] *= right_grad + if region.overlap.bottom > 0: + bottom_grad = torch.linspace( + 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype + ).view((1, 1, -1, 1)) + region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad + + # Update the merged results with the region results. + merged_latents[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += step_output.prev_sample * region_weight + merged_latents_weights[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += region_weight + + pred_orig_sample = getattr(step_output, "pred_original_sample", None) + if pred_orig_sample is not None: + # If one region has pred_original_sample, then we can assume that all regions will have it, because + # they all use the same scheduler. + if merged_pred_original is None: + merged_pred_original = torch.zeros_like(latents) + merged_pred_original[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += pred_orig_sample + + # Normalize the merged results. + latents = torch.where( + merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents + ) + # For debugging, uncomment this line to visualize the region seams: + # latents = torch.where(merged_latents_weights > 1, 0.0, latents) + predicted_original = None + if merged_pred_original is not None: + predicted_original = torch.where( + merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original + ) + + callback( + PipelineIntermediateState( + step=i, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=int(t), + latents=latents, + predicted_original=predicted_original, + ) + ) + return latents From 18fc36dbcd88cdfa69bc56ca7b9cd33ea9053aac Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 4 Aug 2024 04:26:05 +0300 Subject: [PATCH 17/25] Suggested changes Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --- .../app/services/config/config_default.py | 21 +++---------------- .../diffusion/custom_attention.py | 19 +++-------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 352b1b46830..7624d6a22dd 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -13,7 +13,6 @@ from typing import Any, Literal, Optional import psutil -import torch import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict @@ -450,23 +449,9 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict: del parsed_config_dict["attention_slice_size"] - # xformers attention removed, on mps better works normal attention - if attention_type == "xformers": - if torch.backends.mps.is_available(): - parsed_config_dict["attention_type"] = "normal" - else: - parsed_config_dict["attention_type"] = "torch-sdp" - - # slicing attention now enabled by `attention_slice_size` - if attention_type == "sliced": - if torch.backends.mps.is_available(): - parsed_config_dict["attention_type"] = "normal" - else: - parsed_config_dict["attention_type"] = "torch-sdp" - - # if no attention_slise_size in config, use balanced as default option - if "attention_slice_size" not in parsed_config_dict: - parsed_config_dict["attention_slice_size"] = "balanced" + # xformers attention removed, sliced moved to attention_slice_size + if attention_type in ["sliced", "xformers"]: + parsed_config_dict["attention_type"] = "auto" parsed_config_dict["schema_version"] = "4.0.3" return parsed_config_dict diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 370c10b1d0b..0ef6deb3988 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -6,7 +6,6 @@ import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention -from packaging.version import Version from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights @@ -54,12 +53,6 @@ def __init__( if self.slice_size == "auto": self.slice_size = self._select_slice_size() - if self.attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - # In 2.0 torch there no `scale` argument in sdp, it's added in 2.1 - self.scaled_sdp = Version(torch.__version__) >= Version("2.1") - def _select_attention_type(self) -> str: device = TorchDevice.choose_torch_device() # On some mps system normal attention still faster than torch-sdp on others - on par @@ -231,7 +224,7 @@ def run_ip_adapters( return hidden_states - def _get_slice_size(self, attn) -> Optional[int]: + def _get_slice_size(self, attn: Attention) -> Optional[int]: if self.slice_size == "none": return None if isinstance(self.slice_size, int): @@ -318,11 +311,8 @@ def run_attention_sdp( attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # the output of sdp = (batch, num_heads, seq_len, head_dim) - scale_kwargs = {} - if self.scaled_sdp: - scale_kwargs["scale"] = attn.scale hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, **scale_kwargs + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -367,9 +357,6 @@ def run_attention_sliced( if attn_mask_slice is not None: attn_mask_slice = attn_mask_slice.unsqueeze(0) - scale_kwargs = {} - if self.scaled_sdp: - scale_kwargs["scale"] = attn.scale hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention( query_slice.unsqueeze(0), key_slice.unsqueeze(0), @@ -377,7 +364,7 @@ def run_attention_sliced( attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False, - **scale_kwargs, + scale=attn.scale, ).squeeze(0) else: raise ValueError(f"Unknown attention type: {self.attention_type}") From f44e0cd01423382b77e0a16ac0ddd22168711ce6 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 4 Aug 2024 13:17:41 +0300 Subject: [PATCH 18/25] Update config docstring Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --- invokeai/app/services/config/config_default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 7624d6a22dd..6281facbc28 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -107,8 +107,8 @@ class InvokeAIAppConfig(BaseSettings): device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. - attention_type: Attention type.
Valid values: `auto`, `sliced`, `torch-sdp` - attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` + attention_type: Attention type.
Valid values: `auto`, `normal`, `torch-sdp` + attention_slice_size: Slice size
Valid values: `auto`, `none`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. From 9618b6e11f56370ed23dfe730ee7648ebdd7c7c8 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 6 Aug 2024 20:31:26 +0300 Subject: [PATCH 19/25] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- docs/installation/020_INSTALL_MANUAL.md | 8 -------- .../stable_diffusion/diffusion/custom_attention.py | 6 ++++-- .../diffusion/shared_invokeai_diffusion.py | 3 --- .../backend/stable_diffusion/diffusion_backend.py | 3 --- .../stable_diffusion/extensions/controlnet.py | 5 ----- invokeai/version/__init__.py | 12 ------------ scripts/invokeai-web.py | 3 --- 7 files changed, 4 insertions(+), 36 deletions(-) diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 059834eb453..8b7eeb0cbf7 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -87,14 +87,6 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121 ``` - - If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance. - - !!! example "Install with `xformers`" - - ```bash - pip install "InvokeAI[xformers]" --use-pep517 - ``` - 1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment: === "Linux/macOS" diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 0ef6deb3988..ca79f4a6f55 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -55,7 +55,10 @@ def __init__( def _select_attention_type(self) -> str: device = TorchDevice.choose_torch_device() - # On some mps system normal attention still faster than torch-sdp on others - on par + # On some mps system normal attention still faster than torch-sdp, on others - on par + # Results torch-sdp vs normal attention + # gogurt: 67.993s vs 67.729s + # Adreitz: 260.868s vs 226.638s if device.type == "mps": return "normal" else: # cuda, cpu @@ -89,7 +92,6 @@ def __call__( temb: Optional[torch.Tensor] = None, # For Regional Prompting: regional_prompt_data: Optional[RegionalPromptData] = None, - percent_through: Optional[torch.Tensor] = None, # For IP-Adapter: regional_ip_data: Optional[RegionalIPData] = None, *args, diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49f..d5f54ccff16 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -335,7 +335,6 @@ def _apply_standard_conditioning( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=regions, device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds @@ -426,7 +425,6 @@ def _apply_standard_conditioning_sequentially( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( @@ -474,7 +472,6 @@ def _apply_standard_conditioning_sequentially( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4191db734f9..dddf7e555d4 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -114,9 +114,6 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio sample=sample, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.inputs.timesteps), - ), ) ctx.conditioning_mode = conditioning_mode diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py index a48a681af3f..8728779e00b 100644 --- a/invokeai/backend/stable_diffusion/extensions/controlnet.py +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -112,8 +112,6 @@ def pre_unet_step(self, ctx: DenoiseContext): ctx.unet_kwargs.mid_block_additional_residual += mid_sample def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): - total_steps = len(ctx.inputs.timesteps) - model_input = ctx.latent_model_input image_tensor = self._image_tensor if conditioning_mode == ConditioningMode.Both: @@ -124,9 +122,6 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con sample=model_input, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditioning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / total_steps, - ), ) ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) diff --git a/invokeai/version/__init__.py b/invokeai/version/__init__.py index 57efb1af95f..8720b915320 100644 --- a/invokeai/version/__init__.py +++ b/invokeai/version/__init__.py @@ -6,15 +6,3 @@ __app_id__ = "invoke-ai/InvokeAI" __app_name__ = "InvokeAI" - - -def _ignore_xformers_triton_message_on_windows(): - import logging - - logging.getLogger("xformers").addFilter( - lambda record: "A matching Triton is not available" not in record.getMessage() - ) - - -# In order to be effective, this needs to happen before anything could possibly import xformers. -_ignore_xformers_triton_message_on_windows() diff --git a/scripts/invokeai-web.py b/scripts/invokeai-web.py index 691e58f7d17..cf68004cc6c 100755 --- a/scripts/invokeai-web.py +++ b/scripts/invokeai-web.py @@ -2,13 +2,10 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -import logging import os from invokeai.app.run_app import run_app -logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) - def main(): # Change working directory to the repo root From 09aef431f414665abdff5ff745dab9a0e3a92f07 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 7 Aug 2024 20:53:54 +0300 Subject: [PATCH 20/25] Restore xformers --- docker/Dockerfile | 7 ++- docs/installation/020_INSTALL_MANUAL.md | 8 +++ flake.nix | 2 +- installer/lib/installer.py | 4 +- invokeai/app/api/routers/app_info.py | 8 ++- .../app/services/config/config_default.py | 6 +- .../diffusion/custom_attention.py | 56 ++++++++++++++++++- invokeai/backend/util/hotfixes.py | 46 +++++++++++++++ .../frontend/web/src/services/api/schema.ts | 5 ++ invokeai/version/__init__.py | 12 ++++ pyproject.toml | 6 ++ scripts/invokeai-web.py | 3 + 12 files changed, 154 insertions(+), 9 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 24f2ff9e2f7..7ea078af0d9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -43,7 +43,12 @@ RUN --mount=type=cache,target=/root/.cache/pip \ extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \ fi &&\ - pip install $extra_index_url_arg -e "."; + # xformers + triton fails to install on arm64 + if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ + pip install $extra_index_url_arg -e ".[xformers]"; \ + else \ + pip install $extra_index_url_arg -e "."; \ + fi # #### Build the Web UI ------------------------------------ diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 8b7eeb0cbf7..059834eb453 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -87,6 +87,14 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121 ``` + - If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance. + + !!! example "Install with `xformers`" + + ```bash + pip install "InvokeAI[xformers]" --use-pep517 + ``` + 1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment: === "Linux/macOS" diff --git a/flake.nix b/flake.nix index bf8d2ae9466..3ccc6658121 100644 --- a/flake.nix +++ b/flake.nix @@ -84,7 +84,7 @@ in { devShells.${system} = rec { - develop = mkShell { dir = "venv"; install = "-e '.' --extra-index-url https://download.pytorch.org/whl/cu118"; }; + develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; }; default = develop; }; }; diff --git a/installer/lib/installer.py b/installer/lib/installer.py index 504c801df6d..11823b413e0 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -418,11 +418,11 @@ def get_torch_source() -> Tuple[str | None, str | None]: url = "https://download.pytorch.org/whl/cpu" elif device.value == "cuda": # CUDA uses the default PyPi index - optional_modules = "[onnx-cuda]" + optional_modules = "[xformers,onnx-cuda]" elif OS == "Windows": if device.value == "cuda": url = "https://download.pytorch.org/whl/cu121" - optional_modules = "[onnx-cuda]" + optional_modules = "[xformers,onnx-cuda]" elif device.value == "cpu": # CPU uses the default PyPi index, no optional modules pass diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 9f87e2cdec0..3206adb2421 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -1,6 +1,6 @@ import typing from enum import Enum -from importlib.metadata import version +from importlib.metadata import PackageNotFoundError, version from pathlib import Path from platform import python_version from typing import Optional @@ -56,6 +56,7 @@ class AppDependencyVersions(BaseModel): torch: str = Field(description="PyTorch version") torchvision: str = Field(description="PyTorch Vision version") transformers: str = Field(description="transformers version") + xformers: Optional[str] = Field(description="xformers version") class AppConfig(BaseModel): @@ -74,6 +75,10 @@ async def get_version() -> AppVersion: @app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions) async def get_app_deps() -> AppDependencyVersions: + try: + xformers = version("xformers") + except PackageNotFoundError: + xformers = None return AppDependencyVersions( accelerate=version("accelerate"), compel=version("compel"), @@ -87,6 +92,7 @@ async def get_app_deps() -> AppDependencyVersions: torch=torch.version.__version__, torchvision=version("torchvision"), transformers=version("transformers"), + xformers=xformers, ) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 6281facbc28..4f703b59262 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -28,7 +28,7 @@ DEFAULT_VRAM_CACHE = 0.25 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"] -ATTENTION_TYPE = Literal["auto", "normal", "torch-sdp"] +ATTENTION_TYPE = Literal["auto", "normal", "xformers", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] @@ -449,8 +449,8 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict: del parsed_config_dict["attention_slice_size"] - # xformers attention removed, sliced moved to attention_slice_size - if attention_type in ["sliced", "xformers"]: + # sliced moved to attention_slice_size + if attention_type == "sliced": parsed_config_dict["attention_type"] = "auto" parsed_config_dict["schema_version"] = "4.0.3" diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index ca79f4a6f55..35f81db9ba0 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention +from diffusers.utils.import_utils import is_xformers_available from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights @@ -13,6 +14,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.util.devices import TorchDevice +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + @dataclass class IPAdapterAttentionWeights: @@ -23,7 +30,9 @@ class IPAdapterAttentionWeights: class CustomAttnProcessor: """A custom implementation of attention processor that supports additional Invoke features. This implementation is based on + AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732) SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616) + XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113) AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204) Supported custom features: - IP-Adapter @@ -53,6 +62,9 @@ def __init__( if self.slice_size == "auto": self.slice_size = self._select_slice_size() + if self.attention_type == "xformers" and xformers is None: + raise ImportError("xformers attention requires xformers module to be installed.") + def _select_attention_type(self) -> str: device = TorchDevice.choose_torch_device() # On some mps system normal attention still faster than torch-sdp, on others - on par @@ -61,7 +73,14 @@ def _select_attention_type(self) -> str: # Adreitz: 260.868s vs 226.638s if device.type == "mps": return "normal" - else: # cuda, cpu + elif device.type == "cuda": + # Flash Attention is supported from sm80 compute capability onwards in PyTorch + # https://pytorch.org/blog/accelerated-pytorch-2/ + if torch.cuda.get_device_capability("cuda")[0] < 8 and xformers is not None: + return "xformers" + else: + return "torch-sdp" + else: # cpu return "torch-sdp" def _select_slice_size(self) -> str: @@ -262,6 +281,8 @@ def run_attention( attn_call = self.run_attention_sdp elif self.attention_type == "normal": attn_call = self.run_attention_normal + elif self.attention_type == "xformers": + attn_call = self.run_attention_xformers else: raise Exception(f"Unknown attention type: {self.attention_type}") @@ -291,6 +312,35 @@ def run_attention_normal( return hidden_states + def run_attention_xformers( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + if attention_mask is not None: + # expand our mask's singleton query_length dimension: + # [batch*heads, 1, key_length] -> + # [batch*heads, query_length, key_length] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_length, key_length] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + attention_mask = attention_mask.expand(-1, query.shape[1], -1) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + return hidden_states + def run_attention_sdp( self, attn: Attention, @@ -355,6 +405,10 @@ def run_attention_sliced( attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx]) del attn_slice + elif self.attention_type == "xformers": + hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention( + query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale + ) elif self.attention_type == "torch-sdp": if attn_mask_slice is not None: attn_mask_slice = attn_mask_slice.unsqueeze(0) diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index a9ed2538825..7e362fe9589 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -791,3 +791,49 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0): diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward + +try: + import xformers + + xformers_available = True +except Exception: + xformers_available = False + + +if xformers_available: + # TODO: remove when fixed in diffusers + _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention + + def new_memory_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias=None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op=None, + ): + # diffusers not align shape to 8, which is required by xformers + if attn_bias is not None and type(attn_bias) is torch.Tensor: + orig_size = attn_bias.shape[-1] + new_size = ((orig_size + 7) // 8) * 8 + aligned_attn_bias = torch.zeros( + (attn_bias.shape[0], attn_bias.shape[1], new_size), + device=attn_bias.device, + dtype=attn_bias.dtype, + ) + aligned_attn_bias[:, :, :orig_size] = attn_bias + attn_bias = aligned_attn_bias[:, :, :orig_size] + + return _xformers_memory_efficient_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=p, + scale=scale, + op=op, + ) + + xformers.ops.memory_efficient_attention = new_memory_efficient_attention diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index b4b39eae32b..79b82a23fae 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -725,6 +725,11 @@ export type components = { * @description transformers version */ transformers: string; + /** + * Xformers + * @description xformers version + */ + xformers: string | null; }; /** * AppVersion diff --git a/invokeai/version/__init__.py b/invokeai/version/__init__.py index 8720b915320..57efb1af95f 100644 --- a/invokeai/version/__init__.py +++ b/invokeai/version/__init__.py @@ -6,3 +6,15 @@ __app_id__ = "invoke-ai/InvokeAI" __app_name__ = "InvokeAI" + + +def _ignore_xformers_triton_message_on_windows(): + import logging + + logging.getLogger("xformers").addFilter( + lambda record: "A matching Triton is not available" not in record.getMessage() + ) + + +# In order to be effective, this needs to happen before anything could possibly import xformers. +_ignore_xformers_triton_message_on_windows() diff --git a/pyproject.toml b/pyproject.toml index 5bcf74d88cd..cdf032b301b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,12 @@ dependencies = [ ] [project.optional-dependencies] +"xformers" = [ + # Core generation dependencies, pinned for reproducible builds. + "xformers==0.0.25post1; sys_platform!='darwin'", + # Auxiliary dependencies, pinned only if necessary. + "triton; sys_platform=='linux'", +] "onnx" = ["onnxruntime"] "onnx-cuda" = ["onnxruntime-gpu"] "onnx-directml" = ["onnxruntime-directml"] diff --git a/scripts/invokeai-web.py b/scripts/invokeai-web.py index cf68004cc6c..691e58f7d17 100755 --- a/scripts/invokeai-web.py +++ b/scripts/invokeai-web.py @@ -2,10 +2,13 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +import logging import os from invokeai.app.run_app import run_app +logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) + def main(): # Change working directory to the repo root From 37dfab7cb1d5ed03ec027cec5fb6daacf6d978df Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 7 Aug 2024 21:23:32 +0300 Subject: [PATCH 21/25] Small fixes --- invokeai/app/services/config/config_default.py | 2 +- .../backend/stable_diffusion/diffusion/custom_attention.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 4f703b59262..06a197f630e 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -107,7 +107,7 @@ class InvokeAIAppConfig(BaseSettings): device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. - attention_type: Attention type.
Valid values: `auto`, `normal`, `torch-sdp` + attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `torch-sdp` attention_slice_size: Slice size
Valid values: `auto`, `none`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 35f81db9ba0..d29ec07815c 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -321,7 +321,7 @@ def run_attention_xformers( attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() + key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() if attention_mask is not None: @@ -406,6 +406,9 @@ def run_attention_sliced( torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx]) del attn_slice elif self.attention_type == "xformers": + if attn_mask_slice is not None: + attn_mask_slice = attn_mask_slice.expand(-1, query.shape[1], -1) + hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention( query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale ) From 192fba4fe3e2673d22ef69291a52f8859cd13bba Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 20 Aug 2024 02:03:34 +0300 Subject: [PATCH 22/25] Rewrite sliced attention, more optimizations(batched torch-sdp for old cuda, multihead xformers for high heads count) --- .../diffusion/custom_attention.py | 272 ++++++++++++------ invokeai/backend/util/hotfixes.py | 15 +- 2 files changed, 193 insertions(+), 94 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index d29ec07815c..d854e20efde 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -53,6 +53,9 @@ def __init__( self._ip_adapter_attention_weights = ip_adapter_attention_weights + device = TorchDevice.choose_torch_device() + self.is_old_cuda = device.type == "cuda" and torch.cuda.get_device_capability(device)[0] < 8 + config = get_config() self.attention_type = config.attention_type if self.attention_type == "auto": @@ -76,7 +79,7 @@ def _select_attention_type(self) -> str: elif device.type == "cuda": # Flash Attention is supported from sm80 compute capability onwards in PyTorch # https://pytorch.org/blog/accelerated-pytorch-2/ - if torch.cuda.get_device_capability("cuda")[0] < 8 and xformers is not None: + if self.is_old_cuda and xformers is not None: return "xformers" else: return "torch-sdp" @@ -265,9 +268,10 @@ def run_attention( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + no_sliced: bool = False, ) -> torch.Tensor: slice_size = self._get_slice_size(attn) - if slice_size is not None: + if not no_sliced and slice_size is not None: return self.run_attention_sliced( attn=attn, query=query, @@ -294,6 +298,41 @@ def run_attention( attention_mask=attention_mask, ) + @staticmethod + def _align_attention_mask_memory(attention_mask: torch.Tensor, alignment: int = 8) -> torch.Tensor: + if attention_mask.stride(-2) % alignment == 0 and attention_mask.stride(-2) != 0: + return attention_mask + + last_mask_dim = attention_mask.shape[-1] + new_last_mask_dim = last_mask_dim + (alignment - (last_mask_dim % alignment)) + attention_mask_mem = torch.empty( + attention_mask.shape[:-1] + (new_last_mask_dim,), + device=attention_mask.device, + dtype=attention_mask.dtype, + ) + attention_mask_mem[..., :last_mask_dim] = attention_mask + return attention_mask_mem[..., :last_mask_dim] + + @staticmethod + def _head_to_batch_dim(tensor: torch.Tensor, head_dim: int) -> torch.Tensor: + # [B, S, H*He] -> [B, S, H, He] + tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, head_dim) + # [B, S, H, He] -> [B, H, S, He] + tensor = tensor.permute(0, 2, 1, 3) + # [B, H, S, He] -> [B*H, S, He] + tensor = tensor.reshape(-1, tensor.shape[2], head_dim) + return tensor + + @staticmethod + def _batch_to_head_dim(tensor: torch.Tensor, batch_size: int) -> torch.Tensor: + # [B*H, S, He] -> [B, H, S, He] + tensor = tensor.reshape(batch_size, -1, tensor.shape[1], tensor.shape[2]) + # [B, H, S, He] -> [B, S, H, He] + tensor = tensor.permute(0, 2, 1, 3) + # [B, S, H, He] -> [B, S, H*He] + tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1) + return tensor + def run_attention_normal( self, attn: Attention, @@ -302,14 +341,17 @@ def run_attention_normal( value: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + + query = self._head_to_batch_dim(query, head_dim) + key = self._head_to_batch_dim(key, head_dim) + value = self._head_to_batch_dim(value, head_dim) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = self._batch_to_head_dim(hidden_states, batch_size) return hidden_states def run_attention_xformers( @@ -319,25 +361,62 @@ def run_attention_xformers( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + multihead: Optional[bool] = None, ) -> torch.Tensor: - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + + # batched execution on xformers slightly faster for small heads count + if multihead is None: + heads_count = query.shape[2] // head_dim + multihead = heads_count >= 4 + + if multihead: + # [B, S, H*He] -> [B, S, H, He] + query = query.view(batch_size, query.shape[1], -1, head_dim) + key = key.view(batch_size, key.shape[1], -1, head_dim) + value = value.view(batch_size, value.shape[1], -1, head_dim) + + if attention_mask is not None: + # [B*H, 1, S_key] -> [B, H, 1, S_key] + attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) + # expand our mask's singleton query dimension: + # [B, H, 1, S_key] -> + # [B, H, S_query, S_key] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + attention_mask = attention_mask.expand(-1, -1, query.shape[1], -1) + # xformers requires mask memory to be aligned to 8 + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale + ) + # [B, S_query, H, He] -> [B, S_query, H*He] + hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,)) + hidden_states = hidden_states.to(query.dtype) - if attention_mask is not None: - # expand our mask's singleton query_length dimension: - # [batch*heads, 1, key_length] -> - # [batch*heads, query_length, key_length] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_length, key_length] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - attention_mask = attention_mask.expand(-1, query.shape[1], -1) - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) + else: + # contiguous inputs slightly faster in batched execution + # [B, S, H*He] -> [B*H, S, He] + query = self._head_to_batch_dim(query, head_dim).contiguous() + key = self._head_to_batch_dim(key, head_dim).contiguous() + value = self._head_to_batch_dim(value, head_dim).contiguous() + + if attention_mask is not None: + # expand our mask's singleton query dimension: + # [B*H, 1, S_key] -> + # [B*H, S_query, S_key] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + attention_mask = attention_mask.expand(-1, query.shape[1], -1) + # xformers requires mask memory to be aligned to 8 + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + # [B*H, S_query, He] -> [B, S_query, H*He] + hidden_states = self._batch_to_head_dim(hidden_states, batch_size) return hidden_states @@ -348,27 +427,54 @@ def run_attention_sdp( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + multihead: Optional[bool] = None, ) -> torch.Tensor: - batch_size = key.shape[0] - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attention_mask is not None: - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + + if multihead is None: + # multihead extremely slow on old cuda gpu + multihead = not self.is_old_cuda + + if multihead: + # [B, S, H*He] -> [B, H, S, He] + query = query.view(batch_size, query.shape[1], -1, head_dim).transpose(1, 2) + key = key.view(batch_size, key.shape[1], -1, head_dim).transpose(1, 2) + value = value.view(batch_size, value.shape[1], -1, head_dim).transpose(1, 2) + + if attention_mask is not None: + # [B*H, 1, S_key] -> [B, H, 1, S_key] + attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) + # mask alignment to 8 decreases memory consumption and increases speed + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + ) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale - ) + # [B, H, S_query, He] -> [B, S_query, H, He] + hidden_states = hidden_states.transpose(1, 2) + # [B, S_query, H, He] -> [B, S_query, H*He] + hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,)) + hidden_states = hidden_states.to(query.dtype) + else: + # [B, S, H*He] -> [B*H, S, He] + query = self._head_to_batch_dim(query, head_dim) + key = self._head_to_batch_dim(key, head_dim) + value = self._head_to_batch_dim(value, head_dim) + + if attention_mask is not None: + # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key] + # mask alignment to 8 decreases memory consumption and increases speed + attention_mask = self._align_attention_mask_memory(attention_mask) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states.to(query.dtype) + # [B*H, S_query, He] -> [B, S_query, H*He] + hidden_states = self._batch_to_head_dim(hidden_states, batch_size) return hidden_states @@ -381,52 +487,50 @@ def run_attention_sliced( attention_mask: Optional[torch.Tensor], slice_size: int, ) -> torch.Tensor: - dim = query.shape[-1] - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) + batch_size = query.shape[0] + head_dim = attn.to_q.weight.shape[0] // attn.heads + heads_count = query.shape[2] // head_dim + + # [B, S, H*He] -> [B, H, S, He] + query = query.reshape(query.shape[0], query.shape[1], -1, head_dim).transpose(1, 2) + key = key.reshape(key.shape[0], key.shape[1], -1, head_dim).transpose(1, 2) + value = value.reshape(value.shape[0], value.shape[1], -1, head_dim).transpose(1, 2) + # [B*H, S_query/1, S_key] -> [B, H, S_query/1, S_key] + if attention_mask is not None: + attention_mask = attention_mask.reshape(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) - batch_size_attention, query_tokens, _ = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) + # [B, H, S_query, He] + hidden_states = torch.empty(query.shape, device=query.device, dtype=query.dtype) - for i in range((batch_size_attention - 1) // slice_size + 1): + for i in range((heads_count - 1) // slice_size + 1): start_idx = i * slice_size end_idx = (i + 1) * slice_size - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - value_slice = value[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - if self.attention_type == "normal": - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx]) - del attn_slice - elif self.attention_type == "xformers": - if attn_mask_slice is not None: - attn_mask_slice = attn_mask_slice.expand(-1, query.shape[1], -1) - - hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention( - query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale - ) - elif self.attention_type == "torch-sdp": - if attn_mask_slice is not None: - attn_mask_slice = attn_mask_slice.unsqueeze(0) - - hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention( - query_slice.unsqueeze(0), - key_slice.unsqueeze(0), - value_slice.unsqueeze(0), - attn_mask=attn_mask_slice, - dropout_p=0.0, - is_causal=False, - scale=attn.scale, - ).squeeze(0) - else: - raise ValueError(f"Unknown attention type: {self.attention_type}") + # [B, H_s, S, He] -> [B, S, H_s*He] + query_slice = query[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, query.shape[2], -1) + key_slice = key[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, key.shape[2], -1) + value_slice = value[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, value.shape[2], -1) - hidden_states = attn.batch_to_head_dim(hidden_states) - return hidden_states + # [B, H_s, S_query/1, S_key] -> [B*H_s, S_query/1, S_key] + attn_mask_slice = None + if attention_mask is not None: + attn_mask_slice = attention_mask[:, start_idx:end_idx, :, :].reshape((-1,) + attention_mask.shape[-2:]) + + # [B, S_query, H_s*He] + hidden_states_slice = self.run_attention( + attn=attn, + query=query_slice, + key=key_slice, + value=value_slice, + attention_mask=attn_mask_slice, + no_sliced=True, + ) + + # [B, S_query, H_s*He] -> [B, H_s, S_query, He] + hidden_states[:, start_idx:end_idx] = hidden_states_slice.reshape( + hidden_states_slice.shape[:-1] + (-1, head_dim) + ).transpose(1, 2) + + # [B, H_s, S_query, He] -> [B, S_query, H_s*He] + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(hidden_states.shape[:-2] + (-1,)) diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 7e362fe9589..1991aed3ccc 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -802,8 +802,11 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0): if xformers_available: # TODO: remove when fixed in diffusers + from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor + _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention + # TODO: remove? or there still possible calls to xformers not by our attention processor? def new_memory_efficient_attention( query: torch.Tensor, key: torch.Tensor, @@ -815,16 +818,8 @@ def new_memory_efficient_attention( op=None, ): # diffusers not align shape to 8, which is required by xformers - if attn_bias is not None and type(attn_bias) is torch.Tensor: - orig_size = attn_bias.shape[-1] - new_size = ((orig_size + 7) // 8) * 8 - aligned_attn_bias = torch.zeros( - (attn_bias.shape[0], attn_bias.shape[1], new_size), - device=attn_bias.device, - dtype=attn_bias.dtype, - ) - aligned_attn_bias[:, :, :orig_size] = attn_bias - attn_bias = aligned_attn_bias[:, :, :orig_size] + if attn_bias is not None and isinstance(attn_bias, torch.Tensor): + attn_bias = CustomAttnProcessor._align_attention_mask_memory(attn_bias) return _xformers_memory_efficient_attention( query=query, From 0b1ff8f659f5bf605e7da3932b6c7652238ada5a Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 20 Aug 2024 02:19:57 +0300 Subject: [PATCH 23/25] Remove redundant alignment in batched torch-sdp execution, add comments --- .../diffusion/custom_attention.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index d854e20efde..8b1b0d3872d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -367,6 +367,12 @@ def run_attention_xformers( head_dim = attn.to_q.weight.shape[0] // attn.heads # batched execution on xformers slightly faster for small heads count + # 8 heads: + # xformers(dim3): 20.155955553054810 vram: 16483328 + # xformers(dim4): 17.558132648468018 vram: 16483328 + # 1 head: + # xformers(dim3): 5.660739183425903 vram: 9516032 + # xformers(dim4): 6.114191055297852 vram: 9516032 if multihead is None: heads_count = query.shape[2] // head_dim multihead = heads_count >= 4 @@ -433,7 +439,9 @@ def run_attention_sdp( head_dim = attn.to_q.weight.shape[0] // attn.heads if multihead is None: - # multihead extremely slow on old cuda gpu + # multihead extremely slow on old cuda gpu: + # torch-sdp(dim3): 30.07543110847473 vram: 23954432 + # torch-sdp(dim4): 299.3908393383026 vram: 13861888 multihead = not self.is_old_cuda if multihead: @@ -446,6 +454,12 @@ def run_attention_sdp( # [B*H, 1, S_key] -> [B, H, 1, S_key] attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) # mask alignment to 8 decreases memory consumption and increases speed + # fp16: + # torch-sdp(dim4, mask): 6.1701478958129880 vram: 7864320 + # torch-sdp(dim4, aligned mask): 3.3127212524414062 vram: 2621440 + # fp32: + # torch-sdp(dim4, mask): 23.0943229198455800 vram: 16121856 + # torch-sdp(dim4, aligned mask): 17.3104763031005860 vram: 5636096 attention_mask = self._align_attention_mask_memory(attention_mask) hidden_states = F.scaled_dot_product_attention( @@ -463,10 +477,10 @@ def run_attention_sdp( key = self._head_to_batch_dim(key, head_dim) value = self._head_to_batch_dim(value, head_dim) - if attention_mask is not None: - # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key] - # mask alignment to 8 decreases memory consumption and increases speed - attention_mask = self._align_attention_mask_memory(attention_mask) + # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key] + # and there no noticable changes from memory alignment: + # torch-sdp(dim3, mask): 9.7391905784606930 vram: 12713984 + # torch-sdp(dim3, aligned mask): 10.0090200901031500 vram: 12713984 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale From 3d19cacdc4c72bd9371db74b626bdeda5d002cd8 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 20 Aug 2024 03:05:39 +0300 Subject: [PATCH 24/25] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- .../diffusion/custom_attention.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 8b1b0d3872d..740402d8b91 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -8,6 +8,7 @@ from diffusers.models.attention_processor import Attention from diffusers.utils.import_utils import is_xformers_available +import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData @@ -77,12 +78,20 @@ def _select_attention_type(self) -> str: if device.type == "mps": return "normal" elif device.type == "cuda": + # In testing on a Tesla P40 (Pascal architecture), torch-sdp is much slower than xformers + # (8.84 s/it vs. 1.81 s/it for SDXL). We have not tested extensively to find the precise GPU architecture or + # compute capability where this performance gap begins. # Flash Attention is supported from sm80 compute capability onwards in PyTorch - # https://pytorch.org/blog/accelerated-pytorch-2/ - if self.is_old_cuda and xformers is not None: - return "xformers" - else: - return "torch-sdp" + # (https://pytorch.org/blog/accelerated-pytorch-2/). For now, we use this as the cutoff for selecting + # between xformers and torch-sdp. + if self.is_old_cuda: + if xformers is not None: + return "xformers" + logger.warning( + f"xFormers is not installed, but is recommended for best performance with GPU {torch.cuda.get_device_properties(device).name}" + ) + + return "torch-sdp" else: # cpu return "torch-sdp" @@ -478,7 +487,7 @@ def run_attention_sdp( value = self._head_to_batch_dim(value, head_dim) # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key] - # and there no noticable changes from memory alignment: + # and there no noticable changes from memory alignment in batched run: # torch-sdp(dim3, mask): 9.7391905784606930 vram: 12713984 # torch-sdp(dim3, aligned mask): 10.0090200901031500 vram: 12713984 From b947129799ecaa68880585a8183fa48446fa0f65 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 20 Aug 2024 21:28:02 +0300 Subject: [PATCH 25/25] Edit comments --- .../diffusion/custom_attention.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 740402d8b91..68d3bdc7c84 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -376,12 +376,12 @@ def run_attention_xformers( head_dim = attn.to_q.weight.shape[0] // attn.heads # batched execution on xformers slightly faster for small heads count - # 8 heads: - # xformers(dim3): 20.155955553054810 vram: 16483328 - # xformers(dim4): 17.558132648468018 vram: 16483328 - # 1 head: - # xformers(dim3): 5.660739183425903 vram: 9516032 - # xformers(dim4): 6.114191055297852 vram: 9516032 + # 8 heads, fp16 (100000 attention calls): + # xformers(dim3): 20.155955553054810s vram: 16483328b + # xformers(dim4): 17.558132648468018s vram: 16483328b + # 1 head, fp16 (100000 attention calls): + # xformers(dim3): 5.660739183425903s vram: 9516032b + # xformers(dim4): 6.114191055297852s vram: 9516032b if multihead is None: heads_count = query.shape[2] // head_dim multihead = heads_count >= 4 @@ -449,8 +449,9 @@ def run_attention_sdp( if multihead is None: # multihead extremely slow on old cuda gpu: - # torch-sdp(dim3): 30.07543110847473 vram: 23954432 - # torch-sdp(dim4): 299.3908393383026 vram: 13861888 + # fp16 (100000 attention calls): + # torch-sdp(dim3): 30.07543110847473s vram: 23954432b + # torch-sdp(dim4): 299.3908393383026s vram: 13861888b multihead = not self.is_old_cuda if multihead: @@ -463,12 +464,12 @@ def run_attention_sdp( # [B*H, 1, S_key] -> [B, H, 1, S_key] attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2]) # mask alignment to 8 decreases memory consumption and increases speed - # fp16: - # torch-sdp(dim4, mask): 6.1701478958129880 vram: 7864320 - # torch-sdp(dim4, aligned mask): 3.3127212524414062 vram: 2621440 - # fp32: - # torch-sdp(dim4, mask): 23.0943229198455800 vram: 16121856 - # torch-sdp(dim4, aligned mask): 17.3104763031005860 vram: 5636096 + # fp16 (100000 attention calls): + # torch-sdp(dim4, mask): 6.1701478958129880s vram: 7864320b + # torch-sdp(dim4, aligned mask): 3.3127212524414062s vram: 2621440b + # fp32 (100000 attention calls): + # torch-sdp(dim4, mask): 23.0943229198455800s vram: 16121856b + # torch-sdp(dim4, aligned mask): 17.3104763031005860s vram: 5636096b attention_mask = self._align_attention_mask_memory(attention_mask) hidden_states = F.scaled_dot_product_attention( @@ -488,8 +489,9 @@ def run_attention_sdp( # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key] # and there no noticable changes from memory alignment in batched run: - # torch-sdp(dim3, mask): 9.7391905784606930 vram: 12713984 - # torch-sdp(dim3, aligned mask): 10.0090200901031500 vram: 12713984 + # fp16 (100000 attention calls): + # torch-sdp(dim3, mask): 9.7391905784606930s vram: 12713984b + # torch-sdp(dim3, aligned mask): 10.0090200901031500s vram: 12713984b hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale