Skip to content

Commit

Permalink
Fix default LR, Add overall LoRA+ ratio, Add log
Browse files Browse the repository at this point in the history
`--loraplus_ratio` added for both TE and UNet
Add log for lora+
  • Loading branch information
rockerBOO committed Apr 8, 2024
1 parent 1933ab4 commit 75833e8
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 60 deletions.
1 change: 1 addition & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")

Expand Down
24 changes: 12 additions & 12 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,32 +412,32 @@ def prepare_optimizer_params(
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def assemble_params(loras, lr, lora_plus_ratio):
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if lora_plus_ratio is not None and "lora_up" in name:
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

Expand All @@ -452,15 +452,15 @@ def assemble_params(loras, lr, lora_plus_ratio):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

Expand Down
28 changes: 14 additions & 14 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,32 +1040,32 @@ def prepare_optimizer_params(
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def assemble_params(loras, lr, lora_plus_ratio):
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if lora_plus_ratio is not None and "lora_up" in name:
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

Expand All @@ -1080,7 +1080,7 @@ def assemble_params(loras, lr, lora_plus_ratio):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

Expand All @@ -1099,15 +1099,15 @@ def assemble_params(loras, lr, lora_plus_ratio):
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_lora_plus_ratio
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

else:
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

Expand Down
30 changes: 15 additions & 15 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,32 +1038,32 @@ def prepare_optimizer_params(
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.get_trainable_named_params():
if lora_plus_ratio is not None and "lora_up" in name:
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

Expand All @@ -1078,7 +1078,7 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

Expand All @@ -1097,15 +1097,15 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_lora_plus_ratio
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

else:
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

Expand Down
78 changes: 59 additions & 19 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,69 @@ def generate_step_logs(

lrs = lr_scheduler.get_last_lr()

if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
if args.network_train_unet_only:
logs["lr/unet"] = float(lrs[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lrs[0])
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder

if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)
else:
if len(lrs) > 4:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1

for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
lora_plus = ""
group_id = i

if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
lora_plus = '_lora+' if i % 2 == 1 else ''
group_id = int((i / 2) + (i % 2 + 0.5))

logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{i}"] = (
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)

else:
if args.network_train_text_encoder_only:
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
else:
logs["lr/textencoder"] = float(lrs[0])

elif args.network_train_unet_only:
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
logs["lr/unet"] = float(lrs[0])
logs["lr/unet_lora+"] = float(lrs[1])
else:
logs["lr/unet"] = float(lrs[0])
else:
if len(lrs) == 2:
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
logs["lr/unet"] = float(lrs[0])
logs["lr/unet_lora+"] = float(lrs[1])
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
logs["lr/all"] = float(lrs[0])
logs["lr/all_lora+"] = float(lrs[1])
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1])
elif len(lrs) == 4:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
logs["lr/unet"] = float(lrs[2])
logs["lr/unet_lora+"] = float(lrs[3])
else:
logs["lr/all"] = float(lrs[0])

if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)

return logs

def assert_extra_args(self, args, train_dataset_group):
Expand Down Expand Up @@ -339,7 +374,7 @@ def train(self, args):

# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
Expand All @@ -348,6 +383,11 @@ def train(self, args):

optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)

if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
assert (
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
), "LoRA+ and Prodigy/DAdaptation is not supported"

# dataloaderを準備する
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
Expand Down

0 comments on commit 75833e8

Please sign in to comment.