Skip to content

Commit

Permalink
ConditioningAverage now also averages the pooled output.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 4, 2023
1 parent d94ddd8 commit 3a09fac
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,23 @@ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_streng
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")

cond_from = conditioning_from[0][0]
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)

for i in range(len(conditioning_to)):
t1 = conditioning_to[i][0]
pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
t0 = cond_from[:,:t1.shape[1]]
if t0.shape[1] < t1.shape[1]:
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)

tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
n = [tw, conditioning_to[i][1].copy()]
t_to = conditioning_to[i][1].copy()
if pooled_output_from is not None and pooled_output_to is not None:
t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
elif pooled_output_from is not None:
t_to["pooled_output"] = pooled_output_from

n = [tw, t_to]
out.append(n)
return (out, )

Expand Down

0 comments on commit 3a09fac

Please sign in to comment.