-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding #22234
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing and for the benchmarks!
This looks very good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for this fix and all the details!
I'd like @sgugger to also review this before we merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the right way to do it, thanks a lot for the fix!
Did you accidentally break meta loading? |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | ||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I set persistent=False in this PR because cos_cached
and sin_cached
are not included in the model's state_dict of the original checkpoint. Setting persistent=True will induce missing key warnings when loading the llama model with from_pretrained().
But if this breaks the process of model loading on the meta device, please feel free to correct them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if persistent=True
is aboslutely needed to make the meta device loading work, you could add r".*.cos_cached"
in _keys_to_ignore_on_missing
and _keys_to_ignore_on_unexpected
(same for sin_cached
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done some testing and this change isn't necessary: those buffers are not affected by with init_empty_weights():
and the problem was somewhere else.
Consider this suggestion invalid.
Thank you!
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | ||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) |
Since |
Hi @BlackSamorez, would you like to open a PR with these suggested changes including details about the issue they resolve? |
@BlackSamorez |
Hi, I test the following two codes on my device. It seems the meta device works correctly in this PR. import pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
model_name_or_path = "decapoda-research/llama-7b-hf"
model1 = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
model1 = model1.to(torch.device("cuda:0"))
# Save a initialized cos_cached tensor to `cos1.pt`, for comparasion with meta device loading
cos1 = model1.model.layers[0].self_attn.rotary_emb.cos_cached.to(torch.device("cpu"))
pickle.dump(cos1, open("cos1.pt", 'wb')) import pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
model_name_or_path = "decapoda-research/llama-7b-hf"
config = LlamaConfig.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
with init_empty_weights():
model0 = LlamaForCausalLM(config)
model0 = load_checkpoint_and_dispatch(
model0, model_name_or_path, device_map='auto',
)
# Compare the `cos_cached` tensor
cos0 = model0.model.layers[0].self_attn.rotary_emb.cos_cached.to(torch.device("cpu"))
cos1 = pickle.load(open("cos1.pt", 'rb'))
all((cos0==cos1).tolist()) # True @BlackSamorez Maybe you can check the results on your device. |
Yes, you're right and I was wrong. It works and the problem was in entirely different part of my program. |
I'm still facing this issue with latest deepspeed (0.9.5+1491e14e) and transformers (4.31.0.dev0). I feel this issue is more likely related to the LLaMA implementation here (LlamaRotaryEmbedding).
|
I encountered exactly the same issue,training failed when using zero3 |
Thanks for the input, will investigate! |
any updates? I also meet this issues. with ds==0.9.3, transformers==4.32.0dev |
Did not have time to investigate, we are going to need a reproducer if you want some help here. Pinging @pacman100 when we have a reproducer shared! |
What does this PR do?
The original implementation of LlamaRotaryEmbedding does not use
cos_cached
&sin_cached
tensors as the PyTorch Parameter or Buffer, thus these tensors do not move to GPU when we usemodel.to(gpu_id)
ormodel.cuda()
. They will keep in the device CPU.This PR adjusts the
cos_cached
&sin_cached
tensors to the Buffer with persistent=False. This keeps these tensors moving from CPU to GPU together with the model, while keeping them out of the model's state_dict as original.Fixes:
Fix unnecessary moves of tensors from CPU to GPU in LlamaRotaryEmbedding, for saving a large amount of CPU usage especially when we do inference.
Code for Reproducing the issue:
Use
top
command in bash to watch the CPU usage.Here are the comparison before applying this PR and after this PR:
Before:
After:
Here the CPU usage drops to a normal level because the
cos_cached
&sin_cached
tensors can move to GPU correctly with the model. This helps avoid unnecessary moves of tensors from CPU to GPU in LlamaRotaryEmbedding.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker and @younesbelkada