Skip to content

Commit

Permalink
In place copy EMA rather than moving it
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Dec 31, 2024
1 parent 080ae73 commit 2fe7628
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def sample_images(
controlnet
)
if ema:
device = flux.device
flux.to("cpu")
ema.to(device)
model_params = [param.detach().cpu().clone() for param in ema.get_params_iter(flux)]
for model_param, ema_param in zip(ema.get_params_iter(flux), zip(ema.get_params_iter(ema.ema_model))):
ema_param = ema_param.to(model_param.device)
model_param.copy_(ema_param)
for prompt_dict in prompts:
sample_image_inference(
accelerator,
Expand All @@ -125,10 +126,10 @@ def sample_images(
controlnet,
file_suffix = "_ema"
)
ema.to("cpu")
flux.to(device)
with torch.cuda.device(device):
torch.cuda.empty_cache()
for model_param, original_model_param in zip(ema.get_params_iter(flux), model_params):
original_model_param = original_model_param.to(model_param.device)
model_param.data.copy_(original_model_param.data)
model_params = None
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
Expand All @@ -154,9 +155,10 @@ def sample_images(
controlnet
)
if ema:
device = flux.device
flux.to("cpu")
ema.to(device)
model_params = [param.detach().cpu().clone() for param in ema.get_params_iter(flux)]
for model_param, ema_param in zip(ema.get_params_iter(flux), zip(ema.get_params_iter(ema.ema_model))):
ema_param = ema_param.to(model_param.device)
model_param.copy_(ema_param)
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
Expand All @@ -173,10 +175,10 @@ def sample_images(
controlnet,
file_suffix = "_ema"
)
ema.to("cpu")
flux.to(device)
with torch.cuda.device(device):
torch.cuda.empty_cache()
for model_param, original_model_param in zip(ema.get_params_iter(flux), model_params):
original_model_param = original_model_param.to(model_param.device)
model_param.data.copy_(original_model_param.data)
model_params = None

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down

0 comments on commit 2fe7628

Please sign in to comment.