Skip to content

Commit

Permalink
Support embeddings in flux_train_network
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Oct 6, 2024
1 parent 7d88bd7 commit 0415eb8
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

setup_logging()
import logging
from safetensors.torch import load_file

logger = logging.getLogger(__name__)

Expand All @@ -24,6 +25,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.tokenizer_cache = {}

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -82,17 +84,35 @@ def load_target_model(self, args, weight_dtype, accelerator):
model = self.prepare_split_model(model, weight_dtype, accelerator)

clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()

# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype

def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key):
embed_state_dict = load_file(embed_file)
if not embed_key in embed_state_dict:
raise Exception(f"{embed_key} not found in {embed_file}")
tokenizer.add_tokens(placeholder)
index = tokenizer.convert_tokens_to_ids(placeholder)
if (model.get_input_embeddings().num_embeddings <= len(tokenizer)):
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}")
model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key]
logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}")

# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

if args.additional_embedding:
for placeholder, embed_file in args.additional_embedding:
inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l")
inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl")

clip_l.eval()
t5xxl.eval()

if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
Expand Down Expand Up @@ -142,7 +162,11 @@ def prepare_split_model(self, model, weight_dtype, accelerator):
return flux_lower

def get_tokenize_strategy(self, args):

name = self.get_flux_model_name(args)
cache_key = name + "_" + str(args.t5xxl_max_token_length)
if cache_key in self.tokenizer_cache:
return self.tokenizer_cache[cache_key]

if args.t5xxl_max_token_length is None:
if name == "schnell":
Expand All @@ -153,7 +177,9 @@ def get_tokenize_strategy(self, args):
t5xxl_max_token_length = args.t5xxl_max_token_length

logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
tokenizer = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
self.tokenizer_cache[cache_key] = tokenizer
return tokenizer

def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
Expand Down Expand Up @@ -508,6 +534,11 @@ def setup_parser() -> argparse.ArgumentParser:
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
)
parser.add_argument(
"--additional_embedding",
action="append",
nargs=2
)
return parser


Expand Down

0 comments on commit 0415eb8

Please sign in to comment.