From 0866d2d4397278bf605eeee17a7df31576edb744 Mon Sep 17 00:00:00 2001 From: Delirious <36864043+deepdelirious@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:45:27 -0400 Subject: [PATCH] Support embeddings --- library/sd3_train_utils.py | 2 +- sd3_train.py | 26 ++++++++++++++++++++++++++ sd3_train_network.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 054d1b4a1..4f18db87d 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -550,7 +550,7 @@ def encode_prompt(prpt): if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: text_encoder_conds = sample_prompts_te_outputs[prpt] print(f"Using cached text encoder outputs for prompt: {prpt}") - if text_encoders is not None: + elif text_encoders is not None: print(f"Encoding prompt: {prpt}") tokens_and_masks = tokenize_strategy.tokenize(prpt) # strategy has apply_t5_attn_mask option diff --git a/sd3_train.py b/sd3_train.py index cdac945e6..d4604b040 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -42,6 +42,8 @@ ) from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments +from safetensors.torch import load_file + # from library.custom_train_functions import ( # apply_snr_weight, # prepare_scheduler_for_custom_training, @@ -227,6 +229,18 @@ def train(args): # load tokenizer and prepare tokenize strategy sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) + + def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key): + embed_state_dict = load_file(embed_file) + if not embed_key in embed_state_dict: + raise Exception(f"{embed_key} not found in {embed_file}") + tokenizer.add_tokens(placeholder) + index = tokenizer.convert_tokens_to_ids(placeholder) + if (model.get_input_embeddings().num_embeddings <= len(tokenizer)): + model.resize_token_embeddings(len(tokenizer)) + logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}") + model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key] + logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}") # load clip_l, clip_g, t5xxl for caching text encoder outputs # clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) @@ -234,6 +248,13 @@ def train(args): clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + + if args.additional_embedding: + for placeholder, embed_file in args.additional_embedding: + inject_embedding(clip_l, sd3_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l") + inject_embedding(clip_g, sd3_tokenize_strategy.clip_g, placeholder, embed_file, "clip_g") + inject_embedding(t5xxl, sd3_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl") + assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" # prepare text encoding strategy @@ -1178,6 +1199,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", ) + parser.add_argument( + "--additional_embedding", + action="append", + nargs=2 + ) return parser diff --git a/sd3_train_network.py b/sd3_train_network.py index 3506404ae..20aeea28a 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -17,6 +17,7 @@ setup_logging() import logging +from safetensors.torch import load_file logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.tokenizer_cache = {} def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -87,12 +89,31 @@ def load_target_model(self, args, weight_dtype, accelerator): loading_dtype = None # as is else: loading_dtype = weight_dtype + + def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key): + embed_state_dict = load_file(embed_file) + if not embed_key in embed_state_dict: + raise Exception(f"{embed_key} not found in {embed_file}") + tokenizer.add_tokens(placeholder) + index = tokenizer.convert_tokens_to_ids(placeholder) + if (model.get_input_embeddings().num_embeddings <= len(tokenizer)): + model.resize_token_embeddings(len(tokenizer)) + logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}") + model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key] + logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}") # loading t5xxl to cpu takes a long time, so we should load to gpu in future t5xxl = sd3_utils.load_t5xxl( args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict ) t5xxl.eval() + + if args.additional_embedding: + for placeholder, embed_file in args.additional_embedding: + inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l") + inject_embedding(clip_g, self.get_tokenize_strategy(args).clip_g, placeholder, embed_file, "clip_g") + inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl") + if args.fp8_base and not args.fp8_base_unet: # check dtype of model if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: @@ -107,8 +128,13 @@ def load_target_model(self, args, weight_dtype, accelerator): return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit def get_tokenize_strategy(self, args): + cache_key = "sd3_" + str(args.t5xxl_max_token_length) + if cache_key in self.tokenizer_cache: + return self.tokenizer_cache[cache_key] logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") - return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + tokenizer = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + self.tokenizer_cache[cache_key] = tokenizer + return tokenizer def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] @@ -426,6 +452,11 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() sd3_train_utils.add_sd3_training_arguments(parser) + parser.add_argument( + "--additional_embedding", + action="append", + nargs=2 + ) return parser