From d875cda565171407e1e2dc087fb5c5140359c6ec Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sat, 8 Jun 2024 22:11:11 -0400 Subject: [PATCH 1/2] Fix sdxl inpaint --- modules/processing.py | 4 ++-- modules/sd_models.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 0ff6a45c0c5..dc538272116 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,7 +115,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - if getattr(sd_model.model, "is_sdxl_inpaint", False): + if sd_model.is_sdxl_inpaint: # The "masked-image" in this case will just be all 0.5 since the entire image is masked. image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 image_conditioning = images_tensor_to_samples(image_conditioning, @@ -389,7 +389,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False): + if self.sampler.model_wrap.inner_model.is_sdxl_inpaint: return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. diff --git a/modules/sd_models.py b/modules/sd_models.py index 61bd15d8f05..93ff6c5fe9e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -386,13 +386,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() - # Set is_sdxl_inpaint flag. - diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None) - model.is_sdxl_inpaint = ( - model.is_sdxl and - diffusion_model_input is not None and - diffusion_model_input.shape[1] == 9 - ) if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -408,6 +401,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer del state_dict + # Set is_sdxl_inpaint flag. + # Perform this check after model initialization to make sure state_dict + # structure is already known. + diffusion_model_input = model.model.state_dict().get( + 'diffusion_model.input_blocks.0.0.weight' + ) + model.is_sdxl_inpaint = ( + model.is_sdxl and + diffusion_model_input is not None and + diffusion_model_input.shape[1] == 9 + ) + if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) timer.record("apply channels_last") From f89b5dbbd282091fd6b3318f3ef20cf23cf9ea3a Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sat, 8 Jun 2024 22:15:37 -0400 Subject: [PATCH 2/2] nit --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 93ff6c5fe9e..af35187cdb0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -402,8 +402,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer del state_dict # Set is_sdxl_inpaint flag. - # Perform this check after model initialization to make sure state_dict - # structure is already known. + # Checks Unet structure to detect inpaint model. The inpaint model's + # checkpoint state_dict does not contain the key + # 'diffusion_model.input_blocks.0.0.weight'. diffusion_model_input = model.model.state_dict().get( 'diffusion_model.input_blocks.0.0.weight' )