-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] CUDA Graphs for Multi-Step + Chunked-Prefill #8645
Changes from 8 commits
c0769f4
c0ab4dc
e01088e
48beebc
a128e88
b70a0d5
3b49f43
5b05521
b9f76ab
1665e7e
7494f49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -729,14 +729,42 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): | |
|
||
def _use_captured_graph(self, | ||
batch_size: int, | ||
decode_only: bool, | ||
max_decode_seq_len: int, | ||
max_encoder_seq_len: int = 0) -> bool: | ||
return (self.decode_only and not self.runner.model_config.enforce_eager | ||
return (decode_only and not self.runner.model_config.enforce_eager | ||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] | ||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture | ||
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture | ||
and batch_size <= self.runner.max_batchsize_to_capture) | ||
|
||
def _cuda_graph_pad_size(self, | ||
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch_size: int, | ||
num_seqs: int, | ||
max_decode_seq_len: int, | ||
max_encoder_seq_len: int = 0) -> int: | ||
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ | ||
self.runner.scheduler_config.chunked_prefill_enabled | ||
# The input batch_size is the number of input-tokens that includes | ||
# both the prefill and decode tokens. Generally, when the batch has | ||
# prefills, we don't use CUDA graphs. i.e. _use_captured_graph() will | ||
# be False. | ||
# However, In the multi-step + chunked-prefill case, only the first | ||
# step has Prefills (if any). The rest of the steps are guaranteed to | ||
# be all decodes. In this case, we set up the padding as if all the | ||
# sequences are decodes so we may run all steps expect the first step | ||
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# in CUDA graph mode. | ||
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch_size = num_seqs if is_mscp else batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it valid if we just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is valid in the case when all the input sequences are decodes. batch-size in the rest of the code refers to the number of input tokens being batched. I did not want to conflate the two ideas. However, I have made some changes to this function, that simplifies the interface a little and makes this conflation legitimate. Please take a look. Appreciate any feedback on this! Thanks. |
||
decode_only = self.decode_only or is_mscp | ||
if not self._use_captured_graph(batch_size, decode_only, | ||
max_decode_seq_len, | ||
max_encoder_seq_len): | ||
return -1 | ||
|
||
graph_batch_size = _get_graph_batch_size(batch_size) | ||
assert graph_batch_size >= batch_size | ||
return graph_batch_size - batch_size | ||
|
||
def build(self) -> ModelInputForGPU: | ||
"""Finalize the builder intermediate data and | ||
create on-device tensors. | ||
|
@@ -796,20 +824,18 @@ def build(self) -> ModelInputForGPU: | |
} | ||
|
||
batch_size = len(input_tokens) | ||
use_captured_graph = self._use_captured_graph( | ||
|
||
cuda_graph_pad_size = self._cuda_graph_pad_size( | ||
batch_size, | ||
max_decode_seq_len, | ||
num_seqs=len(seq_lens), | ||
max_decode_seq_len=max_encoder_seq_len, | ||
max_encoder_seq_len=max_encoder_seq_len) | ||
|
||
# If cuda graph can be used, pad tensors accordingly. | ||
# See `capture_model` API for more details. | ||
# vLLM uses cuda graph only for decoding requests. | ||
cuda_graph_pad_size = -1 | ||
if use_captured_graph: | ||
graph_batch_size = _get_graph_batch_size(batch_size) | ||
assert graph_batch_size >= batch_size | ||
cuda_graph_pad_size = graph_batch_size - batch_size | ||
batch_size = graph_batch_size | ||
if cuda_graph_pad_size != -1: | ||
# If cuda graph can be used, pad tensors accordingly. | ||
# See `capture_model` API for more details. | ||
# vLLM uses cuda graph only for decoding requests. | ||
batch_size += cuda_graph_pad_size | ||
|
||
# Tokens and positions. | ||
if cuda_graph_pad_size: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest naming:
_get_block_table_with_cuda_graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave it some thought -
_get_block_table_with_cuda_graph
seems to suggest that cuda-graphs are a property of the block-tables. while we simply copy the python block tables to the block-table tensor inself.runner.graph_block_tables
.IMHO
_use_graph_block_tables
captures the intent, givenself.runner.graph_block_tables
is the tensor being filled.Perhaps,
_prepare_graph_block_tables
/_get_graph_runner_block_tables
is a better alternative.What do you think ? Happy to make the change if you think otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_graph_runner_block_tables looks better to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to
_get_graph_runner_block_tables
. Thanks.