diff --git a/flux_train_network.py b/flux_train_network.py index 65b121e7c..fce62ddc3 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -16,6 +16,7 @@ setup_logging() import logging +from safetensors.torch import load_file logger = logging.getLogger(__name__) @@ -24,6 +25,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.tokenizer_cache = {} def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -82,17 +84,35 @@ def load_target_model(self, args, weight_dtype, accelerator): model = self.prepare_split_model(model, weight_dtype, accelerator) clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - clip_l.eval() - # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) if args.fp8_base and not args.fp8_base_unet: loading_dtype = None # as is else: loading_dtype = weight_dtype + def inject_embedding(model, tokenizer, placeholder, embed_file, embed_key): + embed_state_dict = load_file(embed_file) + if not embed_key in embed_state_dict: + raise Exception(f"{embed_key} not found in {embed_file}") + tokenizer.add_tokens(placeholder) + index = tokenizer.convert_tokens_to_ids(placeholder) + if (model.get_input_embeddings().num_embeddings <= len(tokenizer)): + model.resize_token_embeddings(len(tokenizer)) + logger.info(f"Expanded model embeddings to : {model.get_input_embeddings().num_embeddings}") + model.get_input_embeddings().weight.data[index] = embed_state_dict[embed_key] + logger.info(f"Added custom embedding for {placeholder} to {embed_key} as token {index}") + # loading t5xxl to cpu takes a long time, so we should load to gpu in future t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if args.additional_embedding: + for placeholder, embed_file in args.additional_embedding: + inject_embedding(clip_l, self.get_tokenize_strategy(args).clip_l, placeholder, embed_file, "clip_l") + inject_embedding(t5xxl, self.get_tokenize_strategy(args).t5xxl, placeholder, embed_file, "t5xxl") + + clip_l.eval() t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: # check dtype of model if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: @@ -142,7 +162,11 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): + name = self.get_flux_model_name(args) + cache_key = name + "_" + str(args.t5xxl_max_token_length) + if cache_key in self.tokenizer_cache: + return self.tokenizer_cache[cache_key] if args.t5xxl_max_token_length is None: if name == "schnell": @@ -153,7 +177,9 @@ def get_tokenize_strategy(self, args): t5xxl_max_token_length = args.t5xxl_max_token_length logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") - return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + tokenizer = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + self.tokenizer_cache[cache_key] = tokenizer + return tokenizer def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] @@ -508,6 +534,11 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) + parser.add_argument( + "--additional_embedding", + action="append", + nargs=2 + ) return parser