Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: multi-device support for contrastive search #24635

Merged
merged 1 commit into from
Jul 3, 2023

Conversation

gante
Copy link
Member

@gante gante commented Jul 3, 2023

What does this PR do?

Fixes #24634

In multi-gpu settings, the past KV cache may be scattered across devices -- the cache corresponding to a layer sits in the same device as the layer itself, and different layers may be in different devices.

In contrastive search, we must apply indexing operations on the past KV cache. The indexes are in a tensor, which sits on the same device as the model outputs by default. Applying these indexes on the past KV cache currently results in an exception if the model is split across devices (see the issue linked above).

This means we either move the indexing tensor to all possible devices or keep the tensor on CPU. Indexing is typically CPU-heavy on PyTorch, so the benchmarks on my end indicate that moving the indexing tensor to the CPU enables multi-device contrastive search without noticeable throughput degradation 🙌

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 3, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante requested a review from amyeroberts July 3, 2023 13:54
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging into this and fixing!

It would be great to share the tests ran and numbers for future reference or work 🤗

@gante
Copy link
Member Author

gante commented Jul 3, 2023

For future reference, here's the benchmark code:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

# Other configuration options
DEVICE = "cuda:0"
NUM_RUNS = 10
MAX_NEW_TOKENS = 1000
TEXT_INPUT = "def sieve_of_eratosthenes():"

# Load the model and prepare generate args
repo_id = "huggyllama/llama-7b"
model = AutoModelForCausalLM.from_pretrained(repo_id, device_map="auto", load_in_4bit=True)

assistant_model = None
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
model_inputs = tokenizer(TEXT_INPUT, return_tensors="pt").to(DEVICE)

generate_kwargs = {
    "max_new_tokens": MAX_NEW_TOKENS,
    "top_k": 10,
    "penalty_alpha": 0.6,
}

# Warmup
print("Warming up...")
for _ in range(2):
    gen_out = model.generate(**model_inputs, **generate_kwargs)
print("Done!")


# Measure OR Stream
def measure_generate(model, model_inputs, generate_kwargs):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    torch.cuda.reset_peak_memory_stats(DEVICE)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    start_event.record()
    for _ in tqdm(range(NUM_RUNS)):
        gen_out = model.generate(**model_inputs, **generate_kwargs)
    end_event.record()

    torch.cuda.synchronize()
    max_memory = torch.cuda.max_memory_allocated(DEVICE)
    print("Max memory (MB): ", max_memory * 1e-6)
    print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3))

measure_generate(model, model_inputs, generate_kwargs)

On my end, with a RTX3090, I get 150 tokens/s before and after these changes.

@gante gante merged commit 9934bb1 into huggingface:main Jul 3, 2023
@gante gante deleted the multi_device_contrastive branch July 3, 2023 15:08
@amyeroberts
Copy link
Collaborator

@gante Thanks for adding the script! ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

.generate() supports contrastive-search on multi-device?
3 participants