From e6402a8295664153a7cc1ea3f187ef4ea6c05ed1 Mon Sep 17 00:00:00 2001 From: Yur Name Date: Sat, 14 Oct 2023 16:28:28 -0700 Subject: [PATCH] Reverts changes to restore old DIMM --- modules/processing.py | 1 + modules/sd_samplers.py | 3 +- modules/sd_samplers_compvis.py | 232 +++++++++++++++++++++++++++++++ modules/sd_samplers_timesteps.py | 12 +- 4 files changed, 239 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..699f6361f43 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -217,6 +217,7 @@ def __post_init__(self): self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf') self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise + self.ddim_discretize = self.ddim_discretize or opts.ddim_discretize self.extra_generation_params = self.extra_generation_params or {} self.override_settings = self.override_settings or {} diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 45faae62821..ec0699d09e9 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,10 +1,11 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, + *sd_samplers_compvis.samplers_data_compvis, *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index e69de29bb2d..c3752215ac0 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -0,0 +1,232 @@ +import math +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms + +import numpy as np +import torch + +from modules.shared import state +from modules import sd_samplers_common, prompt_parser, shared +import modules.models.diffusion.uni_pc + + +samplers_data_compvis = [ + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), +] + + +class VanillaStableDiffusionSampler: + def __init__(self, constructor, sd_model): + self.p = None + self.sampler = constructor(shared.sd_model) + self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') + self.is_plms = hasattr(self.sampler, 'p_sample_plms') + self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) + self.orig_p_sample_ddim = None + if self.is_plms: + self.orig_p_sample_ddim = self.sampler.p_sample_plms + elif self.is_ddim: + self.orig_p_sample_ddim = self.sampler.p_sample_ddim + self.mask = None + self.nmask = None + self.init_latent = None + self.sampler_noises = None + self.steps = None + self.step = 0 + self.stop_at = None + self.eta = None + self.config = None + self.last_latent = None + + self.conditioning_key = sd_model.model.conditioning_key + + def number_of_needed_noises(self, p): + return 0 + + def launch_sampling(self, steps, func): + self.steps = steps + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except sd_samplers_common.InterruptedException: + return self.last_latent + + def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) + + res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs) + + x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) + + return res + + def update_inner_model(self): + self.sampler.model = shared.sd_model + + def before_sample(self, x, ts, cond, unconditional_conditioning): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + if self.stop_at is not None and self.step > self.stop_at: + raise sd_samplers_common.InterruptedException + + # Have to unwrap the inpainting conditioning here to perform pre-processing + image_conditioning = None + uc_image_conditioning = None + if isinstance(cond, dict): + if self.conditioning_key == "crossattn-adm": + image_conditioning = cond["c_adm"] + uc_image_conditioning = unconditional_conditioning["c_adm"] + else: + image_conditioning = cond["c_concat"][0] + cond = cond["c_crossattn"][0] + unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + + assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers' + cond = tensor + + # for DDIM, shapes must match, we can't just process cond and uncond independently; + # filling unconditional_conditioning with repeats of the last vector to match length is + # not 100% correct but should work well enough + if unconditional_conditioning.shape[1] < cond.shape[1]: + last_vector = unconditional_conditioning[:, -1:] + last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) + unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) + elif unconditional_conditioning.shape[1] > cond.shape[1]: + unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] + + if self.mask is not None: + img_orig = self.sampler.model.q_sample(self.init_latent, ts) + x = img_orig * self.mask + self.nmask * x + + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. + if image_conditioning is not None: + if self.conditioning_key == "crossattn-adm": + cond = {"c_adm": image_conditioning, "c_crossattn": [cond]} + unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]} + else: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + return x, ts, cond, unconditional_conditioning + + def update_step(self, last_latent): + if self.mask is not None: + self.last_latent = self.init_latent * self.mask + self.nmask * last_latent + else: + self.last_latent = last_latent + + sd_samplers_common.store_latent(self.last_latent) + + self.step += 1 + state.sampling_step = self.step + shared.total_tqdm.update() + + def after_sample(self, x, ts, cond, uncond, res): + if not self.is_unipc: + self.update_step(res[1]) + + return x, ts, cond, uncond, res + + def unipc_after_update(self, x, model_x): + self.update_step(x) + + def initialize(self, p): + self.p = p + + if self.is_ddim: + self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim + else: + self.eta = 0.0 + + if self.eta != 0.0: + p.extra_generation_params["Eta DDIM"] = self.eta + + if self.is_unipc: + keys = [ + ('UniPC variant', 'uni_pc_variant'), + ('UniPC skip type', 'uni_pc_skip_type'), + ('UniPC order', 'uni_pc_order'), + ('UniPC lower order final', 'uni_pc_lower_order_final'), + ] + + for name, key in keys: + v = getattr(shared.opts, key) + if v != shared.opts.get_default(key): + p.extra_generation_params[name] = v + + for fieldname in ['p_sample_ddim', 'p_sample_plms']: + if hasattr(self.sampler, fieldname): + setattr(self.sampler, fieldname, self.p_sample_ddim_hook) + if self.is_unipc: + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) + + self.mask = p.mask if hasattr(p, 'mask') else None + self.nmask = p.nmask if hasattr(p, 'nmask') else None + + + def adjust_steps_if_invalid(self, p, num_steps): + if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): + if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: + num_steps = shared.opts.uni_pc_order + valid_step = 999 / (1000 // num_steps) + if valid_step == math.floor(valid_step): + return int(valid_step) + 1 + + return num_steps + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + steps = self.adjust_steps_if_invalid(p, steps) + self.initialize(p) + + self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) + x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) + + self.init_latent = x + self.last_latent = x + self.step = 0 + + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + if self.conditioning_key == "crossattn-adm": + conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]} + else: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + self.initialize(p) + + self.init_latent = None + self.last_latent = x + self.step = 0 + + steps = self.adjust_steps_if_invalid(p, steps or p.steps) + + # Wrap the conditioning models with additional image conditioning for inpainting model + # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape + if image_conditioning is not None: + if self.conditioning_key == "crossattn-adm": + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)} + else: + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} + + samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) + + return samples_ddim \ No newline at end of file diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c2b..f0fb93d972a 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -9,9 +9,9 @@ import modules.shared as shared samplers_timesteps = [ - ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), - ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), - ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), + ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('k_PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), + ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), ] @@ -160,8 +160,4 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True - return samples - - -sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__] -VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions + return samples \ No newline at end of file