From 8649a5b34506d8a51756bc0dba56fe4b2cfc20a7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 16:48:46 -0700 Subject: [PATCH 01/18] register --- vllm/attention/backends/flash_attn.py | 115 +++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 160bf2307fbf5..5a41349ac1bb1 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -18,6 +17,120 @@ if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder +from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func +from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache + + +@torch.library.custom_op("vllm::flash_attn_varlen_func", + mutates_args="unknown") +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Optional[List[int]] = None, + softcap: float = 0.0, + alibi_slopes: Optional[List[float]] = None, + block_table: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # custom op does not support tuple input + real_window_size: Tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + return _flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=causal, + window_size=real_window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + block_table=block_table, + ) + + +@flash_attn_varlen_func.register_fake # type: ignore +def _( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Optional[List[int]] = None, + softcap: float = 0.0, + alibi_slopes: Optional[List[float]] = None, + block_table: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # NOTE: shape can be incorrect. + # just annotate the shape to pass Dynamo + return torch.empty(q.shape, + dtype=q.dtype, + layout=q.layout, + device=q.device) + + +@torch.library.custom_op("vllm::flash_attn_with_kvcache", + mutates_args="unknown") +def flash_attn_with_kvcache( + decode_query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cache_seqlens: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + alibi_slopes: Optional[List[int]] = None, + softcap: float = 0.0, +) -> torch.Tensor: + return _flash_attn_with_kvcache( + decode_query, + key_cache, + value_cache, + cache_seqlens=cache_seqlens, + block_table=block_table, + softmax_scale=softmax_scale, + causal=causal, + alibi_slopes=alibi_slopes, + softcap=softcap, + ) + + +@flash_attn_with_kvcache.register_fake # type: ignore +def _( + decode_query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cache_seqlens: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + alibi_slopes: Optional[List[int]] = None, + softcap: float = 0.0, +) -> torch.Tensor: + # NOTE: shape can be incorrect. + # just annotate the shape to pass Dynamo + return torch.empty(decode_query.shape, + dtype=decode_query.dtype, + layout=decode_query.layout, + device=decode_query.device) + class FlashAttentionBackend(AttentionBackend): From c2c8ca6d7bc1e1d8077f809e29eb61e1a06a3a04 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 16:50:33 -0700 Subject: [PATCH 02/18] use --- vllm/attention/backends/flash_attn.py | 52 ++++++++++++++------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5a41349ac1bb1..08b041938756f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -630,7 +630,7 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - out = flash_attn_varlen_func( + out = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, v=value, @@ -650,34 +650,36 @@ def forward( # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, + output[: + num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=self.logits_soft_cap, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[ + num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) From 679b18acb49bfd95fe1b31a0cd891442a73f4e2f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 16:53:58 -0700 Subject: [PATCH 03/18] use --- tests/kernels/test_flash_attn.py | 6 +++--- vllm/attention/backends/flashinfer.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 374ba0afb5f41..ec6abcb87eec1 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -2,7 +2,7 @@ import pytest import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +import vllm_flash_attn # noqa: F401 NUM_HEADS = [(16, 16), (32, 8), (64, 8)] HEAD_SIZES = [128, 256] @@ -105,7 +105,7 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = flash_attn_with_kvcache( + output = torch.ops.vllm.flash_attn_with_kvcache( q=query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, @@ -185,7 +185,7 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = flash_attn_varlen_func( + output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index fad873b448a34..5e5c7038c0c2b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -2,11 +2,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: + import vllm_flash_attn # noqa from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - from vllm_flash_attn import flash_attn_varlen_func except ImportError: - flash_attn_varlen_func = None BatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None @@ -520,7 +519,7 @@ def forward( # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache is None: - output = flash_attn_varlen_func( + output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, v=value, From 94a39ccbfc58b6c943b3cb9cd1a2e090a9cad6e9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:01:06 -0700 Subject: [PATCH 04/18] manually mutate all --- vllm/attention/backends/flash_attn.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 08b041938756f..20680c2aae98c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -22,7 +22,12 @@ @torch.library.custom_op("vllm::flash_attn_varlen_func", - mutates_args="unknown") + mutates_args=[ + "q", "k", "v", "cu_seqlens_q", "cu_seqlens_k", + "max_seqlen_q", "max_seqlen_k", "softmax_scale", + "causal", "window_size", "softcap", + "alibi_slopes", "block_table" + ]) def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, @@ -87,7 +92,11 @@ def _( @torch.library.custom_op("vllm::flash_attn_with_kvcache", - mutates_args="unknown") + mutates_args=[ + "decode_query", "key_cache", "value_cache", + "cache_seqlens", "block_table", "softmax_scale", + "causal", "alibi_slopes", "softcap" + ]) def flash_attn_with_kvcache( decode_query: torch.Tensor, key_cache: torch.Tensor, From bdbbe763050a66e9e9bb7136af0d9d401e209c59 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:02:59 -0700 Subject: [PATCH 05/18] manually mutate all tensors --- vllm/attention/backends/flash_attn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 20680c2aae98c..7906925a16530 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -24,9 +24,7 @@ @torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[ "q", "k", "v", "cu_seqlens_q", "cu_seqlens_k", - "max_seqlen_q", "max_seqlen_k", "softmax_scale", - "causal", "window_size", "softcap", - "alibi_slopes", "block_table" + "block_table" ]) def flash_attn_varlen_func( q: torch.Tensor, @@ -93,9 +91,11 @@ def _( @torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[ - "decode_query", "key_cache", "value_cache", - "cache_seqlens", "block_table", "softmax_scale", - "causal", "alibi_slopes", "softcap" + "decode_query", + "key_cache", + "value_cache", + "cache_seqlens", + "block_table", ]) def flash_attn_with_kvcache( decode_query: torch.Tensor, From 5b64f2da72bc72eb24f652f8ca988b6a2d594ae7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:13:09 -0700 Subject: [PATCH 06/18] add tests --- tests/compile/test_full_graph.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/compile/test_full_graph.py diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py new file mode 100644 index 0000000000000..d5b59db8c7887 --- /dev/null +++ b/tests/compile/test_full_graph.py @@ -0,0 +1,20 @@ +import os + +import pytest + + +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_full_graph(model): + # make sure these models can be captured in full graph mode + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + + from vllm import LLM, SamplingParams + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model="meta-llama/Meta-Llama-3-8B") + llm.generate(prompts, sampling_params) From 9d97f7b0f6c1a676c3a12bd94105d49f71cc318e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:15:49 -0700 Subject: [PATCH 07/18] add tests --- .buildkite/test-pipeline.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 71afb1aa52883..116761aca7830 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -162,6 +162,13 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py +- label: torch.compile integration test + source_file_dependencies: + - vllm/ + commands: + - pytest -v -s ./compile/test_full_graph.py + + - label: Vision Language Models Test # 42min mirror_hardwares: [amd] source_file_dependencies: From 506eed5a19f543bb5f46668537d447096d073bbb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:20:48 -0700 Subject: [PATCH 08/18] change import --- tests/kernels/test_flash_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index ec6abcb87eec1..66ac4b6c72e72 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -2,7 +2,8 @@ import pytest import torch -import vllm_flash_attn # noqa: F401 + +import vllm.attention.backends.flash_attn # noqa: F401 NUM_HEADS = [(16, 16), (32, 8), (64, 8)] HEAD_SIZES = [128, 256] From d9105aa3d6fdc52d634ff417072a66944b9da293 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:21:23 -0700 Subject: [PATCH 09/18] update tests --- tests/kernels/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 66ac4b6c72e72..2dc485de7ed72 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -107,7 +107,7 @@ def test_flash_attn_with_paged_kv( dtype=torch.int32) output = torch.ops.vllm.flash_attn_with_kvcache( - q=query.unsqueeze(1), + query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, softmax_scale=scale, From f827ad3c0a350a4e30c44f2d8125622408912b2b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:23:40 -0700 Subject: [PATCH 10/18] change args --- tests/kernels/test_flash_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 2dc485de7ed72..e9ae8529cd659 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -107,9 +107,9 @@ def test_flash_attn_with_paged_kv( dtype=torch.int32) output = torch.ops.vllm.flash_attn_with_kvcache( - query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, softmax_scale=scale, causal=True, block_table=block_tables, From f0fe2885e17049ccf4986126bdb1392a6e8484d5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:25:40 -0700 Subject: [PATCH 11/18] change import --- vllm/attention/backends/flashinfer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5e5c7038c0c2b..3022fa70e2ca7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: - import vllm_flash_attn # noqa from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + + import vllm.attention.backends.flash_attn # noqa except ImportError: BatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None From 8c322b04be71cef6a1c98354736e8d9e0266f584 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 14 Aug 2024 17:38:47 -0700 Subject: [PATCH 12/18] rename --- .buildkite/test-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 116761aca7830..749d9702a848f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -162,7 +162,7 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py -- label: torch.compile integration test +- label: torch compile integration test source_file_dependencies: - vllm/ commands: From 755dbaf656f9bb4b6af63dd1882561d1f6949493 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 15:23:19 -0700 Subject: [PATCH 13/18] fix register fake --- vllm/attention/backends/flash_attn.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7906925a16530..5c148c76162b9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -81,12 +81,7 @@ def _( alibi_slopes: Optional[List[float]] = None, block_table: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # NOTE: shape can be incorrect. - # just annotate the shape to pass Dynamo - return torch.empty(q.shape, - dtype=q.dtype, - layout=q.layout, - device=q.device) + return torch.empty_like(q) @torch.library.custom_op("vllm::flash_attn_with_kvcache", @@ -133,12 +128,7 @@ def _( alibi_slopes: Optional[List[int]] = None, softcap: float = 0.0, ) -> torch.Tensor: - # NOTE: shape can be incorrect. - # just annotate the shape to pass Dynamo - return torch.empty(decode_query.shape, - dtype=decode_query.dtype, - layout=decode_query.layout, - device=decode_query.device) + return torch.empty_like(decode_query) class FlashAttentionBackend(AttentionBackend): From fc2a4c2898159885c21441217deb2373a9776512 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 15:40:25 -0700 Subject: [PATCH 14/18] add opcheck --- tests/kernels/test_flash_attn.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index e9ae8529cd659..c5b4ef698b0ba 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -117,6 +117,20 @@ def test_flash_attn_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) + torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, + args=tuple(), + kwargs=dict( + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=("test_faketensor", )) + ref_output = ref_paged_attn( query=query, key_cache=key_cache, @@ -201,6 +215,24 @@ def test_varlen_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ) + torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, + args=tuple(), + kwargs=dict( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=("test_faketensor", )) + ref_output = ref_paged_attn( query=query, key_cache=key_cache, From 495d2f02703d012f0d8e3047d191c2b9e8c9a521 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 16:37:02 -0700 Subject: [PATCH 15/18] fix alibi_slopes --- vllm/attention/backends/flash_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5c148c76162b9..3fa94808781e9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -38,7 +38,7 @@ def flash_attn_varlen_func( causal: bool = False, window_size: Optional[List[int]] = None, softcap: float = 0.0, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, ) -> torch.Tensor: # custom op does not support tuple input @@ -78,7 +78,7 @@ def _( causal: bool = False, window_size: Optional[List[int]] = None, softcap: float = 0.0, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.empty_like(q) @@ -100,7 +100,7 @@ def flash_attn_with_kvcache( block_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, - alibi_slopes: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, softcap: float = 0.0, ) -> torch.Tensor: return _flash_attn_with_kvcache( @@ -125,7 +125,7 @@ def _( block_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, - alibi_slopes: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, softcap: float = 0.0, ) -> torch.Tensor: return torch.empty_like(decode_query) From 76c5cecf4577cb10572fc699cc2d60585dc603d1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 17:11:22 -0700 Subject: [PATCH 16/18] update mutates_args --- vllm/attention/backends/flash_attn.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 3fa94808781e9..f230bb57e3177 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -21,11 +21,7 @@ from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache -@torch.library.custom_op("vllm::flash_attn_varlen_func", - mutates_args=[ - "q", "k", "v", "cu_seqlens_q", "cu_seqlens_k", - "block_table" - ]) +@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, @@ -84,14 +80,7 @@ def _( return torch.empty_like(q) -@torch.library.custom_op("vllm::flash_attn_with_kvcache", - mutates_args=[ - "decode_query", - "key_cache", - "value_cache", - "cache_seqlens", - "block_table", - ]) +@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[]) def flash_attn_with_kvcache( decode_query: torch.Tensor, key_cache: torch.Tensor, From 45bb13179bb46a522c22d703f24a23e4abd349ce Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 21:46:34 -0700 Subject: [PATCH 17/18] add schema tests --- tests/kernels/test_flash_attn.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index c5b4ef698b0ba..df32b8c19a516 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -9,7 +9,9 @@ HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] def ref_paged_attn( @@ -73,6 +75,7 @@ def ref_paged_attn( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], @@ -81,6 +84,7 @@ def test_flash_attn_with_paged_kv( dtype: torch.dtype, block_size: int, soft_cap: Optional[float], + num_blocks: int, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -92,7 +96,7 @@ def test_flash_attn_with_paged_kv( scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(NUM_BLOCKS, + key_cache = torch.randn(num_blocks, block_size, num_kv_heads, head_size, @@ -102,7 +106,7 @@ def test_flash_attn_with_paged_kv( max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, - NUM_BLOCKS, + num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) @@ -117,6 +121,11 @@ def test_flash_attn_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) + if num_blocks <= 2048: + test_utils = ["test_faketensor", "test_schema"] + else: + test_utils = ["test_faketensor"] + torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, args=tuple(), kwargs=dict( @@ -129,7 +138,7 @@ def test_flash_attn_with_paged_kv( cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, ), - test_utils=("test_faketensor", )) + test_utils=test_utils) ref_output = ref_paged_attn( query=query, @@ -152,6 +161,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], @@ -161,6 +171,7 @@ def test_varlen_with_paged_kv( dtype: torch.dtype, block_size: int, soft_cap: Optional[float], + num_blocks: int, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -181,7 +192,7 @@ def test_varlen_with_paged_kv( num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(NUM_BLOCKS, + key_cache = torch.randn(num_blocks, block_size, num_kv_heads, head_size, @@ -196,7 +207,7 @@ def test_varlen_with_paged_kv( max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, - NUM_BLOCKS, + num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) @@ -215,6 +226,11 @@ def test_varlen_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ) + if num_blocks <= 2048: + test_utils = ["test_faketensor", "test_schema"] + else: + test_utils = ["test_faketensor"] + torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, args=tuple(), kwargs=dict( @@ -231,7 +247,7 @@ def test_varlen_with_paged_kv( block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, ), - test_utils=("test_faketensor", )) + test_utils=test_utils) ref_output = ref_paged_attn( query=query, From ee8d42665fe76be12a818283df28e41a10538a49 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 21:46:58 -0700 Subject: [PATCH 18/18] reduce number of heads to avoid OOM --- tests/kernels/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index df32b8c19a516..9a3597e5d6d9d 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -5,7 +5,7 @@ import vllm.attention.backends.flash_attn # noqa: F401 -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16]