Skip to content

Commit

Permalink
Fix channel padding with new Comfy core API (#106)
Browse files Browse the repository at this point in the history
* Fix channel padding with new Comfy core API

* nit
  • Loading branch information
huchenlei authored Aug 27, 2024
1 parent 2cbfe39 commit 6e4aeb2
Showing 1 changed file with 12 additions and 69 deletions.
81 changes: 12 additions & 69 deletions layered_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from enum import Enum
import torch
import functools
import copy
from typing import Optional, List
from dataclasses import dataclass
Expand Down Expand Up @@ -31,73 +30,6 @@
load_layer_model_state_dict = load_torch_file


# ------------ Start patching ComfyUI ------------
def calculate_weight_adjust_channel(func):
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs."""

@functools.wraps(func)
def calculate_weight(
self: ModelPatcher, patches, weight: torch.Tensor, key: str
) -> torch.Tensor:
weight = func(self, patches, weight, key)

for p in patches:
alpha = p[0]
v = p[1]

# The recursion call should be handled in the main func call.
if isinstance(v, list):
continue

if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]

if patch_type == "diff":
w1 = v[0]
if all(
(
alpha != 0.0,
w1.shape != weight.shape,
w1.ndim == weight.ndim == 4,
)
):
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(
f"Merged with {key} channel changed from {weight.shape} to {new_shape}"
)
new_diff = alpha * comfy.model_management.cast_to_device(
w1, weight.device, weight.dtype
)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[
: weight.shape[0],
: weight.shape[1],
: weight.shape[2],
: weight.shape[3],
] = weight
new_weight[
: new_diff.shape[0],
: new_diff.shape[1],
: new_diff.shape[2],
: new_diff.shape[3],
] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
return weight

return calculate_weight


ModelPatcher.calculate_weight = calculate_weight_adjust_channel(
ModelPatcher.calculate_weight
)

# ------------ End patching ComfyUI ------------


class LayeredDiffusionDecode:
"""
Decode alpha channel value from pixel value.
Expand Down Expand Up @@ -323,8 +255,19 @@ def apply_layered_diffusion(
model_dir=layer_model_root,
file_name=self.model_file_name,
)
def pad_diff_weight(v):
if len(v) == 1:
return ("diff", [v[0], {"pad_weight": True}])
elif len(v) == 2 and v[0] == "diff":
return ("diff", [v[1][0], {"pad_weight": True}])
else:
return v

layer_lora_state_dict = load_layer_model_state_dict(model_path)
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict)
layer_lora_patch_dict = {
k: pad_diff_weight(v)
for k, v in to_lora_patch_dict(layer_lora_state_dict).items()
}
work_model = model.clone()
work_model.add_patches(layer_lora_patch_dict, weight)
return (work_model,)
Expand Down

0 comments on commit 6e4aeb2

Please sign in to comment.