Skip to content

Commit

Permalink
alternate implementation for unet forward replacement that does not d…
Browse files Browse the repository at this point in the history
…epend on hijack being applied
  • Loading branch information
AUTOMATIC1111 committed Dec 2, 2023
1 parent af5f073 commit ac02216
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
7 changes: 5 additions & 2 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None

ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)

sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)


def list_optimizers():
Expand Down
14 changes: 8 additions & 6 deletions modules/sd_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
unet_options = []
current_unet_option = None
current_unet = None
original_forward = None

original_forward = None # not used, only left temporarily for compatibility

def list_unets():
new_unets = script_callbacks.list_unets_callback()
Expand Down Expand Up @@ -84,9 +83,12 @@ def deactivate(self):
pass


def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
def create_unet_forward(original_forward):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)

return original_forward(self, x, timesteps, context, *args, **kwargs)

return original_forward(self, x, timesteps, context, *args, **kwargs)
return UNetModel_forward

0 comments on commit ac02216

Please sign in to comment.