Skip to content

Commit

Permalink
batched llama
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Feb 19, 2024
1 parent 593230f commit ca3dcaf
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t()
freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ (
position_ids[:, None, :].float()
)
freqs = freqs.transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)

Expand Down Expand Up @@ -181,6 +184,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down

0 comments on commit ca3dcaf

Please sign in to comment.