Skip to content

Commit

Permalink
Add a way to set patches that modify the attn2 output.
Browse files Browse the repository at this point in the history
Change the transformer patches function format to be more future proof.
  • Loading branch information
comfyanonymous committed Jun 19, 2023
1 parent cd930d4 commit 8883cb0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion comfy/gligen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def func_lowvram(key, x):
return r
return func_lowvram
else:
def func(key, x):
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
Expand Down
18 changes: 13 additions & 5 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,11 @@ def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)

def _forward(self, x, context=None, transformer_options={}):
current_index = None
extra_options = {}
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
extra_options["block_index"] = transformer_options["block_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
Expand All @@ -545,7 +547,7 @@ def _forward(self, x, context=None, transformer_options={}):
context_attn1 = n
value_attn1 = context_attn1
for p in patch:
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)

if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
Expand All @@ -557,7 +559,7 @@ def _forward(self, x, context=None, transformer_options={}):
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
x = p(x, extra_options)

n = self.norm2(x)

Expand All @@ -567,10 +569,15 @@ def _forward(self, x, context=None, transformer_options={}):
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)

n = self.attn2(n, context=context_attn2, value=value_attn2)

if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
for p in patch:
n = p(n, extra_options)

x += n
x = self.ff(self.norm3(x)) + x
return x
Expand Down Expand Up @@ -631,6 +638,7 @@ def forward(self, x, context=None, transformer_options={}):
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
Expand Down
3 changes: 3 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def set_model_attn1_patch(self, patch):
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")

def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")

def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class hypernetwork_patch:
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, current_index, q, k, v):
def __call__(self, q, k, v, extra_options):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]
Expand Down

0 comments on commit 8883cb0

Please sign in to comment.