Skip to content

Commit

Permalink
Allow blockwise LR
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Dec 2, 2024
1 parent dcb8a85 commit 0e93f59
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,31 @@ def train(args):
ae.to(accelerator.device, dtype=weight_dtype)

training_models = []
params_to_optimize = []
params_to_optimize = {}
training_models.append(flux)
name_and_params = list(flux.named_parameters())
# single param group for now
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
param_names = [[n for n, _ in name_and_params]]

def add_param_with_lr(group_name, params, lr):
if group_name not in params_to_optimize:
params_to_optimize[group_name] = {"params": [], "lr": lr}
params_to_optimize[group_name]['params'].append(params)

def process_param(name, params):
if args.blockwise_lr:
for prefix, lr in args.blockwise_lr:
lr = float(lr)
if name.startswith(prefix):
logger.info(f"LR override ({lr}) for block {name}")
add_param_with_lr(prefix, params, lr)
return
add_param_with_lr("default", params, args.learning_rate)

for name, params in name_and_params:
process_param(name, params)

params_to_optimize = params_to_optimize.values()

# calculate number of trainable parameters
n_params = 0
Expand Down Expand Up @@ -857,6 +876,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="append",
nargs=2
)
parser.add_argument(
"--blockwise_lr",
action="append",
nargs=2
)
parser.add_argument(
"--ema",
action="store_true"
Expand Down

0 comments on commit 0e93f59

Please sign in to comment.