diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 44b22c2bd8a21..01066ef796d67 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,20 +1,16 @@ -import random +import pytest import torch from vllm.config import ModelConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT +from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size -def get_aligned_size(batch_size: int, alignment: int): - return ((batch_size + alignment - 1) // alignment * alignment) - - -def test_prepare_prompt(): +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +def test_prepare_prompt(batch_size): model_runner = ModelRunner(None, None, None, None, None) model_runner.set_block_size(16) - batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} @@ -111,7 +107,8 @@ def test_prepare_prompt(): torch.testing.assert_close(actual, expected) -def test_prepare_decode_cuda_graph(): +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +def test_prepare_decode_cuda_graph(batch_size): model_config = ModelConfig( "facebook/opt-125m", "facebook/opt-125m", @@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph(): model_runner = ModelRunner(model_config, None, None, None, None) model_runner.set_block_size(16) - batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] for i in range(batch_size): @@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph(): input_tokens, input_positions, input_metadata, _, _, _ = ( model_runner._prepare_decode(seq_group_metadata_list)) + expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert input_metadata.is_prompt is False assert input_metadata.prompt_lens is None assert input_metadata.num_prompt_tokens == 0 - assert input_metadata.num_generation_tokens == (get_aligned_size( - len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT)) + assert input_metadata.num_generation_tokens == expected_bs assert input_metadata.max_seq_len is None assert input_metadata.subquery_start_loc is None assert input_metadata.seq_start_loc is None @@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph(): assert input_metadata.use_cuda_graph is True assert input_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (get_aligned_size( - len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) - assert input_positions.shape == (get_aligned_size( - len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) + assert input_tokens.shape == (expected_bs, ) + assert input_positions.shape == (expected_bs, ) torch.testing.assert_close(input_tokens, input_positions) # Verify Sampling