From 0e93f5975105ee60225014bee87f02a8525a5d08 Mon Sep 17 00:00:00 2001 From: Delirious <36864043+deepdelirious@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:38:30 -0500 Subject: [PATCH] Allow blockwise LR --- flux_train.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index 71750b725..772581904 100644 --- a/flux_train.py +++ b/flux_train.py @@ -321,12 +321,31 @@ def train(args): ae.to(accelerator.device, dtype=weight_dtype) training_models = [] - params_to_optimize = [] + params_to_optimize = {} training_models.append(flux) name_and_params = list(flux.named_parameters()) # single param group for now - params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) param_names = [[n for n, _ in name_and_params]] + + def add_param_with_lr(group_name, params, lr): + if group_name not in params_to_optimize: + params_to_optimize[group_name] = {"params": [], "lr": lr} + params_to_optimize[group_name]['params'].append(params) + + def process_param(name, params): + if args.blockwise_lr: + for prefix, lr in args.blockwise_lr: + lr = float(lr) + if name.startswith(prefix): + logger.info(f"LR override ({lr}) for block {name}") + add_param_with_lr(prefix, params, lr) + return + add_param_with_lr("default", params, args.learning_rate) + + for name, params in name_and_params: + process_param(name, params) + + params_to_optimize = params_to_optimize.values() # calculate number of trainable parameters n_params = 0 @@ -857,6 +876,11 @@ def setup_parser() -> argparse.ArgumentParser: action="append", nargs=2 ) + parser.add_argument( + "--blockwise_lr", + action="append", + nargs=2 + ) parser.add_argument( "--ema", action="store_true"