Skip to content

Commit

Permalink
add grad_hook after restore state closes kohya-ss#1344
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 12, 2024
1 parent 22413a5 commit 56bb81c
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
Expand Down Expand Up @@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor):
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
Expand Down Expand Up @@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

# For --sample_at_first
sdxl_train_util.sample_images(
Expand Down

0 comments on commit 56bb81c

Please sign in to comment.