Skip to content

Commit

Permalink
Support embeddings in sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Nov 18, 2024
1 parent 1e8c39a commit fdadf4d
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 4 deletions.
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

def get_tokenize_strategy(self, args):
name = self.get_flux_model_name(args)
name = "flux"
cache_key = name + "_" + str(args.t5xxl_max_token_length)
if cache_key in self.tokenizer_cache:
return self.tokenizer_cache[cache_key]
Expand Down
2 changes: 1 addition & 1 deletion library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def encode_prompt(prpt):
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
elif text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
Expand Down
26 changes: 26 additions & 0 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments

from safetensors.torch import load_file

# from library.custom_train_functions import (
# apply_snr_weight,
# prepare_scheduler_for_custom_training,
Expand Down Expand Up @@ -212,13 +214,32 @@ def train(args):
# load tokenizer and prepare tokenize strategy
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# load clip_l, clip_g, t5xxl for caching text encoder outputs
# clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
# clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict)
clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict)
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict)

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, sd3_tokenize_strategy.clip_l, placeholder, embed_file, "clip_l")
inject_embedding(clip_g, sd3_tokenize_strategy.clip_g, placeholder, embed_file, "clip_g")
inject_embedding(t5xxl, sd3_tokenize_strategy.t5xxl, placeholder, embed_file, "t5xxl")

assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified"

# prepare text encoding strategy
Expand Down Expand Up @@ -1061,6 +1082,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する",
)
parser.add_argument(
"--additional_embedding",
action="append",
nargs=2
)
return parser


Expand Down
35 changes: 34 additions & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

setup_logging()
import logging
from safetensors.torch import load_file

logger = logging.getLogger(__name__)

Expand All @@ -25,6 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.tokenizer_cache = {}

def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
# super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -113,12 +115,31 @@ def load_target_model(self, args, weight_dtype, accelerator):
loading_dtype = None # as is
else:
loading_dtype = weight_dtype

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l")
inject_embedding(clip_g, self.get_tokenize_strategy(args).clip_g, placeholder, embed_file, "clip_g")
inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl")

if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
Expand All @@ -133,8 +154,15 @@ def load_target_model(self, args, weight_dtype, accelerator):
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit

def get_tokenize_strategy(self, args):
cache_key = "sd3_" + str(args.t5xxl_max_token_length)
if cache_key in self.tokenizer_cache:
logger.info("Reusing tokenize strategy from cache")
return self.tokenizer_cache[cache_key]
logger.warning("Generating new tokenize strategy")
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
tokenizer = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
self.tokenizer_cache[cache_key] = tokenizer
return tokenizer

def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
Expand Down Expand Up @@ -466,6 +494,11 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
sd3_train_utils.add_sd3_training_arguments(parser)
parser.add_argument(
"--additional_embedding",
action="append",
nargs=2
)
return parser


Expand Down
2 changes: 1 addition & 1 deletion tools/convert_diffusers_to_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def convert(args):
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None

# make reverse map from diffusers map
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(num_double_blocks=flux_utils.NUM_DOUBLE_BLOCKS, num_single_blocks=flux_utils.NUM_SINGLE_BLOCKS)

# iterate over three safetensors files to reduce memory usage
flux_sd = {}
Expand Down

0 comments on commit fdadf4d

Please sign in to comment.