diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 96dd7e9484d..8b59cfbdc14 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -21,7 +21,7 @@ def __init__(self, scale_factor=0.18215): class SDXL(LatentFormat): def __init__(self): self.scale_factor = 0.13025 - self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL + self.latent_rgb_factors = [ # R G B [ 0.3920, 0.4054, 0.4549], [-0.2634, -0.0196, 0.0653], diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 11f26bb8f02..fbd87c5696c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -233,7 +233,7 @@ def expand_directory_list(directories): dirs.add(root) return list(dirs) -def load_embed(embedding_name, embedding_directory, embedding_size): +def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] @@ -292,13 +292,15 @@ def load_embed(embedding_name, embedding_directory, embedding_size): continue out_list.append(t.reshape(-1, t.shape[-1])) embed_out = torch.cat(out_list, dim=0) + elif embed_key is not None and embed_key in embed: + embed_out = embed[embed_key] else: values = embed.values() embed_out = next(iter(values)) return embed_out class SD1Tokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) @@ -315,17 +317,18 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd self.max_word_length = 8 self.embedding_identifier = "embedding:" self.embedding_size = embedding_size + self.embedding_key = embedding_key def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. ''' - embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size) + embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) if embed is None: stripped = embedding_name.strip(',') if len(stripped) < len(embedding_name): - embed = load_embed(stripped, self.embedding_directory, self.embedding_size) + embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) return (embed, embedding_name[len(stripped):]) return (embed, "") diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index d9298b20545..d0803b10bf0 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -41,7 +41,7 @@ def load_sd(self, sd): class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') class SDXLTokenizer(sd1_clip.SD1Tokenizer):