Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce delayed sampling mechanism #84

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def __init__(
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
enable_delayed_sampling: bool = False,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -626,6 +627,7 @@ def __init__(
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.enable_delayed_sampling = enable_delayed_sampling

self._verify_args()

Expand All @@ -652,6 +654,15 @@ def _verify_args(self) -> None:
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")

if self.enable_delayed_sampling and self.num_lookahead_slots != 1:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be 1 for delayed sampling.")

if self.enable_delayed_sampling and not self.use_v2_block_manager:
raise ValueError(
"use_v2_block_manager "
f"({self.use_v2_block_manager}) must be True for delayed sampling.")

class DeviceConfig:

Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class EngineArgs:
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
enable_delayed_sampling: bool = False

guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
Expand Down Expand Up @@ -451,6 +452,13 @@ def add_cli_args(
action='store_true',
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument(
'--enable-delayed-sampling',
action='store_true',
help='If set, the sampling will be delayed by 1 step. First '
'model request execution (prefill) will return an invalid token '
'id that will be discarded. Actual sampling of valid token ids '
'starts from second model execution.')

parser.add_argument(
'--speculative-model',
Expand Down Expand Up @@ -571,6 +579,7 @@ def create_engine_config(self, ) -> EngineConfig:
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
enable_delayed_sampling=self.enable_delayed_sampling,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def create_output_processor(
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
if (scheduler_config.num_lookahead_slots == 0 or (scheduler_config.num_lookahead_slots == 1
and scheduler_config.enable_delayed_sampling)):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
Expand Down
7 changes: 3 additions & 4 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ def process_outputs(self, sequence_group: SequenceGroup,
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples

self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
if valid_samples:
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)

def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
Expand Down
9 changes: 6 additions & 3 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,12 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
# -1 means the output token is not valid (eg. first token if
# delayed sampling is enabled).
if last_child_sample.output_token != -1:
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize and self.detokenizer:
Expand Down
16 changes: 11 additions & 5 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
self.sample_token_positions_only = False

def forward(
self,
Expand Down Expand Up @@ -100,6 +101,7 @@ def forward(
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
token_positions_only=self.sample_token_positions_only,
)

if self.include_gpu_probs_tensor:
Expand Down Expand Up @@ -285,6 +287,7 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
token_positions_only: bool = False,
) -> SampleResultType:
"""Run greedy sampling on a given samples.

Expand All @@ -298,7 +301,8 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples = samples.tolist()
if not token_positions_only:
samples = samples.tolist()
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
Expand All @@ -311,7 +315,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx]]
next_token_ids = [sample_idx if token_positions_only else samples[sample_idx]]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
Expand Down Expand Up @@ -468,6 +472,7 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
token_positions_only: bool = False,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
Expand Down Expand Up @@ -545,14 +550,14 @@ def _sample_with_torch(
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")

# GPU<->CPU sync happens in the loop below.
# GPU<->CPU sync happens in the loop below, unless we're storing only token positions (token_positions_only=True)
# This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
sample_results = _greedy_sample(seq_groups, greedy_samples, token_positions_only)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
Expand Down Expand Up @@ -651,7 +656,7 @@ def _sample_with_triton_kernel(
def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
include_gpu_probs_tensor: bool, modify_greedy_probs: bool, token_positions_only: bool
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
Expand All @@ -671,6 +676,7 @@ def _sample(
sampling_metadata,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
token_positions_only=token_positions_only,
)

# TODO: Enable once Triton kernel & associated code is faster.
Expand Down
4 changes: 2 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __init__(
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
self.prev_logits = None
self.prev_logits_idx = None

def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
Expand All @@ -148,8 +150,6 @@ def get_num_computed_tokens(self) -> int:
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down
98 changes: 90 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,14 @@ def load_model(self) -> None:
htcore.mark_step()
torch.hpu.synchronize()

if self.scheduler_config.enable_delayed_sampling:
self.model.sampler.include_gpu_probs_tensor = True
self.model.sampler.sample_token_positions_only = True

# FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
with HabanaMemoryProfiler() as m_wrap:
self.model = _maybe_wrap_in_hpu_graph(self.model)
logger.info(f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}")

self.model_memory_usage = m.consumed_device_memory
logger.info(f"Loading model weights took in total {m.get_summary_string()}")

Expand Down Expand Up @@ -576,7 +579,8 @@ def _prepare_decode(
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])

seq_len = seq_data.get_len()
seq_len = ((seq_data.get_num_computed_tokens() + 1)
if self.scheduler_config.enable_delayed_sampling else seq_data.get_len())
position = seq_len - 1
input_positions.append([position])

Expand Down Expand Up @@ -873,30 +877,108 @@ def execute_model(
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update({"bypass_hpu_graphs":not use_graphs, "warmup_mode":warmup_mode})
htorch.core.mark_step()
# Sample the next token based on previous logits if any.
if self.scheduler_config.enable_delayed_sampling and not is_prompt:
logits_ids_list = []
logits_tensor = None
logits_tensor_list = []
for seq_group_metadata in seq_group_metadata_list:
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
if seq_data.prev_logits is not None:
if logits_tensor is None:
logits_tensor = seq_data.prev_logits
if seq_data.prev_logits is logits_tensor:
# accumulate row ids from the same tensor
logits_ids_list.append(seq_data.prev_logits_idx)
else:
# new logits tensor, gather all previously collected rows
logits_tensor_list.append(logits_tensor[torch.tensor(logits_ids_list, device=seq_data.prev_logits.device)])
logits_ids_list = [seq_data.prev_logits_idx]
logits_tensor = seq_data.prev_logits
else:
# warmup only, TODO add a check
logits_tensor_list.append(torch.zeros([1, 32000], dtype=torch.float, device="hpu"))
if logits_tensor is not None:
logits_tensor_list.append(logits_tensor[torch.tensor(logits_ids_list, device=seq_data.prev_logits.device)])

prev_logits = torch.cat(logits_tensor_list, dim=0)

with self.profiler.record_event('internal', f'sample_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):
output = self.model.sample(
logits=prev_logits,
sampling_metadata=sampling_metadata,
)

execute_model_kwargs["input_ids"] = output.sampled_token_ids
htorch.core.mark_step()

if self.is_driver_worker:
model_event_name = f"model_{'prompt' if is_prompt else 'decode'}_bs{batch_size}_seq{seq_len}_graphs{'T' if use_graphs else 'F'}"
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices)

if self.scheduler_config.enable_delayed_sampling:
if not is_prompt:
htorch.core.mark_step()
# Only after dispatching next model.forward() read and update the previous token ids to return
sampled_token_ids = output.sampled_token_ids.tolist()
for seq_group_output in output.outputs[:real_batch_size]:
for sample in seq_group_output.samples:
sample.output_token = sampled_token_ids[sample.output_token][0]
output = output
else:
# For prompts compose empty output
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupOutput, SequenceOutput)
sampler_output = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
next_token_id, parent_id = -1, 0
seq_outputs = []
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, {-1: Logprob(0.0)}))
sampler_output.append(
SequenceGroupOutput(seq_outputs, None))

sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None)
output = SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
)

output.outputs = output.outputs[:real_batch_size]
htorch.core.mark_step()

# Compute the logits.
with self.profiler.record_event('internal', f'compute_logits_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states, sampling_metadata)

if self.scheduler_config.enable_delayed_sampling:
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
seq_data.prev_logits = logits
seq_data.prev_logits_idx = idx

htorch.core.mark_step()

# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None

# Sample the next token.
with self.profiler.record_event('internal', f'sample_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
output.outputs = output.outputs[:real_batch_size]
if not self.scheduler_config.enable_delayed_sampling:
with self.profiler.record_event('internal', f'sample_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
output.outputs = output.outputs[:real_batch_size]
htorch.core.mark_step()

if self.is_driver_worker and self.profiler.enabled:
Expand Down