Skip to content

Commit

Permalink
Support embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Oct 28, 2024
1 parent c5f8bcc commit 0866d2d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
2 changes: 1 addition & 1 deletion library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def encode_prompt(prpt):
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
elif text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
Expand Down
26 changes: 26 additions & 0 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments

from safetensors.torch import load_file

# from library.custom_train_functions import (
# apply_snr_weight,
# prepare_scheduler_for_custom_training,
Expand Down Expand Up @@ -227,13 +229,32 @@ def train(args):
# load tokenizer and prepare tokenize strategy
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# load clip_l, clip_g, t5xxl for caching text encoder outputs
# clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
# clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict)
clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict)
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict)

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, sd3_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l")
inject_embedding(clip_g, sd3_tokenize_strategy.clip_g, placeholder, embed_file, "clip_g")
inject_embedding(t5xxl, sd3_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl")

assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified"

# prepare text encoding strategy
Expand Down Expand Up @@ -1178,6 +1199,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する",
)
parser.add_argument(
"--additional_embedding",
action="append",
nargs=2
)
return parser


Expand Down
33 changes: 32 additions & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

setup_logging()
import logging
from safetensors.torch import load_file

logger = logging.getLogger(__name__)

Expand All @@ -25,6 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.tokenizer_cache = {}

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -87,12 +89,31 @@ def load_target_model(self, args, weight_dtype, accelerator):
loading_dtype = None # as is
else:
loading_dtype = weight_dtype

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l")
inject_embedding(clip_g, self.get_tokenize_strategy(args).clip_g, placeholder, embed_file, "clip_g")
inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl")

if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
Expand All @@ -107,8 +128,13 @@ def load_target_model(self, args, weight_dtype, accelerator):
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit

def get_tokenize_strategy(self, args):
cache_key = "sd3_" + str(args.t5xxl_max_token_length)
if cache_key in self.tokenizer_cache:
return self.tokenizer_cache[cache_key]
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
tokenizer = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
self.tokenizer_cache[cache_key] = tokenizer
return tokenizer

def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
Expand Down Expand Up @@ -426,6 +452,11 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch,
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
sd3_train_utils.add_sd3_training_arguments(parser)
parser.add_argument(
"--additional_embedding",
action="append",
nargs=2
)
return parser


Expand Down

0 comments on commit 0866d2d

Please sign in to comment.