Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix]: Fix paged attention unit tests of https://github.com/ROCm/vllm/pull/372 #389

Merged
merged 4 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -701,13 +701,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(

__syncthreads();

// disable rtz conversion due to its impact on accuracy.
constexpr bool LOGITS_RTZ_CONVERSION = false;

// write logits to shared mem
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
if constexpr (LOGITS_RTZ_CONVERSION) {
// use rtz conversion for performance, with no visible impact on accuracy
// use rtz conversion for better performance, with negligible impact on
// accuracy.
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4_rtz<scalar_t>(dout[token_depth]);
} else {
Expand Down
46 changes: 26 additions & 20 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import get_max_shared_memory_bytes, is_navi

from .allclose_default import get_default_atol, get_default_rtol

Expand All @@ -33,7 +33,7 @@

# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
Expand Down Expand Up @@ -116,7 +116,8 @@ def ref_single_query_cached_kv_attention(


@pytest.mark.parametrize(
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"])
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down Expand Up @@ -181,7 +182,11 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
k_scale = v_scale = torch.tensor(0.3, dtype=torch.float)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32)

# additional argument for v1/v2 pa kernel
num_threads = 1024 if current_platform.is_rocm() \
and not is_navi() else 128

# Call the paged attention kernel.
output = torch.empty_like(query)
Expand All @@ -203,12 +208,12 @@ def test_paged_attention(
v_scale,
)

opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]))

elif version in ("v2", "rocm"):
if current_platform.is_rocm():
Expand Down Expand Up @@ -247,13 +252,14 @@ def test_paged_attention(
v_scale,
)

opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache,
value_cache, num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
ops.paged_attention_rocm(
Expand Down Expand Up @@ -299,14 +305,14 @@ def test_paged_attention(
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = k_scale * dequantized_key_cache
key_cache = dequantized_key_cache

value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = v_scale * dequantized_value_cache
value_cache = dequantized_value_cache

ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
Expand Down Expand Up @@ -434,4 +440,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
Loading