Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CUDA graph compilation #627

Merged
merged 6 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,3 @@ def forward(
mean_hidden_states = hidden_states.mean(1)
embeddings = nn.functional.linear(mean_hidden_states, self.output_weight, self.output_bias)
return embeddings, None

35 changes: 21 additions & 14 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID
from lorax_server.utils.attention.utils import block_tables_to_ragged
from lorax_server.utils.dist import MEMORY_FRACTION
from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM
from lorax_server.utils.graph import GraphCache
from lorax_server.utils.import_utils import get_cuda_free_memory
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.sources import HUB
from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_speculative_tokens, warmup_mode
Expand Down Expand Up @@ -406,9 +407,18 @@ def from_pb(
),
prefill_cache_indices=prefill_cache_indices if SLIDING_WINDOW is not None else None,
)

@classmethod
def from_pb_embed(self, pb: generate_pb2.EmbedRequest, tokenizer: PreTrainedTokenizerBase, tokenizers: TokenizerManager, processor, config, dtype, device) -> "FlashCausalLMBatch":
def from_pb_embed(
self,
pb: generate_pb2.EmbedRequest,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
processor,
config,
dtype,
device,
) -> "FlashCausalLMBatch":
return self.from_pb(pb, tokenizer, tokenizers, None, None, dtype, device)

@tracer.start_as_current_span("filter")
Expand Down Expand Up @@ -885,6 +895,7 @@ def adapter_memory_size(self) -> int:
return ADAPTER_MEMORY_FRACTION * total_gpu_memory

def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model: bool = False):
# The warmup batch is the biggest batch we could ever receive
max_total_tokens = batch.max_seqlen + max_new_tokens + get_speculative_tokens()

torch.cuda.empty_cache()
Expand Down Expand Up @@ -954,21 +965,16 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_free_memory -= graph_cache_memory

total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

free_memory = max(
0,
total_free_memory - (1 - MEMORY_FRACTION + ADAPTER_MEMORY_FRACTION) * total_gpu_memory,
)
free_memory = get_cuda_free_memory(self.device, MEMORY_FRACTION - ADAPTER_MEMORY_FRACTION)
free_memory -= graph_cache_memory
logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024)

batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
int(free_memory // total_cache_size)
# Leave 5% for some wiggle room
int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks
+ batch_num_blocks
)

del batch
Expand All @@ -987,6 +993,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
if self.model_graph_wrapper is not None:
# Warmup the graph cache. Needs to be done after setting cache manager as
# tracing will use the static kv cache tensors
self.model_graph_wrapper.kv_cache = self.kv_cache
self.model_graph_wrapper.warmup()
torch.cuda.synchronize(self.device)

Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))

MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.95"))


class FakeBarrier:
def wait(self):
Expand Down
54 changes: 42 additions & 12 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CUDA Graph implementation modified from vLLM:
# https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py

import os
from dataclasses import dataclass
from functools import lru_cache
from statistics import median
Expand All @@ -25,21 +26,26 @@
from lorax_server.models.model import Model


# TODO(travis): make this configurable by model / user
MAX_BATCH_SIZE = 256
MAX_RANK = BGMV_MAX_RANK
MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 96))
MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", BGMV_MAX_RANK))

SLOT_PAD_VALUE = -1
SEGMENT_PAD_VALUE = -1

# Cached batch sizes used in vLLM. This and the helper function `get_cached_batch_size` below
# must be kept in sync.
BATCH_SIZE_INCREMENT = 32
CACHED_BATCH_SIZES = [1, 2, 4, 8, 16] + [BATCH_SIZE_INCREMENT * (i + 1) for i in range(8)]

# set CACHED_BATCH_SIZES to 1, 2, 3, 4, 8, 16 and then increments of BATCH_SIZE_INCREMENT up to MAX_BATCH_SIZE
CACHED_BATCH_SIZES = [1, 2, 3, 4, 8, 16] + [
BATCH_SIZE_INCREMENT * (i + 1) for i in range(MAX_BATCH_SIZE // BATCH_SIZE_INCREMENT)
]
CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= MAX_BATCH_SIZE]

# Include 0 to ensure we can use cuda graphs without adapters
# TODO(travis): use padding to allow for more ranks without increasing memory usage
CACHED_MAX_RANKS = [0, 8, 16, 32, 64]
CACHED_MAX_RANKS = [r for r in CACHED_MAX_RANKS if r <= MAX_RANK]
_allowed_ranks = set(CACHED_MAX_RANKS)

