Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding #22234

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
Comment on lines +102 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I set persistent=False in this PR because cos_cached and sin_cached are not included in the model's state_dict of the original checkpoint. Setting persistent=True will induce missing key warnings when loading the llama model with from_pretrained().
But if this breaks the process of model loading on the meta device, please feel free to correct them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if persistent=True is aboslutely needed to make the meta device loading work, you could add r".*.cos_cached" in _keys_to_ignore_on_missing and _keys_to_ignore_on_unexpected (same for sin_cached)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done some testing and this change isn't necessary: those buffers are not affected by with init_empty_weights(): and the problem was somewhere else.
Consider this suggestion invalid.
Thank you!


def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -111,11 +111,11 @@ def forward(self, x, seq_len=None):
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype)
self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
Comment on lines +114 to +115
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


Expand Down