Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Generation backend rewrite #6577

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8298110
New backend base
StAlKeR7779 Jul 3, 2024
9324b18
Restore generation preview
StAlKeR7779 Jul 3, 2024
dccb9f1
Restore rescale cfg
StAlKeR7779 Jul 3, 2024
f545e60
Restore inpaint support
StAlKeR7779 Jul 3, 2024
44f4f25
Restore regional prompts
StAlKeR7779 Jul 3, 2024
d71896f
Restore t2i adapter
StAlKeR7779 Jul 3, 2024
118b54b
Restore controlnet
StAlKeR7779 Jul 3, 2024
e013899
Restore ip adapters
StAlKeR7779 Jul 3, 2024
4aae90d
Change modifier/override handler args handling
StAlKeR7779 Jul 5, 2024
c82dba0
Implement tiled denoise test version
StAlKeR7779 Jul 5, 2024
f5f6dea
Move new backend logic to node
StAlKeR7779 Jul 6, 2024
9970af2
Update logic in denoise node(try)
StAlKeR7779 Jul 7, 2024
8cfb712
A bit optimize ip adapter loading
StAlKeR7779 Jul 7, 2024
b9434e0
Optmize extensions patching methods
StAlKeR7779 Jul 7, 2024
61529f5
Add seamless support
StAlKeR7779 Jul 7, 2024
0e8b434
Add FreeU support
StAlKeR7779 Jul 7, 2024
7e465e5
Redo lora patcher as extension class
StAlKeR7779 Jul 7, 2024
2c64974
Merge lora patcher to extension class, call lora patcher extension cl…
StAlKeR7779 Jul 7, 2024
d3b1b2f
Merge seamless patcher to extension class
StAlKeR7779 Jul 7, 2024
dc58274
Rewrite tiled denoise node
StAlKeR7779 Jul 7, 2024
d69ec3a
Add t2i and ip adapter to tiled generation
StAlKeR7779 Jul 8, 2024
d42f257
Merge branch 'main' into stalker7779/gen-backend-rewrite3
StAlKeR7779 Jul 8, 2024
bd4de46
Clean up code and imports, remove old backend code
StAlKeR7779 Jul 8, 2024
d4d5684
hotfix inpaint gradient mask
StAlKeR7779 Jul 8, 2024
b7c91a2
Ruff format/fixes
StAlKeR7779 Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.stable_diffusion.extensions import LoRAPatcherExt
from invokeai.backend.util.devices import TorchDevice

# unconditioned: Optional[torch.Tensor]
Expand Down Expand Up @@ -82,9 +83,10 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
LoRAPatcherExt.static_patch_model(
model=text_encoder,
loras=_lora_loader(),
prefix="lora_te_",
model_state_dict=model_state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
Expand Down Expand Up @@ -177,8 +179,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
LoRAPatcherExt.static_patch_model(
model=text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
Expand Down
495 changes: 180 additions & 315 deletions invokeai/app/invocations/denoise_latents.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.extensions import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice

Expand Down Expand Up @@ -59,7 +59,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
Expand Down
258 changes: 125 additions & 133 deletions invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
from typing import Optional, Union

import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
Expand All @@ -19,38 +17,28 @@
LatentsField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import StableDiffusionBackend
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.extensions import (
FreeUExt,
LoRAPatcherExt,
PipelineIntermediateState,
PreviewExt,
RescaleCFGExt,
SeamlessExt,
TiledDenoiseExt,
)
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice


def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy


@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
Expand Down Expand Up @@ -126,6 +114,18 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
default=None,
input=Input.Connection,
)
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
description=FieldDescriptions.t2i_adapter,
title="T2I-Adapter",
default=None,
input=Input.Connection,
)
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter,
title="IP-Adapter",
default=None,
input=Input.Connection,
)

@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
Expand All @@ -139,141 +139,133 @@ def ge_one(cls, v: list[float] | float) -> list[float] | float:
raise ValueError("cfg_scale must be greater than 1")
return v

@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]

def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()

return MultiDiffusionPipeline(
vae=FakeVae(),
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR

seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape

# Calculate the tile locations to cover the latent-space image.
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
with ExitStack() as exit_stack:
ext_manager = ExtensionsManager()

# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

# Load the UNet model.
unet_info = context.models.load(self.unet.unet)

with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
)

controlnet_data = DenoiseLatentsInvocation.prep_control_data(
scheduler = get_scheduler(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)

# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)

# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)

timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)

# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
denoise_ctx = DenoiseContext(
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
callback=step_callback,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
unet=None,
scheduler=scheduler,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### inpaint
# mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
# if mask is not None or unet_config.variant == "inpaint": # ModelVariantType.Inpaint: # is_inpainting_model(unet):
# ext_manager.add_extension(InpaintExt(mask, masked_latents, is_gradient_mask, priority=200))

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback, priority=99999))

### cfg rescale
if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier, priority=100))

### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes, priority=100))

### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config, priority=100))

### lora
if self.unet.loras:
ext_manager.add_extension(
LoRAPatcherExt(
node_context=context,
loras=self.unet.loras,
prefix="lora_unet_",
priority=100,
)
)

### tiled denoise
ext_manager.add_extension(
TiledDenoiseExt(
tile_width=self.tile_width,
tile_height=self.tile_height,
tile_overlap=self.tile_overlap,
priority=100,
)
)

# later will be like:
# for extension_field in self.extensions:
# ext = extension_field.to_extension(exit_stack, context)
# ext_manager.add_extension(ext)
DenoiseLatentsInvocation.parse_t2i_field(exit_stack, context, self.t2i_adapter, ext_manager)
DenoiseLatentsInvocation.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
# TODO: works fine with tiled too?
DenoiseLatentsInvocation.parse_ip_adapter_field(exit_stack, context, self.ip_adapter, ext_manager)

# ext: t2i/ip adapter
ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
# ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
Expand Down
Loading
Loading