diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py new file mode 100644 index 0000000000000..6a1819e990f44 --- /dev/null +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -0,0 +1,268 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, EAGLE would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + +# main model +MAIN_MODEL = "JackFram/llama-68m" + +# speculative model +SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 4 + +# precision +PRECISION = "float32" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that eagle speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that eagle speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7e4a6cc62d02b..de4b2ab796a3c 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -70,8 +70,9 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_medusa_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -116,10 +160,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size: int, - output_len: int): +def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ @@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_different_k(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality +def test_medusa_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that medusa speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ run_greedy_equality_correctness_test(baseline_llm_generator, @@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality +def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that medusa speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. """ diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 442e40f07f0bb..ada6c37d9af8d 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -6,7 +6,8 @@ import torch from vllm.model_executor.utils import set_random_seed -from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, + SamplerOutput, get_all_seq_ids) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer @@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step(): worker.execute_model(execute_model_req=execute_model_req) call_args_list = worker.model_runner._gpu_advance_step.call_args_list assert len(call_args_list) == 1 + + +@torch.inference_mode() +def test_expand_execute_model_request_sync_with_expand_hidden_states(): + """ + In this test we verify that the logic for expanding the + seq_group_metadata_list remains in sync with the expansion logic of + the HiddenStates in _expand_execute_model_request. + """ + k = 5 + batch_size = 16 + seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15] + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + + execute_model_request = ExecuteModelRequest( + seq_group_metadata_list, + previous_hidden_states=HiddenStates( + torch.arange(batch_size), seq_group_metadata_list, + torch.arange(batch_size, 2 * batch_size))) + + expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\ + _expand_execute_model_request(execute_model_request, + seq_with_bonus_token_in_last_step) + + all_seq_ids = torch.tensor( + get_all_seq_ids( + expanded_execute_model_request.seq_group_metadata_list)) + ref_expanded_hidden_states = all_seq_ids + batch_size + ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size + + assert (ref_expanded_hidden_states == expanded_execute_model_request. + previous_hidden_states.hidden_states).all().item() diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index bdf6e502ea112..8591c276b0013 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -60,6 +60,7 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MedusaModel": ("medusa", "Medusa"), + "EAGLEModel": ("eagle", "EAGLE"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), } diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py new file mode 100644 index 0000000000000..99c825ff63572 --- /dev/null +++ b/vllm/model_executor/models/eagle.py @@ -0,0 +1,161 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.transformers_utils.configs.eagle import EAGLEConfig + + +class EAGLE(nn.Module): + """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 + Reference implementation: https://github.com/SafeAILab/EAGLE + + Differences from reference implementation: + 1. In reference, LlamaDecoderLayer implementation doesn't have + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427) + but we do as HF implementation also does. + 2. We allow any decoder layer to be used in EAGLE whereas in reference + decoder layer is fixed to be LlamaDecoderLayer. + 3. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" + + def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: + super().__init__() + self.config = config + + architectures = getattr(self.config.model, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + + self.model = model_cls(self.config.model, *args, **kwargs) + self.fc = nn.Linear(config.model.hidden_size * 2, + config.model.hidden_size, + bias=False) + + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + @property + def sampler(self): + return self.model.sampler + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + + tok_embeds = self.model.model.embed_tokens(input_ids) + inputs_embeds = self.fc( + torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + + hidden_states = self.model.model( + input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + if self.token_map is not None: + _logits = logits + logits = -torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype) + + logits[..., self.token_map] = _logits + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B + # due to missing lm_head weights and its config being that of a + # Llama model. Here's a compatible version with the same weights: + # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm + # Also, here's an example script for converting trained EAGLE + # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d + model_weights = {} + for name, loaded_weight in weights: + if name == "token_map": + if self.config.truncated_vocab_size < self.config.vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name.startswith("fc."): + weight_loader = getattr(self.fc.weight, "weight_loader", + default_weight_loader) + weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("model.lm_head.") or name.startswith( + "model.model."): + model_weights[name.split("model.", 1)[-1]] = loaded_weight + elif name.startswith("lm_head.") or name.startswith("model."): + model_weights[name] = loaded_weight + else: + model_weights[f"model.{name}"] = loaded_weight + + lm_head_weight = model_weights.pop("lm_head.weight") + + if self.token_map is not None and\ + lm_head_weight.shape[0] > self.token_map.shape[0]: + + lm_head_weight = lm_head_weight[self.token_map] + + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, lm_head_weight) + + self.model.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index c2a61ca52011e..55d42952cd0cc 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -30,6 +30,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Medusa(nn.Module): + """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 + Reference implementation: https://github.com/FasterDecoding/Medusa + + Differences from reference implementation: + 1. Currently this only supports generating proposals from top-1 tokens. + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, config: MedusaConfig, **_) -> None: super().__init__() @@ -57,6 +70,12 @@ def __init__(self, config: MedusaConfig, **_) -> None: self.truncated_vocab_size, logit_scale) + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. self.token_map = None def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: diff --git a/vllm/sequence.py b/vllm/sequence.py index 206da192193dc..2fe8ae9d7b270 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1092,6 +1092,10 @@ class SamplerOutput( # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + # Time taken in the forward pass for this across all workers model_forward_time: Optional[float] = None @@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True, omit_defaults=True): # type: ignore[call-arg] """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from - the target model to the proposer model in the subsequent step. + the target model to the proposer model. seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - - seq_group_metadata_list: List[SequenceGroupMetadata] + # Scorer hidden states. For prefill step, it is used for hidden states of + # all tokens, whereas for decode step, it use used for last accepted tokens. hidden_states: torch.Tensor + # The sequence group metadata list. Only needed for decode step. + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + # Scorer hidden states of the 2nd last token proposed by the proposer ( + # irrespective of whether it was accepted or not). Only used for cases when + # last proposed token is accepted (i.e., in case of bonus tokens). For the + # case of no bonus tokens, these are ignored. + second_last_token_hidden_states: Optional[torch.Tensor] = None + _seq_ids: List[int] = msgspec.field(default_factory=list) def __post_init__(self): - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - assert len(self.seq_group_metadata_list) == len(self.hidden_states) + if self.seq_group_metadata_list is not None: + assert len(self.seq_group_metadata_list) == len(self.hidden_states) + self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property def seq_ids(self) -> List[int]: return self._seq_ids - def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor) -> None: - """Update hidden states from target model invocation.""" + def update(self, + hidden_states: torch.Tensor, + seq_group_metadata_list: List[SequenceGroupMetadata], + second_last_token_hidden_states: Optional[torch.Tensor] = None): + """Update hidden states from target model invocation. Only used for + decode steps""" assert len(seq_group_metadata_list) == len(hidden_states) self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + if self.second_last_token_hidden_states is not None: + # Adding dummy hidden_states to this to maintain same shape + self.second_last_token_hidden_states = torch.cat([ + self.second_last_token_hidden_states, + torch.zeros_like(hidden_states) + if second_last_token_hidden_states is None else + second_last_token_hidden_states + ]) + def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids.""" + """Prune to provided list of sequence ids. Only used for decode steps. + """ + # Currently this prunes all seq_ids not present in + # seq_group_metadata_list which might cause problems where a sequence + # may be "paused" then "resumed" later. This should only prune sequences + # which are confirmed to be aborted. seq_ids = get_all_seq_ids(seq_group_metadata_list) if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] self._seq_ids = seq_ids + def expand_with_bonus_tokens( + self, seq_with_bonus_token_in_last_step: set) -> None: + """Expand hidden states for sequences with bonus tokens. This is in + alignment with `MultiStepWorker._expand_execute_model_request`.""" + if self.second_last_token_hidden_states is None \ + or not seq_with_bonus_token_in_last_step: + return + + index = [] + for seq_id in self._seq_ids: + i = self._seq_ids.index(seq_id) + if seq_id in seq_with_bonus_token_in_last_step: + index.append(i + len(self._seq_ids)) + index.append(i) + + self.hidden_states = torch.cat( + [self.hidden_states, self.second_last_token_hidden_states])[index] + class ExecuteModelRequest( msgspec.Struct, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 053e9203e01eb..aedf0a83da07d 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -203,6 +203,7 @@ def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: @@ -280,13 +281,30 @@ def execute_model( graph_batch_size = model_input.input_tokens.shape[0] model_executable = (self.graph_runners[model_input.virtual_engine] [graph_batch_size]) + + if previous_hidden_states is not None: + hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + hidden_states = None else: model_executable = self.model + hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} + kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, @@ -296,6 +314,7 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, device=self.device), + **kwargs, ) # Compute the logits. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 65bfb5dc8d5c6..2dfbacfb7b759 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -4,8 +4,8 @@ import torch -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, + SequenceData, SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) @@ -157,6 +157,12 @@ def _expand_execute_model_request( updated_execute_model_req.seq_group_metadata_list =\ updated_seq_group_metadata_list + + if isinstance(updated_execute_model_req.previous_hidden_states, + HiddenStates): + updated_execute_model_req.previous_hidden_states\ + .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) + return updated_execute_model_req, indices_of_original_sequence_groups @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index acf77a7349eef..2762b8388029f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -147,6 +147,11 @@ def create_worker( draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: + if draft_worker_kwargs[ + "model_config"].hf_config.model_type == "eagle": + raise NotImplementedError( + "EAGLE does not support TP > 1 yet") + allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) @@ -355,14 +360,34 @@ def execute_model( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. + no_spec = num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation + # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. # This is required as if the number of draft model runs changes # dynamically, the non-driver workers won't know unless we perform a # communication to inform them. + + # no_spec is used to signal non-driver worker about prefill vs decode + # stage. This is needed to ensure that order of execution of proposer + # and scorer is same in both driver and non-driver workers (i.e., + # scorer -> proposer for prefill and proposer -> scorer in decode). This + # order is needed to support models like EAGLE that take scorer states + # as inputs. broadcast_dict = dict( num_lookahead_slots=num_lookahead_slots, + no_spec=no_spec, disable_all_speculation=disable_all_speculation, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -373,17 +398,7 @@ def execute_model( self._maybe_disable_speculative_tokens( disable_all_speculation, execute_model_req.seq_group_metadata_list) - # Speculative decoding is disabled in the following cases: - # 1. Prefill phase: Speculative decoding is not - # used during the prefill phase. - # 2. Auto-disable enabled: The running queue size exceeds - # the specified threshold. - # 3. No request: There are no requests in the batch. - # In any of these cases, the proposer and scorer workers - # are called normally. - if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list - ) == 0 or disable_all_speculation: + if no_spec: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) return self._run_speculative_decoding_step(execute_model_req, @@ -464,8 +479,6 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - if not skip_proposer: - self.proposer_worker.execute_model(execute_model_req) sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 @@ -476,10 +489,20 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, if hidden_states is not None: if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( - execute_model_req.seq_group_metadata_list, hidden_states) + hidden_states, execute_model_req.seq_group_metadata_list) else: self.previous_hidden_states.update( - execute_model_req.seq_group_metadata_list, hidden_states) + hidden_states, execute_model_req.seq_group_metadata_list) + + if not skip_proposer: + # We prepare the prefill hidden states here so that there no + # additional complexity in worker for spec_decode vs non_spec_decode + # flow and execute_model doesn't need additional modifications. + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states( + sampler_output.prefill_hidden_states) + + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) @@ -507,15 +530,23 @@ def _run_non_driver_rank(self) -> bool: return False num_lookahead_slots = data["num_lookahead_slots"] - # Even if num_lookahead_slots is zero, we want to run the proposer model - # as it may have KV. - # - # We run the proposer once per lookahead slot. In the future we should - # delegate how many times it runs to the proposer. - for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model() + # In case of prefill, scorer_worker has to be run before proposer so + # that the hidden states can be propagated to proposer when needed. + if data["no_spec"]: + self.scorer_worker.execute_model() + + if not data["disable_all_speculation"]: + # Even if num_lookahead_slots is zero, we want to run the + # proposer model as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we + # should delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + + if not data["no_spec"]: + self.scorer_worker.execute_model() - self.scorer_worker.execute_model() return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") @@ -546,6 +577,8 @@ def _run_speculative_decoding_step( raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") + execute_model_req.previous_hidden_states = None + with Timer() as scoring_timer: proposal_scores = self.scorer.score_proposals( execute_model_req, @@ -651,10 +684,12 @@ def _verify_tokens( accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) + second_last_token_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates(seq_group_metadata_list, - hidden_states) + self.previous_hidden_states = HiddenStates( + hidden_states, seq_group_metadata_list, + second_last_token_hidden_states) return accepted_token_ids, logprobs @@ -951,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) return new_num_gpu_blocks + + +def prepare_prefill_hidden_states( + prefill_hidden_states: torch.Tensor) -> HiddenStates: + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + return HiddenStates(prefill_hidden_states.roll( + shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 0f86b02deb21a..c2276b075c1dd 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -11,10 +11,11 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - InternVLChatConfig, JAISConfig, - MedusaConfig, MLPSpeculatorConfig, - MPTConfig, NemotronConfig, - RWConfig, UltravoxConfig) + EAGLEConfig, InternVLChatConfig, + JAISConfig, MedusaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, RWConfig, + UltravoxConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -32,6 +33,7 @@ "jais": JAISConfig, "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, + "eagle": EAGLEConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 22b906a3149ec..dc2fd6a859e3c 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,5 +1,6 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig +from vllm.transformers_utils.configs.eagle import EAGLEConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -20,6 +21,7 @@ "InternVLChatConfig", "JAISConfig", "MedusaConfig", + "EAGLEConfig", "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py new file mode 100644 index 0000000000000..b357a785e4dc4 --- /dev/null +++ b/vllm/transformers_utils/configs/eagle.py @@ -0,0 +1,49 @@ +import os +from typing import Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class EAGLEConfig(PretrainedConfig): + model_type = "eagle" + + def __init__(self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + model_config = None if model is None else (AutoConfig.for_model( + **model) if isinstance(model, dict) else model) + + for k, v in kwargs.items(): + if k != "architectures" and k != "model_type" and hasattr( + model_config, k): + setattr(model_config, k, v) + + self.model = model_config + + if self.model is None: + self.truncated_vocab_size = None + else: + self.truncated_vocab_size = self.model.vocab_size if \ + truncated_vocab_size is None else truncated_vocab_size + + if "architectures" not in kwargs: + kwargs["architectures"] = ["EAGLEModel"] + + super().__init__(**kwargs) + + if self.model is not None: + for k, v in self.model.to_dict().items(): + if not hasattr(self, k): + setattr(self, k, v) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "EAGLEConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 793f03456e997..5d930919b8ae5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,6 @@ import dataclasses import gc +import inspect import itertools import time import warnings @@ -1192,6 +1193,18 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + + # Prepare dummy previous_hidden_states only if needed by the model. + # This is used by draft models such as EAGLE. + previous_hidden_states = None + if "previous_hidden_states" in inspect.signature( + self.model.forward).parameters: + previous_hidden_states = torch.empty( + [max_batch_size, + self.model_config.get_hidden_size()], + dtype=self.model_config.dtype, + device=self.device) + intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1264,6 +1277,11 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "stream": graph_capture_context.stream } + if previous_hidden_states is not None: + capture_inputs[ + "previous_hidden_states"] = previous_hidden_states[: + batch_size] + if self.has_seqlen_agnostic: # Only used by Mamba-based models CUDA graph atm (Jamba) capture_inputs.update({ @@ -1462,6 +1480,7 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) + output.prefill_hidden_states = hidden_or_intermediate_states elif decode_meta.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] else: @@ -1510,11 +1529,11 @@ def capture( # Note one iteration is not enough for torch.jit.script for _ in range(_NUM_WARMUP_ITERS): self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_inputs, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, **kwargs, ) torch.cuda.synchronize() @@ -1523,11 +1542,11 @@ def capture( self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_inputs, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, **kwargs, ) if hidden_or_intermediate_states is not None: @@ -1588,6 +1607,11 @@ def forward( if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) + + if "previous_hidden_states" in self.input_buffers: + self.input_buffers["previous_hidden_states"].copy_( + kwargs["previous_hidden_states"], non_blocking=True) + if intermediate_tensors is not None: for key in intermediate_tensors.tensors: if key != "model_execute_time" and key != "model_forward_time": diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 6a6caba9371eb..2ed77dd698f5c 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple + +import torch from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -43,7 +45,7 @@ def __init__(self, *args, **kwargs): def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ @@ -85,7 +87,9 @@ def _get_driver_input_and_broadcast( broadcast_data.update(model_input.as_broadcastable_tensor_dict()) broadcast_tensor_dict(broadcast_data, src=0) - return model_input, worker_input + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} def _prepare_last_sampled_token_ids_for_tp_workers( self, @@ -130,7 +134,8 @@ def _prepare_last_sampled_token_ids_for_tp_workers( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: + ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, + torch.Tensor]]]: """ Depending on the current state of the request and multi step worker, this method may skip the normal _prepare_model_input and @@ -148,8 +153,8 @@ def prepare_input( return None virtual_engine = execute_model_req.virtual_engine - model_input, worker_input = self._get_driver_input_and_broadcast( - execute_model_req) + (model_input, worker_input, + kwargs) = self._get_driver_input_and_broadcast(execute_model_req) assert isinstance(model_input, StatefulModelInput) if execute_model_req.is_first_multi_step: # cache the worker input and model input for the next steps @@ -162,7 +167,7 @@ def prepare_input( # loop if broadcast_data is None: return None - model_input, worker_input = broadcast_data + model_input, worker_input, kwargs = broadcast_data assert isinstance(model_input, StatefulModelInput) virtual_engine = worker_input.virtual_engine if model_input.is_first_multi_step: @@ -186,4 +191,4 @@ def prepare_input( assert model_input is not None assert worker_input is not None - return model_input, worker_input + return model_input, worker_input, kwargs diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 331a805caba9a..7ed609c3b447c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -86,7 +86,7 @@ def __init__( or (speculative_config.draft_model_config.model == model_config.model) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator"]) \ + not in ["medusa", "mlp_speculator", "eagle"]) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 9fddc863548eb..516e386595195 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -222,7 +222,9 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def _get_worker_input_from_broadcast( - self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: + self + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker @@ -235,11 +237,13 @@ def _get_worker_input_from_broadcast( self.model_runner.make_model_input_from_broadcasted_tensor_dict( broadcast_data)) - return model_input, worker_input + kwargs = extract_previous_hidden_states(broadcast_data) + + return model_input, worker_input, kwargs def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -251,17 +255,21 @@ def _get_driver_input_and_broadcast( execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + kwargs = extract_previous_hidden_states(execute_model_req) + if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_data.update(kwargs) broadcast_tensor_dict(broadcast_data, src=0) - return model_input, worker_input + return model_input, worker_input, kwargs def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: """ Prepare the inputs to ModelRunner and workers. """ @@ -291,7 +299,7 @@ def execute_model( if inputs is None: return None - model_input, worker_input = inputs + model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps self.execute_worker(worker_input) @@ -312,9 +320,14 @@ def execute_model( "model_execute_time", torch.tensor(0)).item() output = self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, intermediate_tensors, - num_steps) + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) + model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: # output is IntermediateTensors @@ -360,9 +373,15 @@ def _execute_model_spmd( if worker_input.num_seq_groups == 0: return [] + kwargs = extract_previous_hidden_states(execute_model_req) + return self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, intermediate_tensors) + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) class WorkerWrapperBase: @@ -439,3 +458,23 @@ def execute_method(self, method, *args, **kwargs): "This might cause deadlock in distributed execution.") logger.exception(msg) raise e + + +def extract_previous_hidden_states( + data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ + Dict[str, torch.Tensor]: + """If data contains previous_hidden_states, extract it. This returns a dict + which can be used directly as additional kwargs in any following + execute_model calls. This is used in draft models like EAGLE.""" + output = {} + + # When called from non-driver worker, data is dict but when called from + # driver worker, data is ExecuteModelRequest. + if isinstance(data, dict): + if "previous_hidden_states" in data: + output["previous_hidden_states"] = data["previous_hidden_states"] + elif data.previous_hidden_states is not None: + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states + + return output