Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding #22234

Merged
merged 1 commit into from
Mar 17, 2023
Merged

Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding #22234

merged 1 commit into from
Mar 17, 2023

Conversation

ma787639046
Copy link
Contributor

What does this PR do?

The original implementation of LlamaRotaryEmbedding does not use cos_cached & sin_cached tensors as the PyTorch Parameter or Buffer, thus these tensors do not move to GPU when we use model.to(gpu_id) or model.cuda(). They will keep in the device CPU.

This PR adjusts the cos_cached & sin_cached tensors to the Buffer with persistent=False. This keeps these tensors moving from CPU to GPU together with the model, while keeping them out of the model's state_dict as original.

Fixes:

Fix unnecessary moves of tensors from CPU to GPU in LlamaRotaryEmbedding, for saving a large amount of CPU usage especially when we do inference.

Code for Reproducing the issue:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"    # Single card Generation

from tqdm import tqdm

import torch
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", torch_dtype=torch.float16)
model = model.cuda()
model.eval()

# Batch generation
inputs = [
    "LLaMa is a large language model developed by Meta AI, for",
] * 32

batch = tokenizer(inputs, return_tensors="pt", add_special_tokens=False)
batch = batch.to(model.device)

# Here we do some high computational batched generation
for i in tqdm(range(5000)): 
    generated = model.generate(batch["input_ids"],
                                temperature=0.7, top_p=0.9, do_sample=True,
                                num_beams=1, max_new_tokens=600,)

Use top command in bash to watch the CPU usage.

Here are the comparison before applying this PR and after this PR:

Before:

Fix USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND
Before root 20 0 108.6g 1.9g 411620 R 6263 0.2 40:28.1 python

After:

Fix USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND
After root 20 0 108.6g 1.8g 414360 R 98.3 0.2 03:21.6 python

Here the CPU usage drops to a normal level because the cos_cached & sin_cached tensors can move to GPU correctly with the model. This helps avoid unnecessary moves of tensors from CPU to GPU in LlamaRotaryEmbedding.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker and @younesbelkada

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 17, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing and for the benchmarks!
This looks very good to me

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for this fix and all the details!

I'd like @sgugger to also review this before we merge

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the right way to do it, thanks a lot for the fix!

@sgugger sgugger merged commit cf601b9 into huggingface:main Mar 17, 2023
@BlackSamorez
Copy link
Contributor

Did you accidentally break meta loading?
with init_empty_weights():
leaves cos_cached and sin_cached on meta device and they won't be initialized because they are not persistent.

Comment on lines +102 to +103
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I set persistent=False in this PR because cos_cached and sin_cached are not included in the model's state_dict of the original checkpoint. Setting persistent=True will induce missing key warnings when loading the llama model with from_pretrained().
But if this breaks the process of model loading on the meta device, please feel free to correct them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if persistent=True is aboslutely needed to make the meta device loading work, you could add r".*.cos_cached" in _keys_to_ignore_on_missing and _keys_to_ignore_on_unexpected (same for sin_cached)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done some testing and this change isn't necessary: those buffers are not affected by with init_empty_weights(): and the problem was somewhere else.
Consider this suggestion invalid.
Thank you!

Comment on lines +114 to +115
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

@BlackSamorez
Copy link
Contributor

Since inv_freq is a persistent buffer, it should be ok to also make harmonics persistent

@amyeroberts
Copy link
Collaborator

Hi @BlackSamorez, would you like to open a PR with these suggested changes including details about the issue they resolve?

@sgugger
Copy link
Collaborator

sgugger commented Mar 21, 2023

@BlackSamorez init_empty_weights ignores buffers by default, so this should not cause any problem. We have multiple instance of non-persistent buffers in the lib and this is not a problem. I've also run Llama without any issue after it being initialized on the meta device.

@ma787639046
Copy link
Contributor Author

Hi, I test the following two codes on my device. It seems the meta device works correctly in this PR.

import pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

model_name_or_path = "decapoda-research/llama-7b-hf"

model1 = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
model1 = model1.to(torch.device("cuda:0"))

# Save a initialized cos_cached tensor to `cos1.pt`, for comparasion with meta device loading
cos1 = model1.model.layers[0].self_attn.rotary_emb.cos_cached.to(torch.device("cpu"))
pickle.dump(cos1, open("cos1.pt", 'wb'))
import pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch

model_name_or_path = "decapoda-research/llama-7b-hf"

config = LlamaConfig.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
with init_empty_weights():
    model0 = LlamaForCausalLM(config)

model0 = load_checkpoint_and_dispatch(
    model0, model_name_or_path, device_map='auto', 
)
# Compare the `cos_cached` tensor
cos0 = model0.model.layers[0].self_attn.rotary_emb.cos_cached.to(torch.device("cpu"))
cos1 = pickle.load(open("cos1.pt", 'rb'))
all((cos0==cos1).tolist())      # True

@BlackSamorez Maybe you can check the results on your device.

@BlackSamorez
Copy link
Contributor

BlackSamorez commented Mar 21, 2023

Yes, you're right and I was wrong. It works and the problem was in entirely different part of my program.
Consider #22234 (comment) and #22234 (comment) invalid.
Thank you!

@memray
Copy link

memray commented Jun 25, 2023

I'm still facing this issue with latest deepspeed (0.9.5+1491e14e) and transformers (4.31.0.dev0). I feel this issue is more likely related to the LLaMA implementation here (LlamaRotaryEmbedding).

RuntimeErrorcos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]: 
indices should be either on cpu or on the same device as the indexed tensor (cpu)
    RuntimeErrorcos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]: 
indices should be either on cpu or on the same device as the indexed tensor (cpu)    
RuntimeErrorcos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]: 
indices should be either on cpu or on the same device as the indexed tensor (cpu)
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)    
cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
RuntimeError    : cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]indices should be either on cpu or on the same device as the indexed tensor (cpu)

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

@Rothsword
Copy link

I'm still facing this issue with latest deepspeed (0.9.5+1491e14e) and transformers (4.31.0.dev0). I feel this issue is more likely related to the LLaMA implementation here (LlamaRotaryEmbedding).

RuntimeErrorcos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]: 
indices should be either on cpu or on the same device as the indexed tensor (cpu)
    RuntimeErrorcos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]: 
indices should be either on cpu or on the same device as the indexed tensor (cpu)    
RuntimeErrorcos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]: 
indices should be either on cpu or on the same device as the indexed tensor (cpu)
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)    
cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
RuntimeError    : cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]indices should be either on cpu or on the same device as the indexed tensor (cpu)

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

I encountered exactly the same issue,training failed when using zero3

@ArthurZucker
Copy link
Collaborator

Thanks for the input, will investigate!

@Zhanghahah
Copy link

Zhanghahah commented Aug 15, 2023

any updates? I also meet this issues. with ds==0.9.3, transformers==4.32.0dev

@ArthurZucker
Copy link
Collaborator

Did not have time to investigate, we are going to need a reproducer if you want some help here. Pinging @pacman100 when we have a reproducer shared!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants