diff --git a/nodes.py b/nodes.py index d9c26946230..b22cdfc9ba5 100644 --- a/nodes.py +++ b/nodes.py @@ -82,15 +82,23 @@ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_streng print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] + pooled_output_from = conditioning_from[0][1].get("pooled_output", None) for i in range(len(conditioning_to)): t1 = conditioning_to[i][0] + pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from) t0 = cond_from[:,:t1.shape[1]] if t0.shape[1] < t1.shape[1]: t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1) tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength)) - n = [tw, conditioning_to[i][1].copy()] + t_to = conditioning_to[i][1].copy() + if pooled_output_from is not None and pooled_output_to is not None: + t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength)) + elif pooled_output_from is not None: + t_to["pooled_output"] = pooled_output_from + + n = [tw, t_to] out.append(n) return (out, )