Skip to content

Commit

Permalink
EMA previews
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Nov 25, 2024
1 parent 9c32653 commit 34bdae5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
4 changes: 2 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def grad_hook(parameter: torch.Tensor):

optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, ema=ema
)

if ema:
Expand Down Expand Up @@ -794,7 +794,7 @@ def grad_hook(parameter: torch.Tensor):
del accelerator # この後メモリを使うのでこれは消す

if is_main_process:
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ema)
logger.info("model saved.")


Expand Down
54 changes: 51 additions & 3 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
ema=None,
):
if steps == 0:
if not args.sample_at_first:
Expand Down Expand Up @@ -100,6 +101,26 @@ def sample_images(
sample_prompts_te_outputs,
prompt_replacement,
)
if ema:
device = flux.device
flux.to("cpu")
ema.to(device)
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
ema,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
ema.to("cpu")
flux.to(device)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
Expand All @@ -123,6 +144,26 @@ def sample_images(
sample_prompts_te_outputs,
prompt_replacement,
)
if ema:
device = flux.device
flux.to("cpu")
ema.to(device)
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
ema,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
ema.to("cpu")
flux.to(device)

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down Expand Up @@ -307,7 +348,8 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if hasattr(model, "prepare_block_swap_before_forward"):
model.prepare_block_swap_before_forward()
pred = model(
img=img,
img_ids=img_ids,
Expand All @@ -320,8 +362,10 @@ def denoise(
)

img = img + (t_prev - t_curr) * pred

if hasattr(model, "prepare_block_swap_before_forward"):
model.prepare_block_swap_before_forward()

model.prepare_block_swap_before_forward()
return img


Expand Down Expand Up @@ -479,11 +523,15 @@ def update_sd(prefix, sd):


def save_flux_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux, ema: EMA = None
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
if ema:
filename, extension = os.path.splitext(ckpt_file)
ema_file = filename + "_ema" + extension
save_models(ema_file, ema.ema_model, sai_metadata, save_dtype, args.mem_eff_save)

train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)

Expand Down

0 comments on commit 34bdae5

Please sign in to comment.