Skip to content

Commit

Permalink
Support base SDXL and SDXL refiner models.
Browse files Browse the repository at this point in the history
Large refactor of the model detection and loading code.
  • Loading branch information
comfyanonymous committed Jun 22, 2023
1 parent 9fccf4a commit f87ec10
Show file tree
Hide file tree
Showing 16 changed files with 756 additions and 291 deletions.
42 changes: 39 additions & 3 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def __init__(
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
Expand All @@ -51,6 +53,8 @@ def __init__(
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
):
super().__init__()
if use_spatial_transformer:
Expand All @@ -75,6 +79,10 @@ def __init__(
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
Expand All @@ -97,8 +105,10 @@ def __init__(
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
Expand All @@ -111,6 +121,24 @@ def __init__(
linear(time_embed_dim, time_embed_dim),
)

if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()

self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
Expand Down Expand Up @@ -179,7 +207,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
Expand Down Expand Up @@ -238,7 +266,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
Expand All @@ -257,14 +285,22 @@ def __init__(
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

def forward(self, x, hint, timesteps, context, **kwargs):
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

guided_hint = self.input_hint_block(hint, emb, context)

outs = []

hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
Expand Down
23 changes: 23 additions & 0 deletions comfy/clip_config_bigg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 20,
"num_hidden_layers": 32,
"pad_token_id": 1,
"projection_dim": 512,
"torch_dtype": "float32",
"vocab_size": 49408
}
28 changes: 14 additions & 14 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,31 @@ def encode_image(self, image):
outputs = self.model(**inputs)
return outputs

def convert_to_transformers(sd):
def convert_to_transformers(sd, prefix):
sd_k = sd.keys()
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
keys_to_replace = {
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
"{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
"{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
"{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
"{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
"{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
"{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
"{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
}

for x in keys_to_replace:
if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x)

if "embedder.model.visual.proj" in sd_k:
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)

sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
sd = transformers_convert(sd, prefix, "vision_model.", 32)
return sd

def load_clipvision_from_sd(sd):
sd = convert_to_transformers(sd)
def load_clipvision_from_sd(sd, prefix):
sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def __init__(self, in_channels, n_heads, d_head,
use_checkpoint=True, dtype=None):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype)
Expand Down Expand Up @@ -630,7 +630,7 @@ def __init__(self, in_channels, n_heads, d_head,
def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
Expand Down
11 changes: 8 additions & 3 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def __init__(
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
):
super().__init__()
if use_spatial_transformer:
Expand All @@ -526,6 +527,10 @@ def __init__(
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
Expand Down Expand Up @@ -631,7 +636,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
)
Expand Down Expand Up @@ -690,7 +695,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
),
Expand Down Expand Up @@ -746,7 +751,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
)
Expand Down
78 changes: 75 additions & 3 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np

class BaseModel(torch.nn.Module):
Expand All @@ -15,9 +16,9 @@ def __init__(self, unet_config, v_prediction=False):
self.parameterization = "v"
else:
self.parameterization = "eps"
if "adm_in_channels" in unet_config:
self.adm_channels = unet_config["adm_in_channels"]
else:

self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("adm", self.adm_channels)
Expand Down Expand Up @@ -55,6 +56,25 @@ def get_dtype(self):
def is_adm(self):
return self.adm_channels > 0

def encode_adm(self, **kwargs):
return None

def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)

if len(u) > 0:
print("unet unexpected:", u)
del to_load
return self

class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction)
Expand Down Expand Up @@ -95,3 +115,55 @@ class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.concat_keys = ("mask", "masked_image")

class SDXLRefiner(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.embedder = Timestep(256)

def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)

if kwargs.get("prompt_type", "") == "negative":
aesthetic_score = kwargs.get("aesthetic_score", 2.5)
else:
aesthetic_score = kwargs.get("aesthetic_score", 6)

print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.embedder = Timestep(256)

def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)

print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([target_width])))
out.append(self.embedder(torch.Tensor([target_height])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
Loading

0 comments on commit f87ec10

Please sign in to comment.