From 6e4aeb2da78ba48c519367608a61bf47ea6249b4 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 27 Aug 2024 15:32:55 -0400 Subject: [PATCH] Fix channel padding with new Comfy core API (#106) * Fix channel padding with new Comfy core API * nit --- layered_diffusion.py | 81 +++++++------------------------------------- 1 file changed, 12 insertions(+), 69 deletions(-) diff --git a/layered_diffusion.py b/layered_diffusion.py index f18a158..5ffc79b 100644 --- a/layered_diffusion.py +++ b/layered_diffusion.py @@ -1,7 +1,6 @@ import os from enum import Enum import torch -import functools import copy from typing import Optional, List from dataclasses import dataclass @@ -31,73 +30,6 @@ load_layer_model_state_dict = load_torch_file -# ------------ Start patching ComfyUI ------------ -def calculate_weight_adjust_channel(func): - """Patches ComfyUI's LoRA weight application to accept multi-channel inputs.""" - - @functools.wraps(func) - def calculate_weight( - self: ModelPatcher, patches, weight: torch.Tensor, key: str - ) -> torch.Tensor: - weight = func(self, patches, weight, key) - - for p in patches: - alpha = p[0] - v = p[1] - - # The recursion call should be handled in the main func call. - if isinstance(v, list): - continue - - if len(v) == 1: - patch_type = "diff" - elif len(v) == 2: - patch_type = v[0] - v = v[1] - - if patch_type == "diff": - w1 = v[0] - if all( - ( - alpha != 0.0, - w1.shape != weight.shape, - w1.ndim == weight.ndim == 4, - ) - ): - new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] - print( - f"Merged with {key} channel changed from {weight.shape} to {new_shape}" - ) - new_diff = alpha * comfy.model_management.cast_to_device( - w1, weight.device, weight.dtype - ) - new_weight = torch.zeros(size=new_shape).to(weight) - new_weight[ - : weight.shape[0], - : weight.shape[1], - : weight.shape[2], - : weight.shape[3], - ] = weight - new_weight[ - : new_diff.shape[0], - : new_diff.shape[1], - : new_diff.shape[2], - : new_diff.shape[3], - ] += new_diff - new_weight = new_weight.contiguous().clone() - weight = new_weight - return weight - - return calculate_weight - - -ModelPatcher.calculate_weight = calculate_weight_adjust_channel( - ModelPatcher.calculate_weight -) - -# ------------ End patching ComfyUI ------------ - - class LayeredDiffusionDecode: """ Decode alpha channel value from pixel value. @@ -323,8 +255,19 @@ def apply_layered_diffusion( model_dir=layer_model_root, file_name=self.model_file_name, ) + def pad_diff_weight(v): + if len(v) == 1: + return ("diff", [v[0], {"pad_weight": True}]) + elif len(v) == 2 and v[0] == "diff": + return ("diff", [v[1][0], {"pad_weight": True}]) + else: + return v + layer_lora_state_dict = load_layer_model_state_dict(model_path) - layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict) + layer_lora_patch_dict = { + k: pad_diff_weight(v) + for k, v in to_lora_patch_dict(layer_lora_state_dict).items() + } work_model = model.clone() work_model.add_patches(layer_lora_patch_dict, weight) return (work_model,)