Skip to content

Commit

Permalink
Allow blockwise LR
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Dec 2, 2024
1 parent dcb8a85 commit 5114bcb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
28 changes: 26 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,31 @@ def train(args):
ae.to(accelerator.device, dtype=weight_dtype)

training_models = []
params_to_optimize = []
params_to_optimize = {}
training_models.append(flux)
name_and_params = list(flux.named_parameters())
# single param group for now
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
param_names = [[n for n, _ in name_and_params]]

def add_param_with_lr(group_name, params, lr):
if group_name not in params_to_optimize:
params_to_optimize[group_name] = {"params": [], "lr": lr}
params_to_optimize[group_name]['params'].append(params)

def process_param(name, params):
if args.blockwise_lr:
for prefix, lr in args.blockwise_lr:
lr = float(lr)
if name.startswith(prefix):
logger.info(f"LR override ({lr}) for block {name}")
add_param_with_lr(prefix, params, lr)
return
add_param_with_lr("default", params, args.learning_rate)

for name, params in name_and_params:
process_param(name, params)

params_to_optimize = params_to_optimize.values()

# calculate number of trainable parameters
n_params = 0
Expand Down Expand Up @@ -857,6 +876,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="append",
nargs=2
)
parser.add_argument(
"--blockwise_lr",
action="append",
nargs=2
)
parser.add_argument(
"--ema",
action="store_true"
Expand Down
2 changes: 1 addition & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5939,7 +5939,7 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
lrs = lr_scheduler.get_last_lr()

for lr_index in range(len(lrs)):
name = names[lr_index]
name = names[lr_index] if lr_index < len(names) else str(lr_index)
logs["lr/" + name] = float(lrs[lr_index])

if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
Expand Down

0 comments on commit 5114bcb

Please sign in to comment.