Skip to content

Commit

Permalink
Fix for EMA saving
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Jan 5, 2025
1 parent d507597 commit 1e7b27a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def grad_hook(parameter: torch.Tensor):

optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, ema=ema
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, ema=(ema if not args.no_ema_sampling else None)
)

if ema:
Expand All @@ -779,7 +779,7 @@ def grad_hook(parameter: torch.Tensor):
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
ema if not args.no_ema_sampling else None
ema
)
optimizer_train_fn()

Expand Down Expand Up @@ -823,7 +823,7 @@ def grad_hook(parameter: torch.Tensor):
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
ema if not args.no_ema_sampling else None
ema
)

flux_train_utils.sample_images(
Expand Down

0 comments on commit 1e7b27a

Please sign in to comment.