-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)
- Loading branch information
1 parent
b3856be
commit a3fce56
Showing
17 changed files
with
854 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
"""This docstring details important information on the testing methodology. | ||
Most of the tests rely on "greedy equality", where we expect the output of | ||
speculative decoding on a sequence to exactly match the output of normal non- | ||
speculative decoding. | ||
Since speculative decoding with rejection sampling guarantees that the output | ||
distribution matches the target model's output distribution (up to hardware | ||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy | ||
equality. | ||
However, we still need to verify below scenario could be passed: | ||
* Batch size 1 greedy equality | ||
* Batch size >1 greedy equality | ||
* Test greedy equality under preemption | ||
* Test greedy equality under various number of speculative tokens. | ||
With those tests, we can say at least, EAGLE would not break the | ||
correctess for the target model outputs. | ||
""" | ||
|
||
import pytest | ||
|
||
from .conftest import run_greedy_equality_correctness_test | ||
|
||
# main model | ||
MAIN_MODEL = "JackFram/llama-68m" | ||
|
||
# speculative model | ||
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" | ||
|
||
# max. number of speculative tokens: this corresponds to | ||
# num_heads in the config.json of the speculator model. | ||
MAX_SPEC_TOKENS = 4 | ||
|
||
# precision | ||
PRECISION = "float32" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
# Print spec metrics. | ||
"disable_log_stats": False, | ||
# Precision | ||
"dtype": PRECISION, | ||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize("output_len", [ | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [1, 32]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_eagle_e2e_greedy_correctness(baseline_llm_generator, | ||
test_llm_generator, batch_size: int, | ||
output_len: int): | ||
"""Verify greedy equality with different batch size.""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
"enforce_eager": False, | ||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
# Print spec metrics. | ||
"disable_log_stats": False, | ||
# Precision | ||
"dtype": PRECISION, | ||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize("output_len", [ | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [1, 32]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size: int, | ||
output_len: int): | ||
"""Verify greedy equality with cuda graph enabled and different | ||
batch sizes.""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
"block_size": 8, | ||
# 2 for small prompt, 256//8 for generated. | ||
"num_gpu_blocks_override": 2 + 256 // 8, | ||
"max_model_len": (2 + 256 // 8) * 8, | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
# Precision | ||
"dtype": PRECISION, | ||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use small output len for fast test. | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [4]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size: int, | ||
output_len: int): | ||
"""Verify greedy equality, even when some sequences are preempted mid- | ||
generation. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
# Precision | ||
"dtype": PRECISION, | ||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize( | ||
"test_llm_kwargs", | ||
[ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": k, | ||
} | ||
# Try a range of num. speculative tokens | ||
for k in range(1, 1 + MAX_SPEC_TOKENS) | ||
]) | ||
@pytest.mark.parametrize("batch_size", [2]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use smaller output len for fast test. | ||
32, | ||
]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_eagle_different_k(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify that eagle speculative decoding produces exact equality | ||
to without spec decode with different values of num_speculative_tokens. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
# Precision | ||
"dtype": PRECISION, | ||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", | ||
[{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
"speculative_disable_by_batch_size": 4 | ||
}]) | ||
@pytest.mark.parametrize("batch_size", [1, 5]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use smaller output len for fast test. | ||
32, | ||
]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify that eagle speculative decoding produces exact equality | ||
to without spec decode when speculation is disabled for large | ||
batch sizes. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
pytest.main([__file__]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.