Skip to content

Commit

Permalink
Add some more transformer hooks and move tomesd to comfy_extras.
Browse files Browse the repository at this point in the history
Tomesd now uses q instead of x to decide which tokens to merge because
it seems to give better results.
  • Loading branch information
comfyanonymous committed Jun 24, 2023
1 parent fa28d73 commit 0567694
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 28 deletions.
59 changes: 52 additions & 7 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from comfy import model_management
import comfy.ops

from . import tomesd

if model_management.xformers_enabled():
import xformers
import xformers.ops
Expand Down Expand Up @@ -519,23 +517,39 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head

def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)

def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = None
block_index = 0
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
extra_options["block_index"] = transformer_options["block_index"]
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}

extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head

if "patches_replace" in transformer_options:
transformer_patches_replace = transformer_options["patches_replace"]
else:
transformer_patches_replace = {}

n = self.norm1(x)
if self.disable_self_attn:
context_attn1 = context
Expand All @@ -551,12 +565,29 @@ def _forward(self, x, context=None, transformer_options={}):
for p in patch:
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)

if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
transformer_block = (block[0], block[1], block_index)
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
block_attn1 = transformer_block
if block_attn1 not in attn1_replace_patch:
block_attn1 = block

if block_attn1 in attn1_replace_patch:
if context_attn1 is None:
context_attn1 = n
value_attn1 = n
n = self.attn1.to_q(n)
context_attn1 = self.attn1.to_k(context_attn1)
value_attn1 = self.attn1.to_v(value_attn1)
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)

if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
n = p(n, extra_options)

x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
Expand All @@ -573,7 +604,21 @@ def _forward(self, x, context=None, transformer_options={}):
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)

n = self.attn2(n, context=context_attn2, value=value_attn2)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block

if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)

if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
Expand Down
5 changes: 4 additions & 1 deletion comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,17 +830,20 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo

h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()

for module in self.output_blocks:
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
Expand Down
27 changes: 24 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,6 @@ def clone(self):
n.model_keys = self.model_keys
return n

def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}

def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
Expand All @@ -330,12 +327,29 @@ def set_model_patch(self, patch, name):
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]

def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch

def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")

def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")

def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)

def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)

def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")

def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")

Expand All @@ -348,6 +362,13 @@ def model_patches_to(self, device):
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)

def model_dtype(self):
return self.model.get_dtype()
Expand Down
33 changes: 33 additions & 0 deletions comfy/ldm/modules/tomesd.py → comfy_extras/nodes_tomesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,36 @@ def get_functions(x, ratio, original_shape):

nothing = lambda y: y
return nothing, nothing



class TomePatchModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"

def patch(self, model, ratio):
self.u = None
def tomesd_m(q, k, v, extra_options):
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v
def tomesd_u(n, extra_options):
return self.u(n)

m = model.clone()
m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u)
return (m, )


NODE_CLASS_MAPPINGS = {
"TomePatchModel": TomePatchModel,
}
18 changes: 1 addition & 17 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,22 +437,6 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora)

class TomePatchModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"

def patch(self, model, ratio):
m = model.clone()
m.set_model_tomesd(ratio)
return (m, )

class VAELoader:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -1341,7 +1325,6 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
Expand Down Expand Up @@ -1466,4 +1449,5 @@ def init_custom_nodes():
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
load_custom_nodes()

0 comments on commit 0567694

Please sign in to comment.