Skip to content

Commit

Permalink
[CI/Build] fix flaky test (vllm-project#3602)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Mar 25, 2024
1 parent 3455b28 commit ac8029a
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import random
import pytest
import torch

from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size


def get_aligned_size(batch_size: int, alignment: int):
return ((batch_size + alignment - 1) // alignment * alignment)


def test_prepare_prompt():
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16)

batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
block_tables = {0: [1]}
Expand Down Expand Up @@ -111,7 +107,8 @@ def test_prepare_prompt():
torch.testing.assert_close(actual, expected)


def test_prepare_decode_cuda_graph():
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
Expand All @@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)

batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
Expand All @@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.num_generation_tokens == expected_bs
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
Expand All @@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"

assert input_tokens.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, )
torch.testing.assert_close(input_tokens, input_positions)

# Verify Sampling
Expand Down

0 comments on commit ac8029a

Please sign in to comment.