diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c30be2a2da4f63..eeac16f543a01d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -103,7 +103,10 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t() + freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ ( + position_ids[:, None, :].float() + ) + freqs = freqs.transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) @@ -181,6 +184,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed