From ccf53d1dce19be71b38fc09f50b6afe2e58a4854 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 20 Dec 2024 22:56:56 -0800 Subject: [PATCH 01/16] [V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/__init__.py | 0 vllm/v1/sample/ops/topk_topp_sampler.py | 148 ++++++++++++++++++++++++ vllm/v1/sample/sampler.py | 119 +++++-------------- 3 files changed, 177 insertions(+), 90 deletions(-) create mode 100644 vllm/v1/sample/ops/__init__.py create mode 100644 vllm/v1/sample/ops/topk_topp_sampler.py diff --git a/vllm/v1/sample/ops/__init__.py b/vllm/v1/sample/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py new file mode 100644 index 0000000000000..7560a070b900a --- /dev/null +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -0,0 +1,148 @@ +from typing import Dict + +import torch +import torch.nn as nn + +from vllm.platforms import current_platform + + +class TopKTopPSampler(nn.Module): + + def forward( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + if current_platform.is_cuda: + return self.forward_cuda(logits, generators, no_top_k, k, no_top_p, + p) + return self.forward_native(logits, generators, no_top_k, k, no_top_p, + p) + + def forward_native( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + def forward_cuda( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + probs = logits.softmax(dim=-1, dtype=torch.float32) + if no_top_k and no_top_p: + return random_sample(probs, generators) + return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) + + +def apply_top_k_top_p( + logits: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, +) -> torch.Tensor: + if no_top_k and no_top_p: + return logits + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if not no_top_k: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if not no_top_p: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def random_sample( + probs: torch.Tensor, + generators: Dict[int, torch.Generator], +) -> torch.Tensor: + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + if len(generators) != probs.shape[0]: + # This might still be done here unnecessarily if there are greedies + q.exponential_() + if generators: + # TODO(woosuk): This can be slow because we handle each request + # one by one. Optimize this. + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def flashinfer_sample( + probs: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + generators: Dict[int, torch.Generator], +) -> torch.Tensor: + assert not (no_top_k and no_top_p) + max_top_k_round = 32 + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if len(generators) != batch_size: + uniform_samples.uniform_() + if generators: + for i, generator in generators.items(): + uniform_samples[:, i].uniform_(generator=generator) + + import flashinfer.sampling + if no_top_k: + # Top-p only. + next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( + probs, uniform_samples, p, deterministic=True) + elif no_top_p: + # Top-k only. + next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( + probs, uniform_samples, k, deterministic=True) + else: + # Both top-k and top-p. + next_token_ids, success = ( + flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, uniform_samples, k, p, deterministic=True)) + + if not success.all(): + if not no_top_k: + probs = flashinfer.sampling.top_k_renorm_prob(probs, k) + if not no_top_p: + probs = flashinfer.sampling.top_p_renorm_prob(probs, p) + next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0], deterministic=True) + return next_token_ids.view(-1) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d1a755be01ff7..26ee9f1726e72 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,32 +1,38 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict - import torch import torch.nn as nn from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - logits = self.apply_temperature(logits, sampling_metadata.temperature) - logits = self.apply_top_k_top_p(logits, sampling_metadata) + # Use float32 for the logits. + logits = logits.to(torch.float32) + orig_logits = logits - probs = self.get_probs(logits) - sampled = self.sample(probs, sampling_metadata) + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) if sampling_metadata.max_num_logprobs > 0: - logprobs = self.get_logprobs(logits) + logprobs = self.get_logprobs(orig_logits) # FIXME: Mask the sampled token_id, get topk logprobs, # and concatenate the topk with the sampled token_id. topk_logprobs, topk_indices = torch.topk( @@ -52,71 +58,35 @@ def apply_temperature( logits: torch.Tensor, temp: torch.Tensor, ) -> torch.Tensor: - # Use float32 to apply temperature scaling. - logits = logits.to(torch.float32) # Avoid division by zero. temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) - # Use in-place division to avoid creating a new tensor. - logits.div_(temp.unsqueeze(dim=1)) - return logits + return logits / temp.unsqueeze(dim=1) + + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) - def apply_top_k_top_p( + def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - return _apply_top_k_top_p( + assert not (sampling_metadata.all_greedy + and sampling_metadata.all_random) + if sampling_metadata.all_greedy: + return self.greedy_sample(logits) + + random_sampled = self.topk_topp_sampler( logits, + sampling_metadata.generators, sampling_metadata.no_top_k, sampling_metadata.top_k, sampling_metadata.no_top_p, sampling_metadata.top_p, ) - - def get_probs(self, logits: torch.Tensor) -> torch.Tensor: - return torch.softmax(logits, dim=-1, dtype=torch.float32) - - def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: - return torch.log_softmax(logits, dim=-1, dtype=torch.float32) - - def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: - return probs.argmax(dim=-1).view(-1) - - def random_sample( - self, - probs: torch.Tensor, - generators: Dict[int, torch.Generator], - ) -> torch.Tensor: - q = torch.empty_like(probs) - # NOTE(woosuk): To batch-process the requests without their own seeds, - # which is the common case, we first assume that every request does - # not have its own seed. Then, we overwrite the values for the requests - # that have their own seeds. - if len(generators) != probs.shape[0]: - # This might still be done here unnecessarily if there are greedies - q.exponential_() - if generators: - # TODO(woosuk): This can be slow because we handle each request - # one by one. Optimize this. - for i, generator in generators.items(): - q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1) - - def sample( - self, - probs: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - assert not (sampling_metadata.all_greedy - and sampling_metadata.all_random) - if sampling_metadata.all_greedy: - return self.greedy_sample(probs) if sampling_metadata.all_random: - return self.random_sample(probs, sampling_metadata.generators) + return random_sampled - greedy_sampled = self.greedy_sample(probs) - random_sampled = self.random_sample(probs, - sampling_metadata.generators) + greedy_sampled = self.greedy_sample(logits) sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, @@ -124,36 +94,5 @@ def sample( ) return sampled - -# TODO(woosuk): Optimize this with a custom kernel. -def _apply_top_k_top_p( - logits: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, -) -> torch.Tensor: - if no_top_k and no_top_p: - return logits - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - if not no_top_k: - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if not no_top_p: - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits + def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return torch.log_softmax(logits, dim=-1, dtype=torch.float32) From 83d9aa44d726f5509feac2c06f71d804e45a4d63 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 02:47:05 -0800 Subject: [PATCH 02/16] update Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/penalties.py | 57 ++++++++++++++++++ vllm/v1/sample/sampler.py | 100 ++++++++------------------------ 2 files changed, 81 insertions(+), 76 deletions(-) create mode 100644 vllm/v1/sample/ops/penalties.py diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py new file mode 100644 index 0000000000000..91ebaf9269f32 --- /dev/null +++ b/vllm/v1/sample/ops/penalties.py @@ -0,0 +1,57 @@ +from typing import List, Set, Tuple + +import torch + +from vllm.model_executor.layers.utils import ( + apply_penalties as _apply_penalties) +from vllm.utils import is_pin_memory_available, make_tensor_with_pad + + +def apply_min_token_penalties(logits: torch.Tensor, + output_token_ids: List[List[int]], + stop_token_ids: List[Set[int]], + min_tokens: List[int]) -> None: + """ + Applies minimum token penalty by setting the logits of the stop tokens + to -inf. + """ + min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] + for index, min_token in enumerate(min_tokens): + if (len(output_token_ids[index]) < min_token): + for stop_token_id in stop_token_ids[index]: + min_tokens_logits_to_penalize.append((index, stop_token_id)) + if min_tokens_logits_to_penalize: + logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") + + +def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: List[List[int]]) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return _apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) + + +def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, + device: torch.device) -> torch.Tensor: + """ + Convert the different list data structures to tensors. + """ + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=is_pin_memory_available(), + ) + return output_tokens_tensor.to(device, non_blocking=True) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 7b79075a24fed..983504d7529f1 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,13 +1,11 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Set, Tuple - import torch import torch.nn as nn -from vllm.model_executor.layers.utils import apply_penalties -from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.penalties import (apply_min_token_penalties, + apply_penalties) from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler _SAMPLING_EPS = 1e-5 @@ -26,36 +24,20 @@ def forward( ) -> SamplerOutput: # Use float32 for the logits. logits = logits.to(torch.float32) + needs_logprobs = sampling_metadata.max_num_logprobs > 0 + if needs_logprobs: + orig_logits = logits.clone() - _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, - sampling_metadata.stop_token_ids, - sampling_metadata.min_tokens) - if not sampling_metadata.no_penalties: - assert sampling_metadata.prompt_token_ids is not None - _apply_penalties(logits, sampling_metadata.prompt_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids) - - logits = self.apply_temperature(logits, sampling_metadata.temperature) - logits = self.apply_top_k_top_p(logits, sampling_metadata) - probs = self.get_probs(logits) - sampled = self.sample(probs, sampling_metadata) - - - orig_logits = logits - + # Apply penalties (e.g., min_tokens, freq_penalties). + logits = self.apply_penalties(logits, sampling_metadata) # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. sampled = self.sample(logits, sampling_metadata) - - # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) - if sampling_metadata.max_num_logprobs > 0: + if needs_logprobs: logprobs = self.get_logprobs(orig_logits) # FIXME: Mask the sampled token_id, get topk logprobs, # and concatenate the topk with the sampled token_id. @@ -84,7 +66,7 @@ def apply_temperature( ) -> torch.Tensor: # Avoid division by zero. temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) - return logits / temp.unsqueeze(dim=1) + return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) @@ -121,52 +103,18 @@ def sample( def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return torch.log_softmax(logits, dim=-1, dtype=torch.float32) - -def _apply_min_token_penalties(logits: torch.Tensor, - output_token_ids: List[List[int]], - stop_token_ids: List[Set[int]], - min_tokens: List[int]): - """ - Applies minimum token penalty by setting the logits of the stop tokens - to -inf. - """ - min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] - for index, min_token in enumerate(min_tokens): - if (len(output_token_ids[index]) < min_token): - for stop_token_id in stop_token_ids[index]: - min_tokens_logits_to_penalize.append((index, stop_token_id)) - if min_tokens_logits_to_penalize: - logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") - - -def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor, - output_token_ids: List[List[int]]): - """ - Applies presence, frequency and repetition penalties to the logits. - """ - _, vocab_size = logits.shape - output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, - logits.device) - return apply_penalties(logits, prompt_token_ids, output_tokens_t, - presence_penalties, frequency_penalties, - repetition_penalties) - - -def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, - device: torch.device) -> torch.Tensor: - """ - Convert the different list data structures to tensors. - """ - output_tokens_tensor = make_tensor_with_pad( - output_token_ids, - # Use the value of vocab_size as a pad since we don't have a - # token_id of this value. - pad=vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=is_pin_memory_available(), - ) - return output_tokens_tensor.to(device, non_blocking=True) + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + apply_min_token_penalties(logits, sampling_metadata.output_token_ids, + sampling_metadata.stop_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + apply_penalties(logits, sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids) From 0c6d409a9f899dcb21d7f5892e5da36b6da7f87b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 02:54:12 -0800 Subject: [PATCH 03/16] Add warning Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 35 +++++++++++++++---------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7560a070b900a..d1e54da1b2716 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -3,25 +3,33 @@ import torch import torch.nn as nn +from vllm.logger import init_logger from vllm.platforms import current_platform +logger = init_logger(__name__) + +try: + import flashinfer.sampling + use_flashinfer = True +except ImportError: + use_flashinfer = False + class TopKTopPSampler(nn.Module): - def forward( - self, - logits: torch.Tensor, - generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, - ) -> torch.Tensor: + def __init__(self): + super().__init__() if current_platform.is_cuda: - return self.forward_cuda(logits, generators, no_top_k, k, no_top_p, - p) - return self.forward_native(logits, generators, no_top_k, k, no_top_p, - p) + if use_flashinfer: + self.forward = self.forward_cuda + else: + logger.warning( + "flashinfer.sampling is not available. Falling back to " + "Pytorch-native implementation of sampling. For the best " + "performance, please install FalshInfer.") + self.forward = self.forward_native + else: + self.forward = self.forward_native def forward_native( self, @@ -123,7 +131,6 @@ def flashinfer_sample( for i, generator in generators.items(): uniform_samples[:, i].uniform_(generator=generator) - import flashinfer.sampling if no_top_k: # Top-p only. next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( From 121cea5f093ea851ff2c7efc6b7fae2f7e67cd52 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 02:57:55 -0800 Subject: [PATCH 04/16] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 983504d7529f1..f520fdd4aa51c 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -113,8 +113,10 @@ def apply_penalties( sampling_metadata.min_tokens) if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None - apply_penalties(logits, sampling_metadata.prompt_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids) + logits = apply_penalties(logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids) + return logits From cf097f4a7b2ef8b5ff6a8f4dd1c8acad343131bc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 03:11:47 -0800 Subject: [PATCH 05/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index f520fdd4aa51c..54f282782574f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -22,14 +22,14 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - # Use float32 for the logits. - logits = logits.to(torch.float32) needs_logprobs = sampling_metadata.max_num_logprobs > 0 if needs_logprobs: orig_logits = logits.clone() # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) + # Use float32 for the logits. + logits = logits.to(torch.float32) # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. @@ -66,7 +66,9 @@ def apply_temperature( ) -> torch.Tensor: # Avoid division by zero. temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) - return logits.div_(temp.unsqueeze(dim=1)) + # Use in-place division to avoid creating a new tensor. + logits.div_(temp.unsqueeze(dim=1)) + return logits def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) From 98374e09b8d648166ac787b7c349b5b104388f97 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 03:16:26 -0800 Subject: [PATCH 06/16] comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 1 + vllm/v1/sample/sampler.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index d1e54da1b2716..9ed1d5ddeaa4b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -145,6 +145,7 @@ def flashinfer_sample( flashinfer.sampling.top_k_top_p_sampling_from_probs( probs, uniform_samples, k, p, deterministic=True)) + # NOTE: CPU-GPU synchronization happens here. if not success.all(): if not no_top_k: probs = flashinfer.sampling.top_k_renorm_prob(probs, k) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 54f282782574f..2fd70a636c11f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -24,6 +24,8 @@ def forward( ) -> SamplerOutput: needs_logprobs = sampling_metadata.max_num_logprobs > 0 if needs_logprobs: + # NOTE: We need to clone the tensor because the below ops may + # modify the logits tensor in-place. orig_logits = logits.clone() # Apply penalties (e.g., min_tokens, freq_penalties). From e068d68f5b426d6502081c27971b88b3e3128384 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 03:21:00 -0800 Subject: [PATCH 07/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 2fd70a636c11f..74660be152000 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -27,11 +27,11 @@ def forward( # NOTE: We need to clone the tensor because the below ops may # modify the logits tensor in-place. orig_logits = logits.clone() + # Use float32 for the logits. + logits = logits.to(torch.float32) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Use float32 for the logits. - logits = logits.to(torch.float32) # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. From 6e97c5f01aaaaa4bfa35515b539f962b32da73f3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 03:28:46 -0800 Subject: [PATCH 08/16] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 74660be152000..588d57ab4131e 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -40,6 +40,8 @@ def forward( sampled = sampled.to(torch.int32) if needs_logprobs: + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. logprobs = self.get_logprobs(orig_logits) # FIXME: Mask the sampled token_id, get topk logprobs, # and concatenate the topk with the sampled token_id. From 15fda81ff024f0457489ce31cd906f4399c3e7d8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 03:37:47 -0800 Subject: [PATCH 09/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 588d57ab4131e..6d0a367aa1009 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -42,6 +42,8 @@ def forward( if needs_logprobs: # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). logprobs = self.get_logprobs(orig_logits) # FIXME: Mask the sampled token_id, get topk logprobs, # and concatenate the topk with the sampled token_id. From 3dcac1c06f682553cc08f66aedd3a3363b2c75ea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 04:23:43 -0800 Subject: [PATCH 10/16] Fix tests Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_sampler.py | 54 ++++++++++++++------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index d8d055805cbea..5ebf72927cfd6 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -68,7 +68,7 @@ def _create_default_sampling_metadata( no_top_p=True, no_top_k=True, generators={}, - max_num_logprobs=VOCAB_SIZE, + max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, @@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): sampling_metadata.min_tokens = min_tokens sampling_metadata.stop_token_ids = stop_token_ids sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - for vocab in range(VOCAB_SIZE): - # Verify that the logprobs for stop token ids is set - # to -inf. - logprob_index = torch.where( - sampler_output.logprob_token_ids[batch_idx] == - vocab)[0].item() - if vocab in stop_token_ids[batch_idx]: - assert sampler_output.logprobs[batch_idx][ - logprob_index] == -float("inf") + for token_id in range(VOCAB_SIZE): + if token_id in stop_token_ids[batch_idx]: + assert logits[batch_idx][token_id] == -float("inf") else: - assert sampler_output.logprobs[batch_idx][ - logprob_index] != -float("inf") + assert logits[batch_idx][token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int, batch_size, presence_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - # The logprobs in the SamplerOutput are arranged in descending order. - # Since all tokens initially have the same logprobs, the non-penalized - # tokens will appear at the beginning, while the penalized tokens - # will appear at the end of the list. - penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ - VOCAB_SIZE - 1] - penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] - non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0] - non_penalized_log_prod = sampler_output.logprobs[batch_idx][0] - assert non_penalized_log_prod > penalized_log_prod + # Since all tokens initially have the same logits, the non-penalized + # token ID will be the one with the highest logit value, while the + # penalized token ID will be the one with the lowest logit value. + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() if presence_penalty > 0: # If `presence_penalty` is set to a value greater than 0, it # indicates a preference for new tokens over those already @@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] - non_penalized_token_id = logprobs_token_ids[0] - penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() distinct_sorted_token_ids_in_output = \ sorted_token_ids_in_output[batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ @@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, batch_size, repetition_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] - non_penalized_token_id = logprobs_token_ids[0] - penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() prompt_tokens = sampling_metadata.prompt_token_ids[ batch_idx][:].tolist() output_tokens = sampling_metadata.output_token_ids[batch_idx] From 5cac3e1b3b800226c1e3f474a08eec8e85635628 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 04:44:37 -0800 Subject: [PATCH 11/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 9ed1d5ddeaa4b..e8731102077bb 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -24,9 +24,9 @@ def __init__(self): self.forward = self.forward_cuda else: logger.warning( - "flashinfer.sampling is not available. Falling back to " - "Pytorch-native implementation of sampling. For the best " - "performance, please install FalshInfer.") + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of top-p & top-k sampling. For the " + "best performance, please install FalshInfer.") self.forward = self.forward_native else: self.forward = self.forward_native From 8061a16a07d9de94811be27b45fb7a9d15e3062a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 11:40:37 -0800 Subject: [PATCH 12/16] comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 29 ++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e8731102077bb..9ec7fb79b130b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -40,6 +40,7 @@ def forward_native( no_top_p: bool, p: torch.Tensor, ) -> torch.Tensor: + """PyTorch-native implementation of top-k and top-p sampling.""" logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) @@ -53,8 +54,12 @@ def forward_cuda( no_top_p: bool, p: torch.Tensor, ) -> torch.Tensor: + """More optimized implementation for top-k and top-p sampling.""" probs = logits.softmax(dim=-1, dtype=torch.float32) if no_top_k and no_top_p: + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. return random_sample(probs, generators) return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) @@ -66,6 +71,10 @@ def apply_top_k_top_p( no_top_p: bool, p: torch.Tensor, ) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + This function sorts the logits tensor, which can be slow for large batches. + """ if no_top_k and no_top_p: return logits logits_sort, logits_idx = logits.sort(dim=-1, descending=False) @@ -96,13 +105,17 @@ def random_sample( probs: torch.Tensor, generators: Dict[int, torch.Generator], ) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-GPU synchronization. + """ q = torch.empty_like(probs) # NOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. if len(generators) != probs.shape[0]: - # This might still be done here unnecessarily if there are greedies q.exponential_() if generators: # TODO(woosuk): This can be slow because we handle each request @@ -120,6 +133,20 @@ def flashinfer_sample( p: torch.Tensor, generators: Dict[int, torch.Generator], ) -> torch.Tensor: + """Sample from the probabilities using FlashInfer. + + Statistically, this function is equivalent to the `random_sample` function. + However, this function is faster because it avoids sorting the logits tensor + via rejection sampling. + + NOTE: The outputs of this function do not necessarily match the outputs of + the `random_sample` function. It only guarantees that the outputs are + statistically equivalent. + + NOTE: This function includes CPU-GPU synchronization, while `random_sample` + does not. Call this function at the end of the forward pass to minimize + the synchronization overhead. + """ assert not (no_top_k and no_top_p) max_top_k_round = 32 batch_size = probs.shape[0] From e968e18da750bab666fec293c10a83c89bb8d48d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 11:48:33 -0800 Subject: [PATCH 13/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 9ec7fb79b130b..eb75e7b82d217 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -10,9 +10,9 @@ try: import flashinfer.sampling - use_flashinfer = True + is_flashinfer_available = True except ImportError: - use_flashinfer = False + is_flashinfer_available = False class TopKTopPSampler(nn.Module): @@ -20,7 +20,8 @@ class TopKTopPSampler(nn.Module): def __init__(self): super().__init__() if current_platform.is_cuda: - if use_flashinfer: + if is_flashinfer_available: + logger.info("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda else: logger.warning( From 0f784a5c9daf3f4d969e2c1bd0cfe42c8c8482da Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 11:57:23 -0800 Subject: [PATCH 14/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 50 ++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 6d0a367aa1009..1e38453a0ff28 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,4 +1,6 @@ """A layer that samples the next tokens from the model's outputs.""" +from typing import Tuple + import torch import torch.nn as nn @@ -24,12 +26,21 @@ def forward( ) -> SamplerOutput: needs_logprobs = sampling_metadata.max_num_logprobs > 0 if needs_logprobs: - # NOTE: We need to clone the tensor because the below ops may - # modify the logits tensor in-place. - orig_logits = logits.clone() + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # NOTE: We compute logprobs first because the below ops may + # modify the logits tensor in-place (and we don't want to clone + # the logits tensor for memory efficiency). + topk_logprobs, topk_indices = self.get_topk_logprobs( + logits, sampling_metadata) + else: + topk_logprobs = None + topk_indices = None + # Use float32 for the logits. logits = logits.to(torch.float32) - # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) # Apply temperature. @@ -39,22 +50,6 @@ def forward( # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) - if needs_logprobs: - # NOTE(woosuk): Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. - # This is different from the V0 sampler, which uses the logits that - # is used for sampling (after penalties and temperature scaling). - logprobs = self.get_logprobs(orig_logits) - # FIXME: Mask the sampled token_id, get topk logprobs, - # and concatenate the topk with the sampled token_id. - topk_logprobs, topk_indices = torch.topk( - logprobs, sampling_metadata.max_num_logprobs, dim=-1) - # Use int32 to reduce the tensor size. - topk_indices = topk_indices.to(torch.int32) - else: - topk_logprobs = None - topk_indices = None - # NOTE: CPU-GPU synchronization happens here. sampler_output = SamplerOutput( sampled_token_ids=sampled.tolist(), @@ -108,8 +103,19 @@ def sample( ) return sampled - def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: - return torch.log_softmax(logits, dim=-1, dtype=torch.float32) + def get_topk_logprobs( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) + # FIXME: Mask the sampled token_id, get topk logprobs, + # and concatenate the topk with the sampled token_id. + topk_logprobs, topk_indices = torch.topk( + logprobs, sampling_metadata.max_num_logprobs, dim=-1) + # Use int32 to reduce the tensor size. + topk_indices = topk_indices.to(torch.int32) + return topk_logprobs, topk_indices def apply_penalties( self, From 6bea16658c797ed11531a7eff8d492f3aeec3ef7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 16:09:59 -0800 Subject: [PATCH 15/16] Consider VLLM_USE_FLASHINFER_SAMPLER Signed-off-by: Woosuk Kwon --- vllm/envs.py | 5 +++-- vllm/v1/sample/ops/topk_topp_sampler.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 18870c1c6b51a..c4a568c680db0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -30,7 +30,7 @@ VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None - VLLM_USE_FLASHINFER_SAMPLER: bool = False + VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None @@ -277,7 +277,8 @@ def get_default_config_root(): # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), + lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) + if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, # If set, vllm will force flashinfer to use tensor cores; # otherwise will use heuristic based on model architecture. diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index eb75e7b82d217..0c2224e3d2b16 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform @@ -20,7 +21,15 @@ class TopKTopPSampler(nn.Module): def __init__(self): super().__init__() if current_platform.is_cuda: - if is_flashinfer_available: + if (is_flashinfer_available + and envs.VLLM_USE_FLASHINFER_SAMPLER is not False): + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + # default it is unused). For backward compatibility, we set + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and interpret + # it differently in V0 and V1 samplers: In V0, None means False, + # while in V1, None means True. This is why we use the condition + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. logger.info("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda else: From 68ffc9630b7f1acb3d8f9207d0e4fd8a1bd9f40c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 26 Dec 2024 16:12:57 -0800 Subject: [PATCH 16/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 30 ++++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 0c2224e3d2b16..c088c3c129ca5 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -21,17 +21,25 @@ class TopKTopPSampler(nn.Module): def __init__(self): super().__init__() if current_platform.is_cuda: - if (is_flashinfer_available - and envs.VLLM_USE_FLASHINFER_SAMPLER is not False): - # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for - # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by - # default it is unused). For backward compatibility, we set - # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and interpret - # it differently in V0 and V1 samplers: In V0, None means False, - # while in V1, None means True. This is why we use the condition - # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - logger.info("Using FlashInfer for top-p & top-k sampling.") - self.forward = self.forward_cuda + if is_flashinfer_available: + if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + # default it is unused). For backward compatibility, we set + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and + # interpret it differently in V0 and V1 samplers: In V0, + # None means False, while in V1, None means True. This is + # why we use the condition + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. + logger.info("Using FlashInfer for top-p & top-k sampling.") + self.forward = self.forward_cuda + else: + logger.warning( + "FlashInfer is available, but it is not enabled. " + "Falling back to the PyTorch-native implementation of " + "top-p & top-k sampling. For the best performance, " + "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + self.forward = self.forward_native else: logger.warning( "FlashInfer is not available. Falling back to the PyTorch-"