From b7382c1844435ecde989d168ddab60c840be2de5 Mon Sep 17 00:00:00 2001 From: fecet Date: Wed, 23 Aug 2023 12:21:53 +0800 Subject: [PATCH] fix rope precision bug When we use model.bfloat16().cuda(), `inv_freq` will be converted to bfloat16 even we defined it in dtype=float32. In bfloat16, position would not be represented right, for example, we cannot distinguish the difference between `256.0` and `257.0`, so when we create position embedding, we forced convert `inv_freq` and `t` to fp32. --- megatron/model/positional_embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 68815075a..84aaf3844 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -50,8 +50,8 @@ def forward(self, x, seq_dim=1, seq_len=None): seq_len = x.shape[seq_dim] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.float()) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: emb = emb.float()