Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register custom op for flash attn and use from torch.ops #7536

Merged
merged 18 commits into from
Aug 16, 2024
7 changes: 7 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py

- label: torch compile integration test
source_file_dependencies:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py


- label: Vision Language Models Test # 42min
mirror_hardwares: [amd]
source_file_dependencies:
Expand Down
20 changes: 20 additions & 0 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"

from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
llm.generate(prompts, sampling_params)
45 changes: 39 additions & 6 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

import vllm.attention.backends.flash_attn # noqa: F401

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -105,17 +106,31 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
output = torch.ops.vllm.flash_attn_with_kvcache(
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)

torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=("test_faketensor", ))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add test_schema later, after solving the OOM issue.

currently, I will get OOM when I test the schema, even though I'm using H100 80GB.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think part of testing the schema involves copying all the inputs and doubling checking that they are (or aren't) mutated in agreement with the op schema. I don't know if that's what is causing the OOMs here tho.


ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down Expand Up @@ -185,7 +200,7 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -200,6 +215,24 @@ def test_varlen_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
)

torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=("test_faketensor", ))

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down
155 changes: 129 additions & 26 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -18,6 +17,108 @@
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache


@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed with @WoosukKwon , these two functions do not mutate the input.

def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)


@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)


@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)


@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -517,7 +618,7 @@ def forward(
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
out = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -537,34 +638,36 @@ def forward(
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
output[:
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
).squeeze(1)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
6 changes: 3 additions & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func

import vllm.attention.backends.flash_attn # noqa
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None

Expand Down Expand Up @@ -520,7 +520,7 @@ def forward(
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand Down
Loading