Skip to content

Commit

Permalink
In place copy EMA rather than moving it
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Jan 1, 2025
1 parent 080ae73 commit 65c8b62
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,15 @@ def sample_images(
controlnet
)
if ema:
device = flux.device
flux.to("cpu")
ema.to(device)
model_params = [param.detach().cpu().clone() for _, param in ema.get_params_iter(flux)]
for (_, model_param), (_, ema_param) in zip(ema.get_params_iter(flux), ema.get_params_iter(ema.ema_model)):
ema_param = ema_param.to(model_param.device)
model_param.copy_(ema_param)
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
ema,
flux,
text_encoders,
ae,
save_dir,
Expand All @@ -125,10 +126,10 @@ def sample_images(
controlnet,
file_suffix = "_ema"
)
ema.to("cpu")
flux.to(device)
with torch.cuda.device(device):
torch.cuda.empty_cache()
for (_, model_param), original_model_param in zip(ema.get_params_iter(flux), model_params):
original_model_param = original_model_param.to(model_param.device)
model_param.data.copy_(original_model_param.data)
model_params = None
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
Expand All @@ -154,14 +155,15 @@ def sample_images(
controlnet
)
if ema:
device = flux.device
flux.to("cpu")
ema.to(device)
model_params = [param.detach().cpu().clone() for _, param in ema.get_params_iter(flux)]
for (_, model_param), (_, ema_param) in zip(ema.get_params_iter(flux), ema.get_params_iter(ema.ema_model)):
ema_param = ema_param.to(model_param.device)
model_param.copy_(ema_param)
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
ema,
flux,
text_encoders,
ae,
save_dir,
Expand All @@ -173,10 +175,10 @@ def sample_images(
controlnet,
file_suffix = "_ema"
)
ema.to("cpu")
flux.to(device)
with torch.cuda.device(device):
torch.cuda.empty_cache()
for (_, model_param), original_model_param in zip(ema.get_params_iter(flux), model_params):
original_model_param = original_model_param.to(model_param.device)
model_param.data.copy_(original_model_param.data)
model_params = None

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down

0 comments on commit 65c8b62

Please sign in to comment.