From 7d02fc9835bee0aff70e50d808eecd54b43fefa6 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Sep 2024 18:53:47 +0000 Subject: [PATCH 1/8] Fix bug in encoder-decoder during beam search --- .../test_encoder_decoder_model_runner.py | 29 ++++++++++++++----- vllm/worker/enc_dec_model_runner.py | 11 +++---- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index a00d46ddeb007..d88d2ce373723 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -305,22 +305,24 @@ def test_prepare_decode(batch_size): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + block_tables = {0: [1], 1: [3]} cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) + #seq_lens.append(seq_len) + seq_lens.extend([seq_len, seq_len]) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) + #encoder_seq_lens.append(encoder_seq_len) + encoder_seq_lens.extend([encoder_seq_len, encoder_seq_len]) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={0: seq_data, 1: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -398,19 +400,32 @@ def test_prepare_decode(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention + flattened_block_tables = [block_table for block_table in block_tables.values()] + + expected = torch.tensor( - [block_tables[0] for _ in range(len(seq_group_metadata_list))], + flattened_block_tables * len(seq_group_metadata_list), dtype=torch.int32, - device=model_runner.device) + device=model_runner.device + ) + #expected = torch.tensor( + # [block_tables[0] for _ in range(len(seq_group_metadata_list))], + # dtype=torch.int32, + # device=model_runner.device + # ) + print('expected ' + str(expected)) + print('attn_metadata.block_tables ' + str(attn_metadata.block_tables)) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention expected = torch.tensor( - [cross_block_table for _ in range(len(seq_group_metadata_list))], + [cross_block_table for seq_group_metadata in seq_group_metadata_list for _ in range(len(seq_group_metadata.seq_data))], dtype=torch.int32, device=model_runner.device) + print('expected ' + str(expected)) + print('attn_metadata.cross_block_tables ' + str(attn_metadata.cross_block_tables)) assert torch.equal( attn_metadata.cross_block_tables, expected, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 09dab0135f390..20515b94eb9a6 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -442,11 +442,12 @@ def _prepare_encoder_model_input_tensors( # during vLLM memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) + for _ in enumerate(seq_group_metadata.seq_data.items()): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) if (model_input.attn_metadata is not None and model_input.attn_metadata.use_cuda_graph): From 31247e07a4b6ac3ec32d51c0ac31193013f2817d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Sep 2024 20:22:20 +0000 Subject: [PATCH 2/8] Format tests --- .../test_encoder_decoder_model_runner.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index d88d2ce373723..fee56e4e06c7a 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -310,19 +310,18 @@ def test_prepare_decode(batch_size): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - #seq_lens.append(seq_len) - seq_lens.extend([seq_len, seq_len]) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - #encoder_seq_lens.append(encoder_seq_len) - encoder_seq_lens.extend([encoder_seq_len, encoder_seq_len]) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data, 1: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + }, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -330,6 +329,10 @@ def test_prepare_decode(batch_size): ) assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) # Build # * Decoder model inputs @@ -400,32 +403,24 @@ def test_prepare_decode(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention - flattened_block_tables = [block_table for block_table in block_tables.values()] - - - expected = torch.tensor( - flattened_block_tables * len(seq_group_metadata_list), - dtype=torch.int32, - device=model_runner.device - ) - #expected = torch.tensor( - # [block_tables[0] for _ in range(len(seq_group_metadata_list))], - # dtype=torch.int32, - # device=model_runner.device - # ) - print('expected ' + str(expected)) - print('attn_metadata.block_tables ' + str(attn_metadata.block_tables)) + flattened_block_tables = [ + block_table for block_table in block_tables.values() + ] + expected = torch.tensor(flattened_block_tables * + len(seq_group_metadata_list), + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention - expected = torch.tensor( - [cross_block_table for seq_group_metadata in seq_group_metadata_list for _ in range(len(seq_group_metadata.seq_data))], - dtype=torch.int32, - device=model_runner.device) - print('expected ' + str(expected)) - print('attn_metadata.cross_block_tables ' + str(attn_metadata.cross_block_tables)) + expected = torch.tensor([ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ], + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.cross_block_tables, expected, @@ -510,7 +505,7 @@ def test_prepare_decode_cuda_graph(batch_size): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + block_tables = {0: [1], 1: [3]} cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block @@ -525,13 +520,20 @@ def test_prepare_decode_cuda_graph(batch_size): seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + }, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == 1 + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) From 6b9c5b6f0694b8fd1f7e05c5b084c8828126c9c3 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Sep 2024 00:06:58 +0000 Subject: [PATCH 3/8] Fixing tests --- .../test_encoder_decoder_model_runner.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index fee56e4e06c7a..643b0213156f3 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -510,11 +510,9 @@ def test_prepare_decode_cuda_graph(batch_size): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( @@ -549,8 +547,9 @@ def test_prepare_decode_cuda_graph(batch_size): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(batch_size) - cuda_graph_pad_size = graph_batch_size - batch_size + expanded_batch_size = batch_size * 2 + graph_batch_size = _get_graph_batch_size(expanded_batch_size) + cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( itertools.repeat(1, cuda_graph_pad_size)) @@ -579,10 +578,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention. Pad the block tables as expected. - expected = [block_tables[0] for _ in range(batch_size)] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) + flattened_block_tables = [ + block_table for _ in range(len(seq_group_metadata_list)) + for block_table in block_tables.values() + ] + flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( - expected, + flattened_block_tables, max_len=64, pad=0, dtype=torch.int32, @@ -594,7 +596,10 @@ def test_prepare_decode_cuda_graph(batch_size): ) # - Encoder/decoder cross-attention. Pad the cross-attention block tables # as expected. - expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected = [ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ] expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( expected, From f9f3bee6ad8c94d2464114325ad01e469101504c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Sep 2024 00:23:01 +0000 Subject: [PATCH 4/8] Fix logic --- vllm/worker/enc_dec_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 20515b94eb9a6..060b85b9c80f0 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -442,7 +442,7 @@ def _prepare_encoder_model_input_tensors( # during vLLM memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: - for _ in enumerate(seq_group_metadata.seq_data.items()): + for _ in range(len(seq_group_metadata.seq_data.items())): encoder_seq_lens.append( seq_group_metadata.encoder_seq_data.get_len()) cross_block_table = seq_group_metadata.cross_block_table From 2b978ffad1ef7dcd79c5bdefc5dca5b0ed9d9eba Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Sep 2024 01:03:57 +0000 Subject: [PATCH 5/8] Fix tests --- .../test_encoder_decoder_model_runner.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 643b0213156f3..f660890cca686 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_decode(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -288,6 +289,7 @@ def test_prepare_decode(batch_size): Arguments: * batch_size + * multiple_seqs_per_seq_group * backend_name: The attention backend under test * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) ''' @@ -305,7 +307,12 @@ def test_prepare_decode(batch_size): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1], 1: [3]} + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block @@ -315,13 +322,14 @@ def test_prepare_decode(batch_size): encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, seq_data={ 0: seq_data, 1: seq_data - }, + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -486,7 +494,8 @@ def test_prepare_decode(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): """ Tests that for encoder-decoder models with CUDA Graph capture and replay enabled, the tensors used during the decode phase are correctly padded @@ -501,12 +510,18 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, enforce_eager=False, ) - + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1], 1: [3]} + cross_block_table = [2] + expanded_batch_size = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 @@ -521,7 +536,7 @@ def test_prepare_decode_cuda_graph(batch_size): seq_data={ 0: seq_data, 1: seq_data - }, + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -532,6 +547,7 @@ def test_prepare_decode_cuda_graph(batch_size): [seq_len for _ in range(len(seq_group_metadata.seq_data))]) encoder_seq_lens.extend( [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) + expanded_batch_size = expanded_batch_size + len(seq_group_metadata.seq_data) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) @@ -547,7 +563,6 @@ def test_prepare_decode_cuda_graph(batch_size): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - expanded_batch_size = batch_size * 2 graph_batch_size = _get_graph_batch_size(expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) From 1ad16513bff45c56ac6d7e54735298430ce440d0 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Sep 2024 01:11:03 +0000 Subject: [PATCH 6/8] Format --- tests/worker/test_encoder_decoder_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index f660890cca686..e5c0ff21934f4 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -519,7 +519,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - + cross_block_table = [2] expanded_batch_size = 0 for i in range(batch_size): @@ -547,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): [seq_len for _ in range(len(seq_group_metadata.seq_data))]) encoder_seq_lens.extend( [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - expanded_batch_size = expanded_batch_size + len(seq_group_metadata.seq_data) + expanded_batch_size = expanded_batch_size + len( + seq_group_metadata.seq_data) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) From 437b4cbf70b940832b3f07fb564d1cfc0f744ad3 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Sep 2024 03:16:16 +0000 Subject: [PATCH 7/8] format --- vllm/worker/enc_dec_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 060b85b9c80f0..066b0e500653c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -442,7 +442,7 @@ def _prepare_encoder_model_input_tensors( # during vLLM memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: - for _ in range(len(seq_group_metadata.seq_data.items())): + for _ in range(len(seq_group_metadata.seq_data)): encoder_seq_lens.append( seq_group_metadata.encoder_seq_data.get_len()) cross_block_table = seq_group_metadata.cross_block_table From 13ca3aca6d5b911c45f7a1deacf87f2b8c3f7713 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Sep 2024 17:32:57 +0000 Subject: [PATCH 8/8] Dummy --- vllm/worker/enc_dec_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 066b0e500653c..709efdc8b9d57 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -435,7 +435,6 @@ def _prepare_encoder_model_input_tensors( encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & # seq len from each sequence group metadata. # Cross-attention block tables are empty