Skip to content

Commit

Permalink
Add prompt to prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
cccntu committed Oct 16, 2022
1 parent a593a66 commit b43bc04
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
k = self.to_k(context if isinstance(context, torch.Tensor) else context[0])
v = self.to_v(context if isinstance(context, torch.Tensor) else context[1])

b, _, _ = q.shape
q, k, v = map(
Expand Down Expand Up @@ -331,8 +331,8 @@ def forward(self, hidden_states, context=None, mask=None):

query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
key = self.to_k(context if isinstance(context, torch.Tensor) else context[0])
value = self.to_v(context if isinstance(context, torch.Tensor) else context[1])

dim = query.shape[-1]

Expand Down

0 comments on commit b43bc04

Please sign in to comment.