Skip to content

Commit

Permalink
Switch to torch CUTLASS implementation (fairinternal/xformers#1121)
Browse files Browse the repository at this point in the history
* Switch to torch CUTLASS implementation

* fixed flake8 linters

* better way to check if a registered operator has a cuda implementation

* using get_operator for accessing the aten operator&use has_attr

* added PT suffix when using TORCH CUTLASS

__original_commit__ = fairinternal/xformers@8c1d0bd
  • Loading branch information
lvaleriu authored and xFormers Bot committed Jun 7, 2024
1 parent a40ca6e commit b3248b3
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 5 deletions.
69 changes: 69 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import setuptools
import torch
from torch._C import parse_schema
from torch.utils.cpp_extension import (
CUDA_HOME,
BuildExtension,
Expand Down Expand Up @@ -142,6 +143,54 @@ def get_hip_version(rocm_dir) -> str:
return None


def is_pt_cutlass_compatible(force: bool) -> bool:
compatible = True

fwd_schema_str = (
"aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, "
"Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, "
"SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, "
"float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> "
"(Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, "
"SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)"
)
expected_fwd_schema = parse_schema(fwd_schema_str)

current_schema = torch.ops.aten._efficient_attention_forward.default._schema
if not current_schema.is_backward_compatible_with(expected_fwd_schema):
compatible = False

if force:
raise ImportError(
f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_forward schema\n"
f"EXPECTED:\n{expected_fwd_schema}\n"
f"but GOT:\n{current_schema}"
)

bwd_schema_str = (
"aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, "
"Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, "
"SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, "
"int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, "
"int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)"
)

expected_bwd_schema = parse_schema(bwd_schema_str)

current_schema = torch.ops.aten._efficient_attention_backward.default._schema
if not current_schema.is_backward_compatible_with(expected_bwd_schema):
compatible = False

if force:
raise ImportError(
f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_backward schema\n"
f"EXPECTED:\n{expected_bwd_schema}\n"
f"but GOT:\n{current_schema}"
)

return compatible


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# XXX: Not supported on windows for cuda<12
# https://github.com/Dao-AILab/flash-attention/issues/345
Expand Down Expand Up @@ -245,6 +294,16 @@ def get_extensions():

sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
fmha_source_cuda = glob.glob(
os.path.join(extensions_dir, "**", "fmha", "**", "*.cu"), recursive=True
)
exclude_files = ["small_k.cu", "decoder.cu", "attention_cutlass_rand_uniform.cu"]
fmha_source_cuda = [
c
for c in fmha_source_cuda
if not any(exclude_file in c for exclude_file in exclude_files)
]

source_hip = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cpp"),
recursive=True,
Expand All @@ -258,6 +317,16 @@ def get_extensions():
sources = list(set(sources) - set(source_hip))

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")

xformers_pt_cutlass_attn = os.getenv("XFORMERS_PT_CUTLASS_ATTN")
# By default, we try to link to torch internal CUTLASS attention implementation
# and silently switch to local CUTLASS attention build if no compatibility
# If we force 'torch FA switch' then setup will fail when no compatibility
if (
xformers_pt_cutlass_attn is None or xformers_pt_cutlass_attn == "1"
) and is_pt_cutlass_compatible(force=xformers_pt_cutlass_attn == "1"):
source_cuda = list(set(source_cuda) - set(fmha_source_cuda))

cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,16 @@ def test_has_kernel_for(sm_shmem: Tuple[int, int], dtype_str: str) -> None:
if sm < 80 and dtype_str == "bf16":
return

if hasattr(torch.ops.xformers, "_has_cutlassF_kernel_for"):
pytest.skip(
"xformers doesnt have any _has_cutlassF_kernel_for implementation since it uses torch CUTLASS"
)

if hasattr(torch.ops.xformers, "_has_cutlassB_kernel_for"):
pytest.skip(
"xformers doesnt have any _has_cutlassB_kernel_for implementation since it uses torch CUTLASS"
)

for k in [16, 32, 64, 128, 256]:
assert torch.ops.xformers._has_cutlassF_kernel_for(
dtype, sm, shmem_kbytes * 1024, k
Expand Down Expand Up @@ -2394,6 +2404,12 @@ def test_cutlassB_iter_order(
the same block of dQ
.. and we test this across variable causal masks+local attention combinations
"""

if hasattr(torch.ops.xformers, "_cutlassB_iteration_data"):
pytest.skip(
"xformers doesnt have any _cutlassB_iteration_data implementation since it uses torch CUTLASS"
)

if (
window_size > 0
and custom_mask_type == fmha.cutlass._CustomMaskType.NoCustomMask
Expand Down
24 changes: 19 additions & 5 deletions xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from ..common import get_xformers_operator, register_operator
from ..common import get_operator, get_xformers_operator, register_operator
from . import attn_bias
from .attn_bias import (
AttentionBias,
Expand Down Expand Up @@ -157,14 +157,23 @@ def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int
return int(_CustomMaskType.NoCustomMask)


USE_TORCH_CUTLASS = not torch._C._dispatch_has_kernel_for_dispatch_key(
"xformers::efficient_attention_forward_cutlass", "CUDA"
)


@register_operator
class FwOp(AttentionFwOpBase):
"""xFormers' MHA kernel based on CUTLASS.
Supports a large number of settings (including without TensorCores, f32 ...)
and GPUs as old as P100 (Sm60)
"""

OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
OPERATOR = (
get_operator("aten", "_efficient_attention_forward")
if USE_TORCH_CUTLASS
else get_xformers_operator("efficient_attention_forward_cutlass")
)
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 65536
Expand All @@ -186,7 +195,7 @@ class FwOp(AttentionFwOpBase):
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = True
SUPPORTS_BMGHK = True
NAME = "cutlassF"
NAME = "cutlassF-pt" if USE_TORCH_CUTLASS else "cutlassF"

_TEST_K: List[int] = [
32, # 64x64 kernel
Expand Down Expand Up @@ -341,7 +350,12 @@ def operator_flop(
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
OPERATOR = (
get_operator("aten", "_efficient_attention_backward")
if USE_TORCH_CUTLASS
else get_xformers_operator("efficient_attention_backward_cutlass")
)

SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
Expand All @@ -364,7 +378,7 @@ class BwOp(AttentionBwOpBase):
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
NAME = "cutlassB"
NAME = "cutlassB-pt" if USE_TORCH_CUTLASS else "cutlassB"

_TEST_K: List[int] = [
32, # 64x64 kernel
Expand Down

0 comments on commit b3248b3

Please sign in to comment.