Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register custom op for flash attn and use from torch.ops #7536

Merged
merged 18 commits into from
Aug 16, 2024
7 changes: 7 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py

- label: torch compile integration test
source_file_dependencies:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py


- label: Vision Language Models Test # 42min
mirror_hardwares: [amd]
source_file_dependencies:
Expand Down
20 changes: 20 additions & 0 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"

from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
llm.generate(prompts, sampling_params)
13 changes: 7 additions & 6 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

import vllm.attention.backends.flash_attn # noqa: F401

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -105,10 +106,10 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
output = torch.ops.vllm.flash_attn_with_kvcache(
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
Expand Down Expand Up @@ -185,7 +186,7 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand Down
176 changes: 150 additions & 26 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -18,6 +17,129 @@
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache


@torch.library.custom_op("vllm::flash_attn_varlen_func",
mutates_args=[
"q", "k", "v", "cu_seqlens_q", "cu_seqlens_k",
"block_table"
])
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[List[float]] = None,
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)


@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[List[float]] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# NOTE: shape can be incorrect.
# just annotate the shape to pass Dynamo
return torch.empty(q.shape,
dtype=q.dtype,
layout=q.layout,
device=q.device)


@torch.library.custom_op("vllm::flash_attn_with_kvcache",
mutates_args=[
"decode_query",
"key_cache",
"value_cache",
"cache_seqlens",
"block_table",
])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[List[int]] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)


@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[List[int]] = None,
softcap: float = 0.0,
) -> torch.Tensor:
# NOTE: shape can be incorrect.
# just annotate the shape to pass Dynamo
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
return torch.empty(decode_query.shape,
dtype=decode_query.dtype,
layout=decode_query.layout,
device=decode_query.device)


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -517,7 +639,7 @@ def forward(
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
out = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -537,34 +659,36 @@ def forward(
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
output[:
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
).squeeze(1)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
6 changes: 3 additions & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func

import vllm.attention.backends.flash_attn # noqa
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None

Expand Down Expand Up @@ -520,7 +520,7 @@ def forward(
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand Down
Loading