Skip to content

Commit

Permalink
feature: sdxl inpaint support
Browse files Browse the repository at this point in the history
  • Loading branch information
brycedrennan committed Jan 13, 2024
1 parent 700cb45 commit af0876f
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 15 deletions.
5 changes: 4 additions & 1 deletion imaginairy/api/generate_refiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ def latent_logger(latents):
comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput]
if prompt.model_weights.architecture.primary_alias == "sdxl":
if prompt.model_weights.architecture.primary_alias in (
"sdxl",
"sdxlinpaint",
):
compose_control_inputs = []
else:
compose_control_inputs = [
Expand Down
18 changes: 17 additions & 1 deletion imaginairy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def primary_alias(self):
output_modality="image",
defaults={"size": "1024"},
),
ModelArchitecture(
name="Stable Diffusion XL",
aliases=["sdxlinpaint", "sd-xlinpaint", "sdxl-inpaint"],
output_modality="image",
defaults={"size": "1024"},
),
ModelArchitecture(
name="Stable Video Diffusion",
aliases=["svd", "stablevideo"],
Expand Down Expand Up @@ -162,7 +168,7 @@ def __post_init__(self):
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
),
ModelWeightsConfig(
name="Modern Disney",
name="Redshift Diffusion",
aliases=["redshift-diffusion", "red", "redshift-diffusion-15", "red15"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/nitrosocke/redshift-diffusion/tree/80837fe18df05807861ab91c3bad3693c9342e4c/",
Expand All @@ -179,6 +185,16 @@ def __post_init__(self):
},
weights_location="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/462165984030d82259a11f4367a4eed129e94a7b/",
),
ModelWeightsConfig(
name="Stable Diffusion XL - Inpainting",
aliases=MODEL_ARCHITECTURE_LOOKUP["sdxl-inpaint"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sdxl-inpaint"],
defaults={
"negative_prompt": DEFAULT_NEGATIVE_PROMPT,
"composition_strength": 0.6,
},
weights_location="https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1/tree/115134f363124c53c7d878647567d04daf26e41e/",
),
ModelWeightsConfig(
name="OpenDalle V1.1",
aliases=["opendalle11", "odv11", "opendalle11", "opendalle", "od"],
Expand Down
128 changes: 128 additions & 0 deletions imaginairy/modules/refiners_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from functools import lru_cache
from typing import Any, List, Literal

import numpy as np
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F

Expand All @@ -16,6 +18,7 @@
ScaledDotProductAttention,
)
from imaginairy.vendored.refiners.fluxion.layers.chain import ChainError
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
Expand Down Expand Up @@ -395,6 +398,131 @@ def prompts_to_embeddings(self, prompts: List[WeightedPrompt]) -> Tensor:
return conditioning


class StableDiffusion_XL_Inpainting(StableDiffusion_XL):
def __init__(
self,
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
device: Device | str | None = "cpu",
dtype: DType | None = None,
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)

def forward(
self,
x: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor | None = None,
condition_scale: float = 5.0,
**_: Tensor,
) -> Tensor:
assert self.mask_latents is not None
assert self.target_image_latents is not None
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
return super().forward(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=condition_scale,
)

def set_inpainting_conditions(
self,
target_image: Image.Image,
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
target_image = target_image.convert(mode="RGB")
mask = mask.convert(mode="L")

mask_tensor = torch.tensor(
data=np.array(object=mask).astype(dtype=np.float32) / 255.0
).to(device=self.device)
mask_tensor = (
(mask_tensor > 0.5)
.unsqueeze(dim=0)
.unsqueeze(dim=0)
.to(dtype=self.unet.dtype)
)

self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))

init_image_tensor = (
image_to_tensor(
image=target_image, device=self.device, dtype=self.unet.dtype
)
* 2
- 1
)
masked_init_image = init_image_tensor * (1 - mask_tensor)
self.target_image_latents = self.lda.encode(
x=masked_init_image.to(dtype=self.lda.dtype)
)
assert self.target_image_latents is not None
self.target_image_latents = self.target_image_latents.to(dtype=self.unet.dtype)

return self.mask_latents, self.target_image_latents # type: ignore

def compute_self_attention_guidance(
self,
x: Tensor,
noise: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**kwargs: Tensor,
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
assert self.mask_latents is not None
assert self.target_image_latents is not None

degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)

