Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register custom op for flash attn and use from torch.ops #7536

Merged
merged 18 commits into from
Aug 16, 2024
7 changes: 7 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py

- label: torch compile integration test
source_file_dependencies:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py


- label: Vision Language Models Test # 42min
mirror_hardwares: [amd]
source_file_dependencies:
Expand Down
20 changes: 20 additions & 0 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"

from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
llm.generate(prompts, sampling_params)
73 changes: 61 additions & 12 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
import vllm.attention.backends.flash_attn # noqa: F401

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]


def ref_paged_attn(
Expand Down Expand Up @@ -72,6 +75,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
Expand All @@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
Expand All @@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(NUM_BLOCKS,
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
Expand All @@ -101,21 +106,40 @@ def test_flash_attn_with_paged_kv(

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
output = torch.ops.vllm.flash_attn_with_kvcache(
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand All @@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
Expand All @@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
Expand All @@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(NUM_BLOCKS,
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
Expand All @@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down
155 changes: 129 additions & 26 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -18,6 +17,108 @@
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache


@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed with @WoosukKwon , these two functions do not mutate the input.

def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)


@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)


@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)


@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -517,7 +618,7 @@ def forward(
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
out = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -537,34 +638,36 @@ def forward(
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
output[:
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
).squeeze(1)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
Loading
Loading