Skip to content

Commit

Permalink
Support multi-embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Nov 18, 2024
1 parent fdadf4d commit 6b3134e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 64 deletions.
20 changes: 20 additions & 0 deletions embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
from safetensors.torch import load_file

logger = logging.getLogger(__name__)

def inject_embedding(model, strategy, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
embed = embed_state_dict[embed_key]
placeholders = [f"{placeholder.replace(' ', '_')}_{i}" for i in range(0, len(embed))]
tokenizer.add_tokens(placeholders)
indexes = tokenizer.convert_tokens_to_ids(placeholders)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
for e, i in zip(embed, indexes):
model.get_input_embeddings().weight.data[i] = e
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token(s) {indexes}")
strategy.add_replacement(placeholder, " ".join(placeholders))
18 changes: 3 additions & 15 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments

from safetensors.torch import load_file
import embedding_utils

def train(args):
train_util.verify_training_args(args)
Expand Down Expand Up @@ -218,26 +218,14 @@ def train(args):
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# load clip_l, t5xxl for caching text encoder outputs
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, flux_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l")
inject_embedding(t5xxl, flux_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl")
embedding_utils.inject_embedding(clip_l, flux_tokenize_strategy, flux_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l")
embedding_utils.inject_embedding(t5xxl, flux_tokenize_strategy, flux_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl")

clip_l.eval()
t5xxl.eval()
Expand Down
19 changes: 5 additions & 14 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup_logging()
import logging
from safetensors.torch import load_file
import embedding_utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,25 +113,16 @@ def load_target_model(self, args, weight_dtype, accelerator):
else:
loading_dtype = weight_dtype

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")


# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

strategy = self.get_tokenize_strategy(args)
if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l")
inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl")
embedding_utils.inject_embedding(clip_l, strategy, strategy.clip_l, placeholder, embed_file, "clip_l")
embedding_utils.inject_embedding(t5xxl, strategy, strategy.t5xxl, placeholder, embed_file, "t5xxl")

clip_l.eval()
t5xxl.eval()
Expand Down
11 changes: 11 additions & 0 deletions library/strategy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def set_strategy(cls, strategy):
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy

def __init__(self):
self.replacements = []

def add_replacement(self, original, replacement):
self.replacements.append((original, replacement))

def _process_replacements(self, text):
for original, replacement in self.replacements:
text = [t.replace(original, replacement) for t in text]
return text

def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions library/strategy_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@

class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
super().__init__()
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)

def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text

text = self._process_replacements(text)
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")

t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]

return [l_tokens, t5_tokens, t5_attn_mask]


Expand Down
3 changes: 2 additions & 1 deletion library/strategy_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class Sd3TokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
super().__init__()
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
Expand All @@ -33,7 +34,7 @@ def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[st

def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text

text = self._process_replacements(text)
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
Expand Down
20 changes: 4 additions & 16 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments

from safetensors.torch import load_file
import embedding_utils

# from library.custom_train_functions import (
# apply_snr_weight,
Expand Down Expand Up @@ -215,18 +215,6 @@ def train(args):
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# load clip_l, clip_g, t5xxl for caching text encoder outputs
# clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
# clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
Expand All @@ -236,9 +224,9 @@ def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, sd3_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l")
inject_embedding(clip_g, sd3_tokenize_strategy.clip_g, placeholder, embed_file, "clip_g")
inject_embedding(t5xxl, sd3_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl")
embedding_utils.inject_embedding(clip_l, sd3_tokenize_strategy, sd3_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l")
embedding_utils.inject_embedding(clip_g, sd3_tokenize_strategy, sd3_tokenize_strategy.clip_g, placeholder, embed_file, "clip_g")
embedding_utils.inject_embedding(t5xxl, sd3_tokenize_strategy, sd3_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl")

assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified"

Expand Down
21 changes: 5 additions & 16 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup_logging()
import logging
from safetensors.torch import load_file
import embedding_utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,30 +115,19 @@ def load_target_model(self, args, weight_dtype, accelerator):
loading_dtype = None # as is
else:
loading_dtype = weight_dtype

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()

strategy = self.get_tokenize_strategy(args)
if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l")
inject_embedding(clip_g, self.get_tokenize_strategy(args).clip_g, placeholder, embed_file, "clip_g")
inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl")
embedding_utils.inject_embedding(clip_l, strategy, strategy.clip_l, placeholder, embed_file, "clip_l")
embedding_utils.inject_embedding(clip_g, strategy, strategy.clip_g, placeholder, embed_file, "clip_g")
embedding_utils.inject_embedding(t5xxl, strategy, strategy.t5xxl, placeholder, embed_file, "t5xxl")

if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
Expand Down

0 comments on commit 6b3134e

Please sign in to comment.