Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't save incomplete images #12338

Merged
merged 1 commit into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)


def save_images_if_interrupted():
return not (opts.dont_save_interrupted_images and (state.interrupted or state.skipped))


class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
Expand Down Expand Up @@ -821,14 +825,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(index=0, use_main_prompt=False):
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)

save_images_if_interrupt = save_images_if_interrupted()

for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i

x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration and save_images_if_interrupt:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")

devices.torch_gc()
Expand All @@ -842,25 +848,23 @@ def infotext(index=0, use_main_prompt=False):
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image

if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction and save_images_if_interrupt:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)

image = apply_overlay(image, p.paste_to, i, p.overlay_images)

if opts.samples_save and not p.do_not_save_samples:
if opts.samples_save and not p.do_not_save_samples and save_images_if_interrupt:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)

text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)

if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]) and save_images_if_interrupt:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')

Expand Down Expand Up @@ -896,7 +900,6 @@ def infotext(index=0, use_main_prompt=False):
grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1

if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)

Expand Down Expand Up @@ -1091,7 +1094,7 @@ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_stre
def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix or not save_images_if_interrupted():
return

if not isinstance(image, Image.Image):
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def list_samplers():
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),

"dont_save_interrupted_images": OptionInfo(False, "Don't save incomplete images").info("Don't save images that has been interrupted in mid-generation, they will still show up in webui output."),
}))

options_templates.update(options_section(('saving-paths', "Paths for saving"), {
Expand Down