-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: non-contiguous query with paged kv cache (#553)
## Motivation Previously, only ragged version of prefill kernel supported non-contiguous query tensor (#404). But with paged kv cache, you have to make query tensor contiguous. Libraries like vLLM or SGLang must make query tensor contiguous before calling flashinfer kernels ([vLLM call of flashinfer](https://github.com/vllm-project/vllm/blob/b7df53cd42f3eab007b4f287c151960858e949df/vllm/attention/backends/flashinfer.py#L839), [SGLang call of flashinfer](https://github.com/sgl-project/sglang/blob/87a7cfa080cec3f123618c1429b5f998bf5d99cb/python/sglang/srt/layers/attention/flashinfer_backend.py#L236)). This PR solves it, ensuring that prefill/decode kernels with paged kv cache support non-contiguous query tensor. ## Main Changes 1. Add strides of query tensor in `BatchPrefillPagedParams` and `BatchDecodeParams`. 2. Set stride parameters before calling those kernels. 3. Modify JIT compiling templates to support new kernel parameters. 4. Add some tests. The Python interfaces remain the same. Nothing changes except it accepts non-contiguous query tensors now! --------- Signed-off-by: LinHeLurking <[email protected]>
- Loading branch information
1 parent
f6e0010
commit 89f2c4a
Showing
10 changed files
with
196 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
import pytest | ||
import flashinfer | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 19, 99]) | ||
@pytest.mark.parametrize("page_size", [1, 5]) | ||
@pytest.mark.parametrize("seq_len", [1]) | ||
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) | ||
@pytest.mark.parametrize("num_qo_heads", [4, 8]) | ||
@pytest.mark.parametrize("head_dim", [64, 128, 256]) | ||
def test_batch_paged_decode_packed_input( | ||
batch_size, | ||
page_size, | ||
seq_len, | ||
num_kv_heads, | ||
num_qo_heads, | ||
head_dim, | ||
): | ||
if num_qo_heads % num_kv_heads != 0: | ||
pytest.skip("num_qo_heads must be a multiple of num_kv_heads") | ||
nnz = batch_size * seq_len | ||
num_pages_per_req = (seq_len + page_size - 1) // page_size | ||
num_pages = batch_size * num_pages_per_req | ||
last_page_len = (seq_len - 1) % page_size + 1 | ||
k_cache = torch.randn( | ||
size=(num_pages, page_size, num_kv_heads, head_dim), | ||
dtype=torch.float16, | ||
device="cuda:0", | ||
) | ||
v_cache = torch.randn_like(k_cache) | ||
paged_kv_cache = (k_cache, v_cache) | ||
workspace_buffer = torch.empty( | ||
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" | ||
) | ||
paged_kv_indptr = torch.tensor( | ||
[i * num_pages_per_req for i in range(batch_size + 1)], | ||
dtype=torch.int32, | ||
device="cuda:0", | ||
) | ||
paged_kv_indices = torch.tensor( | ||
list(range(num_pages)), dtype=torch.int32, device="cuda:0" | ||
) | ||
paged_kv_last_page_len = torch.tensor( | ||
[last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0" | ||
) | ||
|
||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer) | ||
wrapper.plan( | ||
indptr=paged_kv_indptr, | ||
indices=paged_kv_indices, | ||
last_page_len=paged_kv_last_page_len, | ||
num_qo_heads=num_qo_heads, | ||
num_kv_heads=num_kv_heads, | ||
head_dim=head_dim, | ||
page_size=page_size, | ||
) | ||
|
||
qkv_packed = torch.randn( | ||
size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), | ||
dtype=torch.float16, | ||
device="cuda:0", | ||
) | ||
qkv_split_idx = ( | ||
num_qo_heads * head_dim, | ||
num_kv_heads * head_dim, | ||
num_kv_heads * head_dim, | ||
) | ||
q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1) | ||
q = q.view(-1, num_qo_heads, head_dim) | ||
o_packed = wrapper.run(q, paged_kv_cache) | ||
o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache) | ||
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_batch_paged_decode_packed_input(37, 127, 1, 4, 64, 128) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters