Skip to content

Commit

Permalink
register custom op for flash attn and use from torch.ops (vllm-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and fialhocoelho committed Aug 22, 2024
1 parent 4d3843b commit 93678a7
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 41 deletions.
7 changes: 7 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py

- label: torch compile integration test
source_file_dependencies:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py


- label: Vision Language Models Test # 42min
mirror_hardwares: [amd]
source_file_dependencies:
Expand Down
20 changes: 20 additions & 0 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"

from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
llm.generate(prompts, sampling_params)
73 changes: 61 additions & 12 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
import vllm.attention.backends.flash_attn # noqa: F401

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]


def ref_paged_attn(
Expand Down Expand Up @@ -72,6 +75,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
Expand All @@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
Expand All @@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(NUM_BLOCKS,
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
Expand All @@ -101,21 +106,40 @@ def test_flash_attn_with_paged_kv(

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand All @@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
Expand All @@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
Expand All @@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(NUM_BLOCKS,
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
Expand All @@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down
155 changes: 129 additions & 26 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -18,6 +17,108 @@
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache


@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)


@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)


@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)


@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -517,7 +618,7 @@ def forward(
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
out = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -537,34 +638,36 @@ def forward(
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
output[:
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
).squeeze(1)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
Loading

0 comments on commit 93678a7

Please sign in to comment.