Skip to content

Commit

Permalink
Support SDXL embedding format with 2 CLIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 10, 2023
1 parent 6ad0a6d commit 606a537
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, scale_factor=0.18215):
class SDXL(LatentFormat):
def __init__(self):
self.scale_factor = 0.13025
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
Expand Down
11 changes: 7 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def expand_directory_list(directories):
dirs.add(root)
return list(dirs)

def load_embed(embedding_name, embedding_directory, embedding_size):
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]

Expand Down Expand Up @@ -292,13 +292,15 @@ def load_embed(embedding_name, embedding_directory, embedding_size):
continue
out_list.append(t.reshape(-1, t.shape[-1]))
embed_out = torch.cat(out_list, dim=0)
elif embed_key is not None and embed_key in embed:
embed_out = embed[embed_key]
else:
values = embed.values()
embed_out = next(iter(values))
return embed_out

class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
Expand All @@ -315,17 +317,18 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
self.max_word_length = 8
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
self.embedding_key = embedding_key

def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size)
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size)
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):])
return (embed, "")

Expand Down
2 changes: 1 addition & 1 deletion comfy/sdxl_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def load_sd(self, sd):

class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')


class SDXLTokenizer(sd1_clip.SD1Tokenizer):
Expand Down

0 comments on commit 606a537

Please sign in to comment.