From 927f30010d4c301959584b2b08022e98b0359e14 Mon Sep 17 00:00:00 2001 From: Sirej Dua Date: Tue, 2 Jul 2024 07:20:29 -0700 Subject: [PATCH] [Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#6050) Co-authored-by: Sirej Dua Co-authored-by: Sirej Dua --- .../e2e/test_integration_dist_tp2.py | 36 ++++++++++++------- vllm/config.py | 6 ---- vllm/spec_decode/spec_decode_worker.py | 18 ++++++---- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 86fd88d57967c..ec593be983c9c 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -75,10 +75,6 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", [{ - # Use a small model for a fast test. - # Note this is repeated in the test body; to initialize a tokenizer. - "model": "JackFram/llama-68m", - # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -93,15 +89,31 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, # second run of the test to fail with internal NCCL error. "use_async": True, }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_tensor_parallel_size": 1, - }, -]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs, test_llm_kwargs", + [ + ( + { + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a + # tokenizer. + "model": "JackFram/llama-68m", + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_tensor_parallel_size": 1, + }), + ({ + "model": "ibm-granite/granite-3b-code-instruct", + }, { + "speculative_model": + "ibm-granite/granite-3b-code-instruct-accelerator", + "num_speculative_tokens": 5, + "speculative_draft_tensor_parallel_size": 1, + }) + ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) def test_draft_model_tp_lt_target_model_tp2(test_llm_generator, diff --git a/vllm/config.py b/vllm/config.py index 4deb19007ad57..5d16da27bdb6f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -989,12 +989,6 @@ def maybe_create_spec_config( ) draft_hf_config = draft_model_config.hf_config - if (draft_hf_config.model_type == "mlp_speculator" - and target_parallel_config.world_size != 1): - # MLPSpeculator TP support will be added very soon - raise ValueError( - "Speculative decoding with mlp_speculator models does not " - "yet support distributed inferencing (TP > 1).") if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ca470bee21c91..43ce987de1e16 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -113,24 +113,28 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) disable_bonus_tokens = True + if ngram_prompt_lookup_max > 0: disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) - elif draft_worker_kwargs[ - "model_config"].hf_config.model_type == "mlp_speculator": - proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) - disable_bonus_tokens = False else: draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'parallel_config'] draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size - if draft_tp == 1: - draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner - proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_worker_kwargs[ + "model_config"].hf_config.model_type == "mlp_speculator": + disable_bonus_tokens = False + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + else: + if draft_tp == 1: + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp)