From 2fe7628910b1ff6107670435ec20b6f2ea80b167 Mon Sep 17 00:00:00 2001 From: Delirious <36864043+deepdelirious@users.noreply.github.com> Date: Tue, 31 Dec 2024 17:12:04 -0500 Subject: [PATCH] In place copy EMA rather than moving it --- library/flux_train_utils.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index cb79513aa..e3fda2a31 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -106,9 +106,10 @@ def sample_images( controlnet ) if ema: - device = flux.device - flux.to("cpu") - ema.to(device) + model_params = [param.detach().cpu().clone() for param in ema.get_params_iter(flux)] + for model_param, ema_param in zip(ema.get_params_iter(flux), zip(ema.get_params_iter(ema.ema_model))): + ema_param = ema_param.to(model_param.device) + model_param.copy_(ema_param) for prompt_dict in prompts: sample_image_inference( accelerator, @@ -125,10 +126,10 @@ def sample_images( controlnet, file_suffix = "_ema" ) - ema.to("cpu") - flux.to(device) - with torch.cuda.device(device): - torch.cuda.empty_cache() + for model_param, original_model_param in zip(ema.get_params_iter(flux), model_params): + original_model_param = original_model_param.to(model_param.device) + model_param.data.copy_(original_model_param.data) + model_params = None else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. @@ -154,9 +155,10 @@ def sample_images( controlnet ) if ema: - device = flux.device - flux.to("cpu") - ema.to(device) + model_params = [param.detach().cpu().clone() for param in ema.get_params_iter(flux)] + for model_param, ema_param in zip(ema.get_params_iter(flux), zip(ema.get_params_iter(ema.ema_model))): + ema_param = ema_param.to(model_param.device) + model_param.copy_(ema_param) for prompt_dict in prompt_dict_lists[0]: sample_image_inference( accelerator, @@ -173,10 +175,10 @@ def sample_images( controlnet, file_suffix = "_ema" ) - ema.to("cpu") - flux.to(device) - with torch.cuda.device(device): - torch.cuda.empty_cache() + for model_param, original_model_param in zip(ema.get_params_iter(flux), model_params): + original_model_param = original_model_param.to(model_param.device) + model_param.data.copy_(original_model_param.data) + model_params = None torch.set_rng_state(rng_state) if cuda_rng_state is not None: