Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix caching #2402

Merged
merged 13 commits into from
Aug 20, 2024
13 changes: 9 additions & 4 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,15 @@ impl State {
+ self.speculate
- 1;

match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
// If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
None
} else {
entry.request.input_ids.clone()
};

match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
Expand Down
1 change: 1 addition & 0 deletions backends/v3/src/radix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ pub struct RadixTrie {
/// call that a real time lookup would require.
time: u64,
}

impl Default for RadixTrie {
fn default() -> Self {
Self::new()
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
grpcio-status
grpcio-tools
hf-transfer
ipdb
loguru
mamba-ssm
marlin-kernels
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex":
Expand Down
27 changes: 17 additions & 10 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def paged_attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import decode_state
from text_generation_server.layers.attention.flashinfer import decode_state

return decode_state.get().forward(
query.contiguous(),
Expand Down Expand Up @@ -221,24 +221,27 @@ def paged_attention(
if ATTENTION == "flashinfer":

def attention(
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flash_infer import prefill_state
assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)

return prefill_state.get().forward(
q,
k,
v,
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
causal=causal,
window_left=window_size_left,
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
Expand All @@ -249,6 +252,8 @@ def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
Expand Down Expand Up @@ -289,6 +294,8 @@ def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
"prefill_state"
)

prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")

decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
Expand All @@ -24,6 +28,78 @@ def get_workspace(device):
return workspace


def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)


@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""

indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)

# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1

token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
page_size=page_size,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)


def create_prefill_state(
*,
device: torch.device,
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self, config, medusa_config, weights):
)

def forward(self, x):
if not self.heads:
return None
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def forward(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def forward(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def forward(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def forward(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down Expand Up @@ -326,6 +328,8 @@ def forward(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def forward(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def forward(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
Expand Down
Loading
Loading