Skip to content

Commit

Permalink
Add DualClipLoader to load clip models for SDXL.
Browse files Browse the repository at this point in the history
Update LoadClip to load clip models for SDXL refiner.
  • Loading branch information
comfyanonymous committed Jun 25, 2023
1 parent b793396 commit 20f579d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 11 deletions.
41 changes: 32 additions & 9 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip

def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
Expand Down Expand Up @@ -524,7 +525,7 @@ def clone(self):
return n

def load_from_state_dict(self, sd):
self.cond_stage_model.transformer.load_state_dict(sd, strict=False)
self.cond_stage_model.load_sd(sd)

def add_patches(self, patches, strength=1.0):
return self.patcher.add_patches(patches, strength)
Expand Down Expand Up @@ -555,6 +556,8 @@ def encode(self, text):
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)

def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)

class VAE:
def __init__(self, ckpt_path=None, device=None, config=None):
Expand Down Expand Up @@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
return StyleModel(model)


def load_clip(ckpt_path, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
def load_clip(ckpt_paths, embedding_directory=None):
clip_data = []
for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True))

class EmptyClass:
pass

for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)

clip_target = EmptyClass()
clip_target.params = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

clip = CLIP(clip_target, embedding_directory=embedding_directory)
clip.load_from_state_dict(clip_data)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
print("clip missing:", m)

if len(u) > 0:
print("clip unexpected:", u)
return clip

def load_gligen(ckpt_path):
Expand Down
3 changes: 3 additions & 0 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def forward(self, tokens):
def encode(self, tokens):
return self(tokens)

def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)

def parse_parentheses(string):
result = []
current_item = ""
Expand Down
13 changes: 13 additions & 0 deletions comfy/sdxl_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def clip_layer(self, layer_idx):
self.layer = "hidden"
self.layer_idx = layer_idx

def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
return super().load_sd(sd)

class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
Expand Down Expand Up @@ -68,6 +73,12 @@ def encode_token_weights(self, token_weight_pairs):
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled

def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
else:
return self.clip_l.load_sd(sd)

class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
Expand All @@ -81,3 +92,5 @@ def encode_token_weights(self, token_weight_pairs):
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled

def load_sd(self, sd):
return self.clip_g.load_sd(sd)
21 changes: 19 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,27 @@ def INPUT_TYPES(s):
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"

CATEGORY = "loaders"
CATEGORY = "advanced/loaders"

def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)

class DualCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"

CATEGORY = "advanced/loaders"

def load_clip(self, clip_name1, clip_name2):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)

class CLIPVisionLoader:
Expand Down Expand Up @@ -1315,6 +1331,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader,
"DualCLIPLoader": DualCLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning,
Expand Down

0 comments on commit 20f579d

Please sign in to comment.