-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling #11394
Merged
Merged
Changes from 12 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ccf53d1
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling
WoosukKwon 9df8ccf
Merge branch 'main' into v1-topk-top
WoosukKwon 83d9aa4
update
WoosukKwon 0c6d409
Add warning
WoosukKwon 121cea5
fix
WoosukKwon cf097f4
minor
WoosukKwon 98374e0
comment
WoosukKwon e068d68
Minor
WoosukKwon 6e97c5f
fix
WoosukKwon 15fda81
minor
WoosukKwon 3dcac1c
Fix tests
WoosukKwon 5cac3e1
Minor
WoosukKwon 8061a16
comment
WoosukKwon e968e18
Minor
WoosukKwon 0f784a5
Minor
WoosukKwon 6bea166
Consider VLLM_USE_FLASHINFER_SAMPLER
WoosukKwon 68ffc96
Minor
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import List, Set, Tuple | ||
|
||
import torch | ||
|
||
from vllm.model_executor.layers.utils import ( | ||
apply_penalties as _apply_penalties) | ||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad | ||
|
||
|
||
def apply_min_token_penalties(logits: torch.Tensor, | ||
output_token_ids: List[List[int]], | ||
stop_token_ids: List[Set[int]], | ||
min_tokens: List[int]) -> None: | ||
""" | ||
Applies minimum token penalty by setting the logits of the stop tokens | ||
to -inf. | ||
""" | ||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] | ||
for index, min_token in enumerate(min_tokens): | ||
if (len(output_token_ids[index]) < min_token): | ||
for stop_token_id in stop_token_ids[index]: | ||
min_tokens_logits_to_penalize.append((index, stop_token_id)) | ||
if min_tokens_logits_to_penalize: | ||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") | ||
|
||
|
||
def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, | ||
presence_penalties: torch.Tensor, | ||
frequency_penalties: torch.Tensor, | ||
repetition_penalties: torch.Tensor, | ||
output_token_ids: List[List[int]]) -> torch.Tensor: | ||
""" | ||
Applies presence, frequency and repetition penalties to the logits. | ||
""" | ||
_, vocab_size = logits.shape | ||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, | ||
logits.device) | ||
return _apply_penalties(logits, prompt_token_ids, output_tokens_t, | ||
presence_penalties, frequency_penalties, | ||
repetition_penalties) | ||
|
||
|
||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, | ||
device: torch.device) -> torch.Tensor: | ||
""" | ||
Convert the different list data structures to tensors. | ||
""" | ||
output_tokens_tensor = make_tensor_with_pad( | ||
output_token_ids, | ||
# Use the value of vocab_size as a pad since we don't have a | ||
# token_id of this value. | ||
pad=vocab_size, | ||
device="cpu", | ||
dtype=torch.int64, | ||
pin_memory=is_pin_memory_available(), | ||
) | ||
return output_tokens_tensor.to(device, non_blocking=True) |
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.logger import init_logger | ||
from vllm.platforms import current_platform | ||
|
||
logger = init_logger(__name__) | ||
|
||
try: | ||
import flashinfer.sampling | ||
use_flashinfer = True | ||
except ImportError: | ||
use_flashinfer = False | ||
|
||
|
||
class TopKTopPSampler(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
if current_platform.is_cuda: | ||
if use_flashinfer: | ||
self.forward = self.forward_cuda | ||
else: | ||
logger.warning( | ||
"FlashInfer is not available. Falling back to the PyTorch-" | ||
"native implementation of top-p & top-k sampling. For the " | ||
"best performance, please install FalshInfer.") | ||
self.forward = self.forward_native | ||
comaniac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
self.forward = self.forward_native | ||
|
||
def forward_native( | ||
self, | ||
logits: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
return random_sample(probs, generators) | ||
|
||
def forward_cuda( | ||
self, | ||
logits: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
if no_top_k and no_top_p: | ||
return random_sample(probs, generators) | ||
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) | ||
|
||
|
||
def apply_top_k_top_p( | ||
logits: torch.Tensor, | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
if no_top_k and no_top_p: | ||
return logits | ||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||
|
||
if not no_top_k: | ||
# Apply top-k. | ||
top_k_mask = logits_sort.size(1) - k.to(torch.long) | ||
# Get all the top_k values. | ||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | ||
top_k_mask = logits_sort < top_k_mask | ||
logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||
|
||
if not no_top_p: | ||
# Apply top-p. | ||
probs_sort = logits_sort.softmax(dim=-1) | ||
probs_sum = probs_sort.cumsum(dim=-1) | ||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | ||
# at least one | ||
top_p_mask[:, -1] = False | ||
logits_sort.masked_fill_(top_p_mask, -float("inf")) | ||
|
||
# Re-sort the probabilities. | ||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) | ||
return logits | ||
|
||
|
||
def random_sample( | ||
probs: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
) -> torch.Tensor: | ||
q = torch.empty_like(probs) | ||
# NOTE(woosuk): To batch-process the requests without their own seeds, | ||
# which is the common case, we first assume that every request does | ||
# not have its own seed. Then, we overwrite the values for the requests | ||
# that have their own seeds. | ||
if len(generators) != probs.shape[0]: | ||
# This might still be done here unnecessarily if there are greedies | ||
q.exponential_() | ||
if generators: | ||
# TODO(woosuk): This can be slow because we handle each request | ||
# one by one. Optimize this. | ||
for i, generator in generators.items(): | ||
q[i].exponential_(generator=generator) | ||
return probs.div_(q).argmax(dim=-1).view(-1) | ||
|
||
|
||
def flashinfer_sample( | ||
probs: torch.Tensor, | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
) -> torch.Tensor: | ||
assert not (no_top_k and no_top_p) | ||
max_top_k_round = 32 | ||
batch_size = probs.shape[0] | ||
uniform_samples = torch.empty((max_top_k_round, batch_size), | ||
device=probs.device) | ||
if len(generators) != batch_size: | ||
uniform_samples.uniform_() | ||
if generators: | ||
for i, generator in generators.items(): | ||
uniform_samples[:, i].uniform_(generator=generator) | ||
|
||
if no_top_k: | ||
# Top-p only. | ||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( | ||
probs, uniform_samples, p, deterministic=True) | ||
elif no_top_p: | ||
# Top-k only. | ||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( | ||
probs, uniform_samples, k, deterministic=True) | ||
else: | ||
# Both top-k and top-p. | ||
next_token_ids, success = ( | ||
flashinfer.sampling.top_k_top_p_sampling_from_probs( | ||
probs, uniform_samples, k, p, deterministic=True)) | ||
|
||
# NOTE: CPU-GPU synchronization happens here. | ||
if not success.all(): | ||
if not no_top_k: | ||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k) | ||
if not no_top_p: | ||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p) | ||
next_token_ids = flashinfer.sampling.sampling_from_probs( | ||
probs, uniform_samples[0], deterministic=True) | ||
return next_token_ids.view(-1) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified the tests here because the sampler returns the original logprobs instead of the final log probs after applying penalties and scaling.