diff --git a/flux_train.py b/flux_train.py index 645e04c0d..5daf148ee 100644 --- a/flux_train.py +++ b/flux_train.py @@ -760,7 +760,7 @@ def grad_hook(parameter: torch.Tensor): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, ema=ema + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, ema=(ema if not args.no_ema_sampling else None) ) if ema: @@ -779,7 +779,7 @@ def grad_hook(parameter: torch.Tensor): num_train_epochs, global_step, accelerator.unwrap_model(flux), - ema if not args.no_ema_sampling else None + ema ) optimizer_train_fn() @@ -823,7 +823,7 @@ def grad_hook(parameter: torch.Tensor): num_train_epochs, global_step, accelerator.unwrap_model(flux), - ema if not args.no_ema_sampling else None + ema ) flux_train_utils.sample_images(