negative_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)

return sag.scale * (noise - degraded_noise)


class SlicedEncoderMixin(nn.Module):
max_chunk_size = 2048
min_chunk_size = 32
Expand Down
51 changes: 39 additions & 12 deletions imaginairy/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from imaginairy import config as iconfig
from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture
from imaginairy.modules import attention
from imaginairy.modules.refiners_sd import SDXLAutoencoderSliced, StableDiffusion_XL
from imaginairy.modules.refiners_sd import (
SDXLAutoencoderSliced,
StableDiffusion_XL,
StableDiffusion_XL_Inpainting,
)
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
Expand Down Expand Up @@ -268,8 +272,10 @@ def _get_diffusion_model_refiners(
device=device,
dtype=dtype,
)
elif architecture.primary_alias == "sdxl":
sd = load_sdxl_pipeline(base_url=weights_location, device=device)
elif architecture.primary_alias in ("sdxl", "sdxlinpaint"):
sd = load_sdxl_pipeline(
base_url=weights_location, device=device, for_inpainting=for_inpainting
)
else:
msg = f"Invalid architecture {architecture.primary_alias}"
raise ValueError(msg)
Expand Down Expand Up @@ -734,7 +740,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):


def load_sdxl_pipeline_from_diffusers_weights(
base_url: str, device=None, dtype=torch.float16
base_url: str, for_inpainting=False, device=None, dtype=torch.float16
):
from imaginairy.utils import get_device

Expand Down Expand Up @@ -764,7 +770,10 @@ def load_sdxl_pipeline_from_diffusers_weights(
source_path=unet_weights_path,
device="cpu",
)
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
if for_inpainting:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9)
else:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
unet.load_state_dict(unet_weights, assign=True)
del unet_weights

Expand All @@ -789,15 +798,20 @@ def load_sdxl_pipeline_from_diffusers_weights(
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL(
if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting
else:
StableDiffusionCls = StableDiffusion_XL

sd = StableDiffusionCls(
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
)

return sd


def load_sdxl_pipeline_from_compvis_weights(
base_url: str, device=None, dtype=torch.float16
base_url: str, for_inpainting=False, device=None, dtype=torch.float16
):
from imaginairy.utils import get_device

Expand All @@ -809,7 +823,10 @@ def load_sdxl_pipeline_from_compvis_weights(
lda.load_state_dict(vae_weights, assign=True)
del vae_weights

unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
if for_inpainting:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9)
else:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
unet.load_state_dict(unet_weights, assign=True)
del unet_weights

Expand All @@ -819,21 +836,31 @@ def load_sdxl_pipeline_from_compvis_weights(
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL(

if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting
else:
StableDiffusionCls = StableDiffusion_XL
sd = StableDiffusionCls(
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
)

return sd


def load_sdxl_pipeline(base_url, device=None):
def load_sdxl_pipeline(base_url, device=None, for_inpainting=False):
logger.info(f"Loading SDXL weights from {base_url}")
device = device or get_device()

with logger.timed_info(f"Loaded SDXL pipeline from {base_url}"):
if is_diffusers_repo_url(base_url):
sd = load_sdxl_pipeline_from_diffusers_weights(base_url, device=device)
sd = load_sdxl_pipeline_from_diffusers_weights(
base_url, for_inpainting=for_inpainting, device=device
)
else:
sd = load_sdxl_pipeline_from_compvis_weights(base_url, device=device)
sd = load_sdxl_pipeline_from_compvis_weights(
base_url, for_inpainting=for_inpainting, device=device
)
return sd


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"text_model.embeddings.token_embedding": "Sum.TokenEncoder",
"text_model.embeddings.position_embedding": "Sum.PositionalEncoder.Embedding",
"text_model.final_layer_norm": "LayerNorm",
"text_projection": "Linear"
"text_projection": "Linear",
"text_model.embeddings.position_ids": null
},
"regex_map": {
"text_model\\.encoder\\.layers\\.(?P<layer>\\d+)\\.layer_norm(?P<norm>\\d+)": "TransformerLayer_{int(layer) + 1}.Residual_{norm}.LayerNorm",
Expand Down

0 comments on commit af0876f

Please sign in to comment.