From 51c48ec3612fc484b76966384b93154f00390872 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Sun, 23 Feb 2025 22:44:31 +0000 Subject: [PATCH 1/4] create the test --- test/test_pallas.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 7d4372b5019..7c9aaeb4961 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -687,18 +687,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): num_queries_per_block=num_queries_per_block, use_kernel=True) - nonkernel_output = ragged_paged_attention( - q_xla, - k_pages_xla, - v_pages_xla, - kv_lens_xla, - page_indices_xla, - cu_q_lens_xla, - num_seqs=num_seqs, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, - use_kernel=False) - q_jax = jnp.array(q.numpy(), dtype=jnp.float32) k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) @@ -723,9 +711,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): self.assertTrue( torch.allclose( output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) - self.assertTrue( - torch.allclose( - output.cpu(), nonkernel_output.cpu(), atol=2e-1, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") From 3c4eda736c769a8d5b8e92b700d917069630c15c Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Sun, 23 Feb 2025 23:41:32 +0000 Subject: [PATCH 2/4] avoid cpu and tpu back and forth --- torch_xla/experimental/custom_kernel.py | 170 +++++++++++++++++++++--- 1 file changed, 148 insertions(+), 22 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 269bd9c0020..a420e9f61b9 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -764,6 +764,137 @@ def _ragged_paged_attention_nonkernel( output = torch.cat(outputs, dim=0) # [num_tokens, num_query_heads, head_dim] return output +# https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 +def _make_sequence_metadata( + *, + cu_q_lens: torch.Tensor, + m: int, + tm: int, + start_sequence: torch.Tensor, + num_sequences: int, +): + """Create the metadata needed for ragged paged attention computation. + + Args: + cu_q_lens: : A 1d, jnp.ndarray with shape [num_seqs+1] and jnp.int32 dtype. + The cumulative query lengths. + m: The number of query tokens. + tm: The m-dimension tile size being used. + start_sequence: The sequence in cu_q_lens to start computing from. This is useful for when num_seqs is sharded. + num_sequences: The number of sequences to compute on. + + Returns: + tuple of: + seq_ids: A 1d, jnp.ndarray with shape [m_tiles + num_seqs] and + jnp.int32 dtype. seq_ids[i] indicates which sequence the grid index (num_logical_tiles_q) will work on. + physical_q_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_seqs] and + jnp.int32. physical_q_tile_ids[i] indicates which query-dim physical tile the grid index (num_logical_tiles_q) will work on. + + num_logical_q_tiles: The number of query-dim logical tiles to execute. + """ + device = cu_q_lens.device + + end_sequence = start_sequence + num_sequences - 1 + + # We need the offset of each sequence from input, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # sequence_offsets.shape = [num_sequences + 1] + # sequence_offsets[0] = 0 + # sequence_offsets[num_sequences] = m + # + # The row at which sequence 'i' starts is sequence_offsets[i]. + sequence_ends = cu_q_lens[1:] + sequence_offsets = cu_q_lens + + # Assign a sequence id to each grid index. The grid index refers to the logical q tile index. + # + # If a sequence starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each sequence by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the sequence_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change sequence_offsets[num_sequences], which is m + # (because we enforce m is divisible by tm). + rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).to(torch.int32) + + # (2) Round the sequence_starts down to the nearest multiple of 'tm'. + sequence_starts = torch.cat( + [torch.zeros(1, dtype=torch.int32).to(device), sequence_ends[:-1]]) + rounded_sequence_starts = sequence_starts // tm * tm + + # (3) Calculate the number of rows in each sequence. + rounded_sequence_sizes = rounded_sequence_ends - rounded_sequence_starts + + # (4) Convert the sequence sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by sequence 'i' if the first row of the tile + # belongs to sequence 'i'. In addition to owned tiles, each sequence can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' sequence never has a partial tile because it always starts at + # the 0-th row. + # + # If no sequence has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every sequence has a partial except the 0-th sequence, the total + # number of tiles is equal to 'm // tm + num_sequences - 1'. Thus we know that + # + # tiles_m <= sequence_tiles.sum() <= tiles_m + num_sequences - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All sequence sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + sequence_tiles = rounded_sequence_sizes // tm + + # Create the sequence ids for each grid index based on the tile counts for each + # sequence. + # + # NOTE: This repeat(...) will pad sequence_ids with the final sequence id if + # sequence_tiles.sum() < tiles_m + num_sequences - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + sequence_ids = repeat_with_fixed_output_size( + torch.arange(num_sequences, dtype=torch.int32).to(device), + sequence_tiles[:num_sequences], + total_repeat_length=tiles_m + num_sequences - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the sequence that owns the tile. + # The remaining possible visits occur when a sequence starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each sequence starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the sequence that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. + # + partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) + + partial_tile_ids = torch.where(partial_tile_mask, tiles_m, + sequence_offsets[:-1] // tm) + + tile_visits = (_histogram(partial_tile_ids, min=0, max=tiles_m - 1) + 1) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = repeat_with_fixed_output_size( + torch.arange(tiles_m, dtype=torch.int32).to(device), + tile_visits, + total_repeat_length=tiles_m + num_sequences - 1, + ) + num_tiles = sequence_tiles.sum(dtype=torch.int32) + return (sequence_ids, m_tile_ids + ), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles @requires_jax def ragged_paged_attention( @@ -794,7 +925,7 @@ def ragged_paged_attention( # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as ragged_attention, make_sequence_metadata + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as ragged_attention payload, tensor_args = trace_pallas( ragged_attention, q, @@ -814,47 +945,42 @@ def ragged_paged_attention( ], ) - sequence_metadata, num_logical_q_tiles = make_sequence_metadata( - cu_q_lens=cu_q_lens.cpu().numpy(), + q_device = q.device + sequence_metadata, num_logical_q_tiles = _make_sequence_metadata( + cu_q_lens=cu_q_lens, m=q.shape[0], tm=num_queries_per_block, - # TODO(jevinjiang, xiowei): pass start_sequence as input. - start_sequence=torch.tensor([0]).cpu().numpy(), + start_sequence=torch.tensor([0], dtype=torch.int32).to(q_device), num_sequences=num_seqs, ) assert len(sequence_metadata) == 2 - sequence_ids = torch.tensor( - sequence_metadata[0].tolist(), dtype=torch.int32).to("xla") - m_tile_ids = torch.tensor( - sequence_metadata[1].tolist(), dtype=torch.int32).to("xla") - num_q_tiles = torch.tensor( - num_logical_q_tiles.tolist(), dtype=torch.int32).to("xla") + sequence_ids, m_tile_ids = sequence_metadata q_dtype_for_kernel_launch = q.dtype page_indices_expanded = torch.unsqueeze(page_indices, 1) - buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") - step = torch.zeros((1,), dtype=torch.int32).to("xla") + buffer_index = torch.zeros((1,), dtype=torch.int32).to(q_device) + step = torch.zeros((1,), dtype=torch.int32).to(q_device) # The jax checkify in ragged paged attention kernel will insert several scalar refs to both inputs # (end of prefetch) and outputs (begining of the original outputs). # TODO(jevinjiang, xiowei): consider seperate checkify from kernel! - s1 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - s2 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - s3 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - s4 = torch.zeros((1, 1), dtype=torch.int32).to("xla") + s1 = torch.zeros((1, 1), dtype=torch.int32).to(q_device) + s2 = torch.zeros((1, 1), dtype=torch.int32).to(q_device) + s3 = torch.zeros((1, 1), dtype=torch.int32).to(q_device) + s4 = torch.zeros((1, 1), dtype=torch.int32).to(q_device) q = q.permute(1, 0, 2) MIN_BLOCK_SIZE = 128 output_shape = torch.Size(list(q.shape[:-1]) + [MIN_BLOCK_SIZE]) - num_q_tiles_1d = torch.tensor([num_logical_q_tiles.tolist()], - dtype=torch.int32).to("xla") + num_logical_q_tiles_1d = torch.tensor([num_logical_q_tiles], + dtype=torch.int32).to(q_device) # TODO(jevinjiang, xiowei): check err returned by checkify! And add tests. _, _, _, _, output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( [ - num_q_tiles, + num_logical_q_tiles, sequence_ids, m_tile_ids, - # Need num_q_tiles_1d to work around a Mosaic internal error. - num_q_tiles_1d, + # Need num_logical_q_tiles_1d to work around a Mosaic internal error. + num_logical_q_tiles_1d, kv_lens, cu_q_lens, buffer_index, From adcaa1be7a1d83b0fb45be60f51e3adb27490bfb Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 24 Feb 2025 00:37:25 +0000 Subject: [PATCH 3/4] add a few more tests. Somehow the dynamo test fails --- test/test_pallas.py | 146 ++++++++++++++++++++---- torch_xla/experimental/custom_kernel.py | 6 +- 2 files changed, 125 insertions(+), 27 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 7c9aaeb4961..6da8a44f438 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -636,35 +636,21 @@ def test_paged_attention_wrapper(self): atol=1e-5, rtol=1e-5)) - @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, - "This test only works on TPUv4+.") - def test_ragged_paged_attention_wrapper_without_dynamo(self): + def _verify_ragged_paged_attention_no_dynamo( + self, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ): from torch_xla.experimental.custom_kernel import ragged_paged_attention from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention - seq_lens = [ - (1, 1328), - (5, 18), - (1, 129), - (120, 229), - (1, 122), # first physical q block - (1, 64), - (32, 100), - (250, 463), - (1, 18), - (1, 17), - (99, 123) - ] # last 3 physical q blocks [(q_len, kv_len),...] - num_heads = (4, 4) - head_dim = 128 - dtype = torch.float32 - page_size = 16 - num_pages = 32768 num_seqs = len(seq_lens) num_kv_pages_per_block = 128 - num_queries_per_block = 8 - block_kv_size = 256 - + num_queries_per_block = 128 q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv( seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype) @@ -712,6 +698,117 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): torch.allclose( output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) + + + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_ragged_paged_attention_mix_prefill_decode1_without_dynamo(self): + seq_lens = [ + (1, 1328), + (5, 18), + (1, 129), + (120, 229), + (1, 122), # first physical q block + (1, 64), + (32, 100), + (250, 463), + (1, 18), + (1, 17), + (99, 123) + ] # last 3 physical q blocks [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = torch.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_no_dynamo( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_ragged_paged_attention_mix_prefill_decode2_without_dynamo(self): + seq_lens = [(1, 127), (120, 1328), (1, 64), (1, 64), (1, 64), (1, 64), + (256, 256), (131, 463)] # [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = torch.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_no_dynamo( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_ragged_paged_attention_mix_prefill_decode3_without_dynamo(self): + seq_lens = [(1, 1328), (5, 18), (506, 563)] # [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = torch.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_no_dynamo( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_ragged_paged_attention_all_tokens_belong_to_one_sequence_without_dynamo(self): + seq_lens = [(512, 1328)] # [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = torch.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_no_dynamo( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_ragged_paged_attention_one_tokens_per_sequence_without_dynamo(self): + seq_lens = [(512, 1328)] # [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = torch.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention_no_dynamo( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_ragged_paged_attention_wrapper_with_dynamo(self): @@ -736,7 +833,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(self): num_seqs = len(seq_lens) num_kv_pages_per_block = 128 num_queries_per_block = 8 - block_kv_size = 256 q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv( seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index a420e9f61b9..75ec442a559 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -926,6 +926,7 @@ def ragged_paged_attention( # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as ragged_attention + import pdb; pdb.set_trace() payload, tensor_args = trace_pallas( ragged_attention, q, @@ -970,8 +971,9 @@ def ragged_paged_attention( q = q.permute(1, 0, 2) MIN_BLOCK_SIZE = 128 output_shape = torch.Size(list(q.shape[:-1]) + [MIN_BLOCK_SIZE]) - num_logical_q_tiles_1d = torch.tensor([num_logical_q_tiles], - dtype=torch.int32).to(q_device) + # num_logical_q_tiles_1d = torch.tensor([num_logical_q_tiles], + # dtype=torch.int32).to(q_device) + num_logical_q_tiles_1d = num_logical_q_tiles.unsqueeze(0) # TODO(jevinjiang, xiowei): check err returned by checkify! And add tests. _, _, _, _, output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( From 44d00ab34338125856e6602ae4017ce322edbd7a Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 24 Feb 2025 19:46:17 +0000 Subject: [PATCH 4/4] add more tests --- test/test_pallas.py | 166 ++++++++++++++++++------ torch_xla/experimental/custom_kernel.py | 10 +- 2 files changed, 130 insertions(+), 46 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 6da8a44f438..a2d7ae8d50d 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -650,7 +650,7 @@ def _verify_ragged_paged_attention_no_dynamo( num_seqs = len(seq_lens) num_kv_pages_per_block = 128 - num_queries_per_block = 128 + num_queries_per_block = 16 q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv( seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype) @@ -661,7 +661,7 @@ def _verify_ragged_paged_attention_no_dynamo( page_indices_xla = page_indices.to("xla") cu_q_lens_xla = cu_q_lens.to("xla") - output = ragged_paged_attention( + kernel_output = ragged_paged_attention( q_xla, k_pages_xla, v_pages_xla, @@ -672,6 +672,20 @@ def _verify_ragged_paged_attention_no_dynamo( num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, use_kernel=True) + print(f'xw32 line675 {kernel_output=}') + + nonkernel_output = ragged_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_lens_xla, + page_indices_xla, + cu_q_lens_xla, + num_seqs=num_seqs, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + use_kernel=False) + print(f'xw32 line688 {nonkernel_output.shape=}') q_jax = jnp.array(q.numpy(), dtype=jnp.float32) k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) @@ -696,10 +710,10 @@ def _verify_ragged_paged_attention_no_dynamo( self.assertTrue( torch.allclose( - output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) - - - + kernel_output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) + self.assertTrue( + torch.allclose( + nonkernel_output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") @@ -722,14 +736,14 @@ def test_ragged_paged_attention_mix_prefill_decode1_without_dynamo(self): dtype = torch.float32 page_size = 16 num_pages = 32768 - + self._verify_ragged_paged_attention_no_dynamo( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, @@ -742,14 +756,14 @@ def test_ragged_paged_attention_mix_prefill_decode2_without_dynamo(self): dtype = torch.float32 page_size = 16 num_pages = 32768 - + self._verify_ragged_paged_attention_no_dynamo( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, @@ -761,33 +775,34 @@ def test_ragged_paged_attention_mix_prefill_decode3_without_dynamo(self): dtype = torch.float32 page_size = 16 num_pages = 32768 - + self._verify_ragged_paged_attention_no_dynamo( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") - def test_ragged_paged_attention_all_tokens_belong_to_one_sequence_without_dynamo(self): + def test_ragged_paged_attention_all_tokens_belong_to_one_sequence_without_dynamo( + self): seq_lens = [(512, 1328)] # [(q_len, kv_len),...] num_heads = (4, 4) head_dim = 128 dtype = torch.float32 page_size = 16 num_pages = 32768 - + self._verify_ragged_paged_attention_no_dynamo( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, @@ -799,14 +814,14 @@ def test_ragged_paged_attention_one_tokens_per_sequence_without_dynamo(self): dtype = torch.float32 page_size = 16 num_pages = 32768 - + self._verify_ragged_paged_attention_no_dynamo( - seq_lens, - num_heads, - head_dim, - page_size, - dtype, - num_pages, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, @@ -832,6 +847,7 @@ def test_ragged_paged_attention_wrapper_with_dynamo(self): num_pages = 32768 num_seqs = len(seq_lens) num_kv_pages_per_block = 128 + # TODO(xw32): consider test on various block sizes num_queries_per_block = 8 q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv( @@ -876,6 +892,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, num_queries_per_block=num_queries_per_block, use_kernel=True, ) + print(f"xw32 line877 {output=}") nonkernel_output = compiled_paged_attention( q_xla, @@ -889,10 +906,75 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, num_queries_per_block=num_queries_per_block, use_kernel=False, ) + print(f"xw32 line891 {nonkernel_output=}") + + from torch_xla.experimental.custom_kernel import ragged_paged_attention + ref_nondynamo_output = ragged_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_lens_xla, + page_indices_xla, + cu_q_lens_xla, + num_seqs=num_seqs, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + use_kernel=True) + print(f"xw32 line909 {ref_nondynamo_output=}") self.assertTrue( torch.allclose( output.cpu(), nonkernel_output.cpu(), atol=2e-1, rtol=1e-2)) + + + def _verify_make_sequence_metadata(self, seq_lens): + from torch_xla.experimental.custom_kernel import _make_sequence_metadata as torch_make_sequence_metadata + q_lens = [seq_len[0] for seq_len in seq_lens] + cu_q_lens_torch = torch.cumsum(torch.tensor([0] + q_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to("xla") + num_q_tokens = sum(q_lens) + num_queries_per_compute_block = 8 + start_group_torch = torch.tensor([0], dtype=torch.int32).to("xla") + num_seqs = len(seq_lens) + sequence_metadata_torch, num_logical_q_tiles_torch = torch_make_sequence_metadata( + cu_q_lens=cu_q_lens_torch, + m=num_q_tokens, + tm=num_queries_per_compute_block, + start_sequence=start_group_torch, + num_sequences=num_seqs, + ) + sequence_ids_torch, m_tile_ids_torch = sequence_metadata_torch + + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import make_sequence_metadata as jax_make_sequence_metadata + cu_q_lens_jax = jnp.cumsum(jnp.array([0] + q_lens)) + start_group_jax = jnp.array([0]) + sequence_metadata_jax, num_logical_q_tiles_jax = jax_make_sequence_metadata( + cu_q_lens=cu_q_lens_jax, + m=num_q_tokens, + tm=num_queries_per_compute_block, + start_sequence=start_group_jax, + num_sequences=num_seqs, + ) + sequence_ids_jax, m_tile_ids_jax = sequence_metadata_jax + + self.assertEqual(num_logical_q_tiles_torch.cpu(), torch.from_numpy(np.array(num_logical_q_tiles_jax)).cpu()) + self.assertTrue(torch.equal(sequence_ids_torch.cpu(), torch.from_numpy(np.array(sequence_ids_jax)).cpu())) + self.assertTrue(torch.equal(m_tile_ids_torch.cpu(), torch.from_numpy(np.array(m_tile_ids_jax)).cpu())) + + def test_make_sequence_metadata_for_ragged_paged_attn(self): + seq_lens = [ + (1, 1328), + (5, 18), + (1, 129), + (120, 229), + (1, 122), # first physical q block + (1, 64), + (32, 100), + (250, 463), + (1, 18), + (1, 17), + (99, 123) + ] # last 3 physical q blocks [(q_len, kv_len),...] + self. _verify_make_sequence_metadata(seq_lens) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 75ec442a559..ca210c9c4c9 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -764,6 +764,7 @@ def _ragged_paged_attention_nonkernel( output = torch.cat(outputs, dim=0) # [num_tokens, num_query_heads, head_dim] return output + # https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 def _make_sequence_metadata( *, @@ -881,21 +882,22 @@ def _make_sequence_metadata( partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) partial_tile_ids = torch.where(partial_tile_mask, tiles_m, - sequence_offsets[:-1] // tm) + sequence_offsets[:-1] // tm) tile_visits = (_histogram(partial_tile_ids, min=0, max=tiles_m - 1) + 1) # Create the m-dimension tile ids for each grid index based on the visit # counts for each tile. m_tile_ids = repeat_with_fixed_output_size( - torch.arange(tiles_m, dtype=torch.int32).to(device), + torch.arange(tiles_m, dtype=torch.int32).to(device), tile_visits, total_repeat_length=tiles_m + num_sequences - 1, ) - num_tiles = sequence_tiles.sum(dtype=torch.int32) + num_tiles = sequence_tiles[:num_sequences].sum(dtype=torch.int32) return (sequence_ids, m_tile_ids ), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles + @requires_jax def ragged_paged_attention( q, # [num_tokens, num_q_heads, head_dim] @@ -926,7 +928,6 @@ def ragged_paged_attention( # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as ragged_attention - import pdb; pdb.set_trace() payload, tensor_args = trace_pallas( ragged_attention, q, @@ -954,6 +955,7 @@ def ragged_paged_attention( start_sequence=torch.tensor([0], dtype=torch.int32).to(q_device), num_sequences=num_seqs, ) + # print(f'xw32 {num_logical_q_tiles=}') assert len(sequence_metadata) == 2 sequence_ids, m_tile_ids = sequence_metadata