From 75833e84a1c7e3c2fb0a9e3ce0fe3d8c1758a012 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 8 Apr 2024 19:23:02 -0400 Subject: [PATCH] Fix default LR, Add overall LoRA+ ratio, Add log `--loraplus_ratio` added for both TE and UNet Add log for lora+ --- library/train_util.py | 1 + networks/dylora.py | 24 ++++++------- networks/lora.py | 28 ++++++++-------- networks/lora_fa.py | 30 ++++++++--------- train_network.py | 78 ++++++++++++++++++++++++++++++++----------- 5 files changed, 101 insertions(+), 60 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e5ab7370..7c2bf6935 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") diff --git a/networks/dylora.py b/networks/dylora.py index edc3e2229..dc5c7cb35 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,32 +412,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_B" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -452,7 +452,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -460,7 +460,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.unet_loras, default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index e082941e5..6cb05bcb0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,32 +1040,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1080,7 +1080,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1099,15 +1099,15 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 3f6774dd8..2eff86d6c 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,32 +1038,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: - for name, param in lora.get_trainable_named_params(): - if lora_plus_ratio is not None and "lora_up" in name: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1078,7 +1078,7 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1097,15 +1097,15 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index ba0c124d1..43226fc47 100644 --- a/train_network.py +++ b/train_network.py @@ -66,34 +66,69 @@ def generate_step_logs( lrs = lr_scheduler.get_last_lr() - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) - else: + if len(lrs) > 4: idx = 0 if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) idx = 1 for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) + lora_plus = "" + group_id = i + + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + lora_plus = '_lora+' if i % 2 == 1 else '' + group_id = int((i / 2) + (i % 2 + 0.5)) + + logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( + logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + else: + if args.network_train_text_encoder_only: + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + + elif args.network_train_unet_only: + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + else: + logs["lr/unet"] = float(lrs[0]) + else: + if len(lrs) == 2: + if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: + logs["lr/all"] = float(lrs[0]) + logs["lr/all_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) + elif len(lrs) == 4: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + logs["lr/unet"] = float(lrs[2]) + logs["lr/unet_lora+"] = float(lrs[3]) + else: + logs["lr/all"] = float(lrs[0]) + + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + return logs def assert_extra_args(self, args, train_dataset_group): @@ -339,7 +374,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -348,6 +383,11 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + assert ( + (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) + ), "LoRA+ and Prodigy/DAdaptation is not supported" + # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers