From b56d5f7801dea45cdbbba8498544e8d2853ad6d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 21:35:39 +0900 Subject: [PATCH] add experimental option to fuse params to optimizer groups --- sdxl_train.py | 114 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 3b28575ed..c7eea2224 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + if args.fused_optimizer_groups: + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + param_group.append(p) + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + print(len(grouped_params)) + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad() + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -630,11 +711,13 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) - if not args.fused_backward_pass: + if not (args.fused_backward_pass or args.fused_optimizer_groups): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -642,9 +725,14 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser