Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support chunked prefill when radix cache is disabled #811

Merged
merged 9 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
limitations under the License.
"""

"""Base cache class."""
"""Base tool cache for constrained decoding tools."""

import time


class BaseCache:
class BaseToolCache:
def __init__(self, enable=True):
self.enable = enable
self.reset()
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/constrained/fsm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"""Cache for the compressed finite state machine."""

from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.base_tool_cache import BaseToolCache


class FSMCache(BaseCache):
class FSMCache(BaseToolCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable)

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/constrained/jump_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.base_tool_cache import BaseToolCache

IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"

Expand Down Expand Up @@ -151,7 +151,7 @@ def is_jump_forward_symbol_state(self, state):
)


class JumpForwardCache(BaseCache):
class JumpForwardCache(BaseToolCache):
def __init__(self):
super().__init__()

Expand Down
38 changes: 29 additions & 9 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache

Expand Down Expand Up @@ -486,15 +487,33 @@ def retract_decode(self):
req = self.reqs[idx]
retracted_reqs.append(req)

# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)

# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))

# release the last node
self.tree_cache.dec_lock_ref(req.last_node)

# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)

req.prefix_indices = None
req.last_node = None
Expand Down Expand Up @@ -575,6 +594,7 @@ def check_for_jump_forward(self, model_runner):
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
Expand Down
27 changes: 21 additions & 6 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ForwardMode,
Req,
)
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
Expand Down Expand Up @@ -144,11 +145,20 @@ def __init__(
)

# Init cache
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler(
self.schedule_policy,
Expand Down Expand Up @@ -354,7 +364,10 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
prefix_indices, last_node = self.tree_cache.match_prefix(
rid=req.rid,
key=req.input_ids,
)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Expand Down Expand Up @@ -614,6 +627,7 @@ def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
Expand Down Expand Up @@ -771,6 +785,7 @@ def handle_finished_requests(self, batch: Batch):
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
Expand Down
43 changes: 43 additions & 0 deletions python/sglang/srt/mem_cache/base_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod


class BaseCache(ABC):
hnyls2002 marked this conversation as resolved.
Show resolved Hide resolved
"""Cache can be indexed by either rid or key."""

@abstractmethod
def reset(self):
pass

@abstractmethod
def match_prefix(self, **kwargs):
pass

@abstractmethod
def insert(self, **kwargs):
pass

@abstractmethod
def cache_req(self, **kwargs):
pass

@abstractmethod
def evict(self, num_tokens, evict_callback):
pass

@abstractmethod
def inc_lock_ref(self, node):
pass

@abstractmethod
def dec_lock_ref(self, node):
pass

@abstractmethod
def evictable_size(self):
pass

def total_size(self):
raise NotImplementedError

def pretty_print(self):
raise NotImplementedError
60 changes: 60 additions & 0 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Cache for chunked prefill, used when RadixCache is disabled."""

from sglang.srt.mem_cache.base_cache import BaseCache


class ChunkCacheEntry:
def __init__(self, rid, value):
self.rid = rid
self.value = value


class ChunkCache(BaseCache):
def __init__(self, req_to_token_pool, token_to_kv_pool):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool

self.reset()

def reset(self):
self.entries = {}

def match_prefix(self, rid, **kwargs):
if rid not in self.entries:
return [], None

entry = self.entries[rid]
return entry.value, entry

def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return

if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)

entry = self.entries[rid]
entry.value = indices
return indices, entry

def insert(self):
raise NotImplementedError

def evict(self, num_tokens, evict_callback):
pass

def inc_lock_ref(self, node):
return 0

def dec_lock_ref(self, node):
return 0

def evictable_size(self):
return 0
7 changes: 5 additions & 2 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import torch

from sglang.srt.mem_cache.base_cache import BaseCache


class TreeNode:
def __init__(self):
Expand All @@ -46,7 +48,7 @@ def _key_match(key0, key1):
return i


class RadixCache:
class RadixCache(BaseCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
Expand All @@ -62,7 +64,7 @@ def reset(self):
self.root_node.lock_ref = 1
self.evictable_size_ = 0

def match_prefix(self, key):
def match_prefix(self, key, **kwargs):
if self.disable:
return [], self.root_node

Expand Down Expand Up @@ -90,6 +92,7 @@ def cache_req(
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,6 @@ def check_server_args(self):
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"

assert not (
self.chunked_prefill_size is not None and self.disable_radix_cache
), "chunked prefill is not supported with radix cache disabled currently"


@dataclasses.dataclass
class PortArgs:
Expand Down
Loading