assert all([r <= BGMV_MAX_RANK for r in _allowed_ranks]), f"Invalid ranks: {_allowed_ranks}"
Expand Down Expand Up @@ -104,6 +110,8 @@ def get_max_graph_state(
position_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device)
slots = torch.full((MAX_BATCH_SIZE,), SLOT_PAD_VALUE, dtype=torch.int64, device=device)
input_lengths = torch.ones((MAX_BATCH_SIZE,), dtype=torch.int32, device=device)
prefix_lens = [0] * MAX_BATCH_SIZE
prefix_lens_tensor = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device)

adapter_weight_data = {}
for layer_name in adapter_layers:
Expand All @@ -116,7 +124,7 @@ def get_max_graph_state(
rank=MAX_RANK,
lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
indices=torch.full((MAX_BATCH_SIZE,), SEGMENT_PAD_VALUE, dtype=torch.int64, device=device),
segment_starts=None,
segment_ends=None,
tmp_shrink=None,
Expand All @@ -134,6 +142,8 @@ def get_max_graph_state(
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
adapter_data=AdapterBatchData(
meta=AdapterBatchMetadata(
adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
Expand Down Expand Up @@ -226,8 +236,8 @@ def trace(

block_tables = max_input_state.block_tables[:batch_size]
input_lengths = max_input_state.input_lengths[:batch_size]
prefix_lengths = [0] * batch_size
prefix_lengths_tensor = torch.zeros(batch_size, dtype=torch.int32, device=device)
prefix_lengths = max_input_state.prefix_lens[:batch_size]
prefix_lengths_tensor = max_input_state.prefix_lens_tensor[:batch_size]
state = None

if FLASH_INFER:
Expand Down Expand Up @@ -285,6 +295,22 @@ def trace(
prefix_lens_tensor=prefix_lengths_tensor,
state=input_state.state,
):
# warmup
output_states = model.forward(
input_ids=input_state.input_ids,
position_ids=input_state.position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=input_state.block_tables,
slots=input_state.slots,
input_lengths=input_state.input_lengths,
max_s=max_total_tokens,
adapter_data=input_state.adapter_data,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=memory_pool): # noqa: SIM117
output_states = model.forward(
Expand All @@ -297,6 +323,7 @@ def trace(
input_lengths=input_state.input_lengths,
max_s=max_total_tokens,
adapter_data=input_state.adapter_data,
prefill_cache_indices=None,
lm_head_indices=None,
)

Expand All @@ -323,9 +350,8 @@ def forward(
pad_and_fill(self.input_state.position_ids, position_ids, 0)
pad_and_fill(self.input_state.slots, slots, SLOT_PAD_VALUE)
pad_and_fill(self.input_state.input_lengths, input_lengths + prefix_lens_tensor, 0)

self.input_state.block_tables.zero_()
self.input_state.block_tables[: block_tables.shape[0], : block_tables.shape[1]] = block_tables
self.input_state.prefix_lens[: len(prefix_lens)] = prefix_lens
pad_and_fill(self.input_state.prefix_lens_tensor, prefix_lens_tensor, 0)

if FLASH_INFER:
block_tables = block_tables_to_ragged(
Expand Down Expand Up @@ -359,8 +385,8 @@ def forward(
cu_seqlen_prefill=None,
input_lengths=input_lengths,
input_lengths_tensor=self.input_state.input_lengths,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
prefix_lens=self.input_state.prefix_lens,
prefix_lens_tensor=self.input_state.prefix_lens_tensor,
state=self.input_state.state,
):
self.graph.replay()
Expand Down Expand Up @@ -511,6 +537,8 @@ def forward(
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor,
max_s: int,
adapter_data: AdapterBatchData,
lm_head_indices: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -552,6 +580,8 @@ def forward(
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
max_s=max_s,
adapter_data=adapter_data,
lm_head_indices=lm_head_indices,
Expand Down
5 changes: 1 addition & 4 deletions server/lorax_server/utils/sources/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def weight_hub_files(
if embedding_tensor_file not in filenames:
raise ValueError(f"No embedding tensor file found for embedding dim {embedding_dim}")
filenames = [
filename
for filename in filenames
if len(filename.split("/")) < 2
or filename == embedding_tensor_file
filename for filename in filenames if len(filename.split("/")) < 2 or filename == embedding_tensor_file
]
else:
filenames = [
Expand Down
Loading