Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up on ragged kernel wrapper #8737

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 191 additions & 28 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,35 +636,21 @@ def test_paged_attention_wrapper(self):
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_without_dynamo(self):
def _verify_ragged_paged_attention_no_dynamo(
self,
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
):
from torch_xla.experimental.custom_kernel import ragged_paged_attention
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention

seq_lens = [
(1, 1328),
(5, 18),
(1, 129),
(120, 229),
(1, 122), # first physical q block
(1, 64),
(32, 100),
(250, 463),
(1, 18),
(1, 17),
(99, 123)
] # last 3 physical q blocks [(q_len, kv_len),...]
num_heads = (4, 4)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768
num_seqs = len(seq_lens)
num_kv_pages_per_block = 128
num_queries_per_block = 8
block_kv_size = 256

num_queries_per_block = 16
q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype)

Expand All @@ -675,7 +661,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
page_indices_xla = page_indices.to("xla")
cu_q_lens_xla = cu_q_lens.to("xla")

output = ragged_paged_attention(
kernel_output = ragged_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
Expand All @@ -686,6 +672,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
use_kernel=True)
print(f'xw32 line675 {kernel_output=}')

nonkernel_output = ragged_paged_attention(
q_xla,
Expand All @@ -698,6 +685,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
use_kernel=False)
print(f'xw32 line688 {nonkernel_output.shape=}')

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
Expand All @@ -722,10 +710,119 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):

self.assertTrue(
torch.allclose(
output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
kernel_output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=2e-1, rtol=1e-2))
nonkernel_output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_mix_prefill_decode1_without_dynamo(self):
seq_lens = [
(1, 1328),
(5, 18),
(1, 129),
(120, 229),
(1, 122), # first physical q block
(1, 64),
(32, 100),
(250, 463),
(1, 18),
(1, 17),
(99, 123)
] # last 3 physical q blocks [(q_len, kv_len),...]
num_heads = (4, 4)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768

self._verify_ragged_paged_attention_no_dynamo(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_mix_prefill_decode2_without_dynamo(self):
seq_lens = [(1, 127), (120, 1328), (1, 64), (1, 64), (1, 64), (1, 64),
(256, 256), (131, 463)] # [(q_len, kv_len),...]
num_heads = (4, 4)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768

self._verify_ragged_paged_attention_no_dynamo(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_mix_prefill_decode3_without_dynamo(self):
seq_lens = [(1, 1328), (5, 18), (506, 563)] # [(q_len, kv_len),...]
num_heads = (4, 4)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768

self._verify_ragged_paged_attention_no_dynamo(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_all_tokens_belong_to_one_sequence_without_dynamo(
self):
seq_lens = [(512, 1328)] # [(q_len, kv_len),...]
num_heads = (4, 4)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768

self._verify_ragged_paged_attention_no_dynamo(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_one_tokens_per_sequence_without_dynamo(self):
seq_lens = [(512, 1328)] # [(q_len, kv_len),...]
num_heads = (4, 4)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768

self._verify_ragged_paged_attention_no_dynamo(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand All @@ -750,8 +847,8 @@ def test_ragged_paged_attention_wrapper_with_dynamo(self):
num_pages = 32768
num_seqs = len(seq_lens)
num_kv_pages_per_block = 128
# TODO(xw32): consider test on various block sizes
num_queries_per_block = 8
block_kv_size = 256

q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype)
Expand Down Expand Up @@ -795,6 +892,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
num_queries_per_block=num_queries_per_block,
use_kernel=True,
)
print(f"xw32 line877 {output=}")

nonkernel_output = compiled_paged_attention(
q_xla,
Expand All @@ -808,10 +906,75 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
num_queries_per_block=num_queries_per_block,
use_kernel=False,
)
print(f"xw32 line891 {nonkernel_output=}")

from torch_xla.experimental.custom_kernel import ragged_paged_attention
ref_nondynamo_output = ragged_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_lens_xla,
page_indices_xla,
cu_q_lens_xla,
num_seqs=num_seqs,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
use_kernel=True)
print(f"xw32 line909 {ref_nondynamo_output=}")

self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=2e-1, rtol=1e-2))


def _verify_make_sequence_metadata(self, seq_lens):
from torch_xla.experimental.custom_kernel import _make_sequence_metadata as torch_make_sequence_metadata
q_lens = [seq_len[0] for seq_len in seq_lens]
cu_q_lens_torch = torch.cumsum(torch.tensor([0] + q_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to("xla")
num_q_tokens = sum(q_lens)
num_queries_per_compute_block = 8
start_group_torch = torch.tensor([0], dtype=torch.int32).to("xla")
num_seqs = len(seq_lens)
sequence_metadata_torch, num_logical_q_tiles_torch = torch_make_sequence_metadata(
cu_q_lens=cu_q_lens_torch,
m=num_q_tokens,
tm=num_queries_per_compute_block,
start_sequence=start_group_torch,
num_sequences=num_seqs,
)
sequence_ids_torch, m_tile_ids_torch = sequence_metadata_torch

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import make_sequence_metadata as jax_make_sequence_metadata
cu_q_lens_jax = jnp.cumsum(jnp.array([0] + q_lens))
start_group_jax = jnp.array([0])
sequence_metadata_jax, num_logical_q_tiles_jax = jax_make_sequence_metadata(
cu_q_lens=cu_q_lens_jax,
m=num_q_tokens,
tm=num_queries_per_compute_block,
start_sequence=start_group_jax,
num_sequences=num_seqs,
)
sequence_ids_jax, m_tile_ids_jax = sequence_metadata_jax

self.assertEqual(num_logical_q_tiles_torch.cpu(), torch.from_numpy(np.array(num_logical_q_tiles_jax)).cpu())
self.assertTrue(torch.equal(sequence_ids_torch.cpu(), torch.from_numpy(np.array(sequence_ids_jax)).cpu()))
self.assertTrue(torch.equal(m_tile_ids_torch.cpu(), torch.from_numpy(np.array(m_tile_ids_jax)).cpu()))

def test_make_sequence_metadata_for_ragged_paged_attn(self):
seq_lens = [
(1, 1328),
(5, 18),
(1, 129),
(120, 229),
(1, 122), # first physical q block
(1, 64),
(32, 100),
(250, 463),
(1, 18),
(1, 17),
(99, 123)
] # last 3 physical q blocks [(q_len, kv_len),...]
self. _verify_make_sequence_metadata(seq_lens)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand Down
Loading
Loading