Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models. #8545

Merged
merged 46 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 Sep 2, 2024
25cef3d
Merge branch 'vllm-project:main' into main
sroy745 Sep 3, 2024
9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 Sep 4, 2024
c48cacb
Merge branch 'vllm-project:main' into main
sroy745 Sep 5, 2024
c42c399
Merge branch 'vllm-project:main' into main
sroy745 Sep 7, 2024
3d13e43
Merge branch 'vllm-project:main' into main
sroy745 Sep 9, 2024
7479775
Merge branch 'vllm-project:main' into main
sroy745 Sep 11, 2024
df9b966
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
9a7ed92
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
7d02fc9
Fix bug in encoder-decoder during beam search
sroy745 Sep 17, 2024
31247e0
Format tests
sroy745 Sep 17, 2024
6b9c5b6
Fixing tests
sroy745 Sep 18, 2024
f9f3bee
Fix logic
sroy745 Sep 18, 2024
2b978ff
Fix tests
sroy745 Sep 18, 2024
1ad1651
Format
sroy745 Sep 18, 2024
437b4cb
format
sroy745 Sep 18, 2024
13ca3ac
Dummy
sroy745 Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 63 additions & 25 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size):
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
def test_prepare_decode(batch_size):
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Expand All @@ -288,6 +289,7 @@ def test_prepare_decode(batch_size):
Arguments:

* batch_size
* multiple_seqs_per_seq_group
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
Expand All @@ -305,29 +307,40 @@ def test_prepare_decode(batch_size):
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))

seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])

# Build
# * Decoder model inputs
Expand Down Expand Up @@ -398,19 +411,24 @@ def test_prepare_decode(batch_size):

# Verify block tables are correct for prompts
# - Decoder self-attention
expected = torch.tensor(
[block_tables[0] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
flattened_block_tables = [
block_table for block_table in block_tables.values()
]
expected = torch.tensor(flattened_block_tables *
len(seq_group_metadata_list),
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
expected = torch.tensor(
[cross_block_table for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
expected = torch.tensor([
cross_block_table for seq_group_metadata in seq_group_metadata_list
for _ in range(len(seq_group_metadata.seq_data))
],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
Expand Down Expand Up @@ -476,7 +494,8 @@ def test_prepare_decode(batch_size):


@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
Expand All @@ -491,32 +510,45 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
enforce_eager=False,
)

block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}

cross_block_table = [2]
expanded_batch_size = 0
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
expanded_batch_size = expanded_batch_size + len(
seq_group_metadata.seq_data)
seq_group_metadata_list.append(seq_group_metadata)

model_input = model_runner.prepare_model_input(seq_group_metadata_list)
Expand All @@ -532,8 +564,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = _get_graph_batch_size(batch_size)
cuda_graph_pad_size = graph_batch_size - batch_size
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
itertools.repeat(1, cuda_graph_pad_size))
Expand Down Expand Up @@ -562,10 +594,13 @@ def test_prepare_decode_cuda_graph(batch_size):

# Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected.
expected = [block_tables[0] for _ in range(batch_size)]
expected.extend([[] for _ in range(cuda_graph_pad_size)])
flattened_block_tables = [
block_table for _ in range(len(seq_group_metadata_list))
for block_table in block_tables.values()
]
flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad(
expected,
flattened_block_tables,
max_len=64,
pad=0,
dtype=torch.int32,
Expand All @@ -577,7 +612,10 @@ def test_prepare_decode_cuda_graph(batch_size):
)
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected.
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
expected = [
cross_block_table for seq_group_metadata in seq_group_metadata_list
for _ in range(len(seq_group_metadata.seq_data))
]
expected.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad(
expected,
Expand Down
12 changes: 6 additions & 6 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,18 +435,18 @@ def _prepare_encoder_model_input_tensors(
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()

# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)

if (model_input.attn_metadata is not None
and model_input.attn_metadata.use_cuda_graph):
Expand Down
Loading