Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] CUDA Graphs for Multi-Step + Chunked-Prefill #8645

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}

int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x >= num_query_blocks) {
Expand Down
47 changes: 27 additions & 20 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,29 @@ def _add_seq_group(
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

def _use_graph_block_tables(self, num_seqs: int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest naming: _get_block_table_with_cuda_graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave it some thought - _get_block_table_with_cuda_graph seems to suggest that cuda-graphs are a property of the block-tables. while we simply copy the python block tables to the block-table tensor in self.runner.graph_block_tables.

IMHO _use_graph_block_tables captures the intent, given self.runner.graph_block_tables is the tensor being filled.

Perhaps, _prepare_graph_block_tables / _get_graph_runner_block_tables is a better alternative.

What do you think ? Happy to make the change if you think otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_graph_runner_block_tables looks better to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to _get_graph_runner_block_tables. Thanks.

block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs

graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]

return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Expand Down Expand Up @@ -522,29 +545,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]

block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._use_graph_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
Expand Down
50 changes: 38 additions & 12 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,42 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):

def _use_captured_graph(self,
batch_size: int,
decode_only: bool,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
return (decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture)

def _cuda_graph_pad_size(self,
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
batch_size: int,
num_seqs: int,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> int:
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
self.runner.scheduler_config.chunked_prefill_enabled
# The input batch_size is the number of input-tokens that includes
# both the prefill and decode tokens. Generally, when the batch has
# prefills, we don't use CUDA graphs. i.e. _use_captured_graph() will
# be False.
# However, In the multi-step + chunked-prefill case, only the first
# step has Prefills (if any). The rest of the steps are guaranteed to
# be all decodes. In this case, we set up the padding as if all the
# sequences are decodes so we may run all steps expect the first step
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
# in CUDA graph mode.
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
batch_size = num_seqs if is_mscp else batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it valid if we just batch_size = num_seqs all the time? It seems to me that if you already have num_seqs, it's not necessary to take batch_size in this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid in the case when all the input sequences are decodes. batch-size in the rest of the code refers to the number of input tokens being batched. I did not want to conflate the two ideas.

However, I have made some changes to this function, that simplifies the interface a little and makes this conflation legitimate. Please take a look. Appreciate any feedback on this! Thanks.

decode_only = self.decode_only or is_mscp
if not self._use_captured_graph(batch_size, decode_only,
max_decode_seq_len,
max_encoder_seq_len):
return -1

graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
return graph_batch_size - batch_size

def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
Expand Down Expand Up @@ -796,20 +824,18 @@ def build(self) -> ModelInputForGPU:
}

batch_size = len(input_tokens)
use_captured_graph = self._use_captured_graph(

cuda_graph_pad_size = self._cuda_graph_pad_size(
batch_size,
max_decode_seq_len,
num_seqs=len(seq_lens),
max_decode_seq_len=max_encoder_seq_len,
max_encoder_seq_len=max_encoder_seq_len)

# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
batch_size += cuda_graph_pad_size

# Tokens and positions.
if cuda_graph_pad_size:
Expand Down
Loading