-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding #22234
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -99,8 +99,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |||||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||||||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||||||||||
emb = torch.cat((freqs, freqs), dim=-1) | ||||||||||
self.cos_cached = emb.cos()[None, None, :, :] | ||||||||||
self.sin_cached = emb.sin()[None, None, :, :] | ||||||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | ||||||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | ||||||||||
|
||||||||||
def forward(self, x, seq_len=None): | ||||||||||
# x: [bs, num_attention_heads, seq_len, head_size] | ||||||||||
|
@@ -111,11 +111,11 @@ def forward(self, x, seq_len=None): | |||||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||||||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||||||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | ||||||||||
self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype) | ||||||||||
self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype) | ||||||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | ||||||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | ||||||||||
Comment on lines
+114
to
+115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
return ( | ||||||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), | ||||||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), | ||||||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||||||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I set persistent=False in this PR because
cos_cached
andsin_cached
are not included in the model's state_dict of the original checkpoint. Setting persistent=True will induce missing key warnings when loading the llama model with from_pretrained().But if this breaks the process of model loading on the meta device, please feel free to correct them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
persistent=True
is aboslutely needed to make the meta device loading work, you could addr".*.cos_cached"
in_keys_to_ignore_on_missing
and_keys_to_ignore_on_unexpected
(same forsin_cached
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done some testing and this change isn't necessary: those buffers are not affected by
with init_empty_weights():
and the problem was somewhere else.Consider this suggestion invalid.
Thank you!