Skip to content

Commit

Permalink
add cuda-graph changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Sep 19, 2024
1 parent 293a7a6 commit 0341276
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 44 deletions.
11 changes: 11 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}

int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x >= num_query_blocks) {
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:

@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
cuda_graph_pad_size: int, batch_size: int,
use_graph_block_tables: bool) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError

Expand Down
69 changes: 42 additions & 27 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,32 @@ def _add_seq_group(
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

def _use_graph_block_tables(self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs

graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]

return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
cuda_graph_pad_size: int, batch_size: int,
use_graph_block_tables: bool):
"""Build attention metadata with on-device tensors.
Args:
Expand All @@ -516,36 +540,27 @@ def build(self, seq_lens: List[int], query_lens: List[int],
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]

block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
num_decode_tokens = batch_size - self.num_prefill_tokens
assert use_graph_block_tables, \
("Must use graph block tables from the runner when using "
"the captured graph")
block_tables = self._use_graph_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
if use_graph_block_tables:
block_tables = self._use_graph_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

assert device is not None
Expand Down
7 changes: 6 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
self.paged_kv_last_page_len.append(last_page_len)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
cuda_graph_pad_size: int, batch_size: int,
use_graph_block_tables: bool):
"""Build attention metadata with on-device tensors.
Args:
Expand All @@ -606,6 +607,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

assert use_graph_block_tables, \
("Must use graph block tables from the runner when using "
"the captured graph")
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
Expand All @@ -630,6 +634,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else:
assert not use_graph_block_tables
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
Expand Down
5 changes: 4 additions & 1 deletion vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def _add_seq_group(
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
cuda_graph_pad_size: int, batch_size: int,
use_graph_block_tables: bool):
"""Build attention metadata with on-device tensors.
Args:
Expand All @@ -217,6 +218,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

assert use_graph_block_tables
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
Expand All @@ -226,6 +228,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
assert not use_graph_block_tables
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
Expand Down
55 changes: 42 additions & 13 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,37 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):

def _use_captured_graph(self,
batch_size: int,
decode_only: bool,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
return (decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture)

def _cuda_graph_pad_size(self,
batch_size: int,
num_seqs: int,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> int:
is_mscp: bool = \
self.runner.scheduler_config.is_multi_step_chunked_prefill
# In multi-step chunked-prefill, starting from the second step
# all the sequences are guaranteed to be decodes. So, we may
# run the first-step in eager mode and the rest of the steps
# in graph mode.
batch_size = batch_size if not is_mscp else num_seqs
decode_only = self.decode_only or is_mscp
if not self._use_captured_graph(batch_size, decode_only,
max_decode_seq_len,
max_encoder_seq_len):
return -1

graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
return graph_batch_size - batch_size

def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
Expand Down Expand Up @@ -796,20 +819,18 @@ def build(self) -> ModelInputForGPU:
}

batch_size = len(input_tokens)
use_captured_graph = self._use_captured_graph(

cuda_graph_pad_size = self._cuda_graph_pad_size(
batch_size,
max_decode_seq_len,
num_seqs=len(seq_lens),
max_decode_seq_len=max_encoder_seq_len,
max_encoder_seq_len=max_encoder_seq_len)

# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
batch_size += cuda_graph_pad_size

# Tokens and positions.
if cuda_graph_pad_size:
Expand Down Expand Up @@ -837,8 +858,16 @@ def build(self) -> ModelInputForGPU:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))

# Attention metadata.
# TODO (varun) : Handle flashinfer unsupported
use_graph_block_tables = cuda_graph_pad_size != -1 or \
(self.scheduler_config.is_multi_step_chunked_prefill and \
len(seq_lens) in _BATCH_SIZES_TO_CAPTURE)
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
seq_lens,
query_lens,
cuda_graph_pad_size,
batch_size,
use_graph_block_tables=use_graph_block_tables)

# LoRA data.
lora_requests = set()
Expand Down
6 changes: 5 additions & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache
from vllm.worker.model_runner import (GPUModelRunnerBase,
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
Expand Down Expand Up @@ -592,6 +593,9 @@ def _advance_step(self, model_input: StatefulModelInput,
counts_update=model_input.maybe_get_counts_update(),
)

if model_input.num_seqs in _BATCH_SIZES_TO_CAPTURE:
attn_metadata.use_cuda_graph = True

return model_input

def load_model(self) -> None:
Expand Down

0 comments on commit 0341276

Please sign in to comment.