-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TL] initial implement flashattention op in TL (#202)
* [TL] initial implement flashattention op in TL * [TL] initial implement flashattention op in TL * [TL] initial implement flashattention op in TL * [TL] initial implement flashattention op in TL & kernel name check * [BugFix] [TL] modify tilelang import path, add flashatten scheduler testscript * [BugFix] [TL] modify tilelang import path, add flashatten scheduler testscript
- Loading branch information
Showing
11 changed files
with
769 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from bitblas.base.roller.hint import Hint | ||
from tvm.target import Target | ||
from .tilelang import select_scheduler as consistent_scheduler | ||
from ..base_scheduler import BaseScheduler | ||
from ..operator import OperatorConfig, Operator, BaseKernelNameGenerator | ||
from ...base.arch.cuda import CUDA | ||
from ...utils import auto_detect_nvidia_target | ||
from dataclasses import dataclass | ||
from typing import Union, Tuple, Literal, Optional | ||
import logging | ||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
WORKSPACE_SIZE = 1024 * 1024 * 256 | ||
|
||
|
||
def is_native_compute(Q_dtype, K_dtype, V_dtype) -> bool: | ||
return Q_dtype == K_dtype and K_dtype == V_dtype | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FlashAttenConfig(OperatorConfig): | ||
batch: Union[int, Tuple[int]] = None | ||
# TODO should distinguish from q_heads and kv_heads | ||
heads: Optional[int] = None | ||
kv_heads: Optional[int] = None | ||
seq_len: Optional[int] = None | ||
dim: Optional[int] = None | ||
Q_dtype: str = "float16" | ||
K_dtype: str = Q_dtype # for default | ||
V_dtype: str = Q_dtype | ||
Accu_dtype: str = "float32" | ||
Out_dtype: str = "float16" | ||
layout: Literal["nnn", "ntn"] = "nnn" | ||
is_causal: bool = False | ||
|
||
|
||
class FlashAttenKernelNameGenerator(BaseKernelNameGenerator): | ||
|
||
KERNEL_PREFIX = "flashatten" | ||
|
||
def is_valid_config(self, config: OperatorConfig) -> bool: | ||
return isinstance(config, FlashAttenConfig) | ||
|
||
@staticmethod | ||
def simplify_dtype(dtype: str) -> str: | ||
if dtype.startswith("float"): | ||
return f"f{dtype[5:]}" | ||
elif dtype.startswith("bfloat"): | ||
return f"bf{dtype[6:]}" | ||
elif dtype.startswith("int"): | ||
return f"i{dtype[3:]}" | ||
elif dtype.startswith("uint"): | ||
return f"u{dtype[4:]}" | ||
else: | ||
raise ValueError("Currently only support float, bfloat, int, uint") | ||
|
||
def generate(self, hint: Hint = None) -> str: | ||
config = self.config | ||
kernel_name = self.KERNEL_PREFIX | ||
shape_str = f"batch{self.config.batch}heads{self.config.heads}seqlen{self.config.seq_len}dim{self.config.dim}" | ||
Q_dtype = self.simplify_dtype(config.Q_dtype) | ||
K_dtype = self.simplify_dtype(config.K_dtype) | ||
V_dtype = self.simplify_dtype(config.V_dtype) | ||
Accu_dtype = self.simplify_dtype(config.Accu_dtype) | ||
Out_dtype = self.simplify_dtype(config.Out_dtype) | ||
precision_str = f"Q{Q_dtype}_K{K_dtype}_V{V_dtype}_Accu{Accu_dtype}_Out{Out_dtype}" | ||
kernel_name = "_".join([kernel_name, shape_str, precision_str]) | ||
# TODO need to add hint | ||
assert self.is_valid(kernel_name), "Kernel name invalid" | ||
return kernel_name | ||
|
||
|
||
class FlashAtten(Operator): | ||
|
||
BITBLAS_TRICK_DTYPE_MAP = { | ||
"float32": ("fp", 32), | ||
"float16": ("fp", 16), | ||
"int8": ("int", 8), | ||
"int4": ("int", 4), | ||
} | ||
|
||
def __init__( | ||
self, | ||
config: FlashAttenConfig, | ||
name: str = "flashatten", | ||
target: Optional[Union[str, Target]] = None, | ||
enable_tuning: bool = True, | ||
from_database: bool = False, | ||
backend: str = "tl", | ||
): | ||
if target is None: | ||
target = auto_detect_nvidia_target() | ||
logger.info(f"Auto detected target: {target}") | ||
|
||
assert (config.Q_dtype | ||
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.Q_dtype}" | ||
assert (config.K_dtype | ||
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.K_dtype}" | ||
assert (config.V_dtype | ||
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.V_dtype}" | ||
assert backend == "tl", "FlashAttention only support TL compiler" | ||
|
||
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.Q_dtype] | ||
|
||
self.source_format = source_format | ||
self.bit = bit | ||
self.backend = backend | ||
super().__init__(name, config, target, backend) | ||
|
||
target = self.target | ||
if target.kind.name != "cuda": | ||
raise ValueError("Currently only support cuda target") | ||
|
||
self.dispatch_tl(target, from_database, source_format, enable_tuning) | ||
|
||
def dispatch_tl(self, | ||
target: Target, | ||
from_database: bool = False, | ||
source_format: str = "fp16", | ||
enable_tuning: bool = True): | ||
self.arch = CUDA(target) | ||
if not from_database: | ||
self._build_default_module(target) | ||
self.workspace = None | ||
if enable_tuning: | ||
self.hardware_aware_finetune() | ||
self.torch_output_dtype = getattr(torch, self.Out_dtype) | ||
|
||
def get_kernel_name_generator(self): | ||
return FlashAttenKernelNameGenerator(self.config) | ||
|
||
def _alloc_workspace(self): | ||
return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() | ||
|
||
def _free_workspace(self): | ||
# release the workspace if it is None | ||
if self.workspace is not None: | ||
self.workspace = None | ||
|
||
def _select_scheduler(self) -> Optional[BaseScheduler]: | ||
if is_native_compute(self.Q_dtype, self.K_dtype, self.V_dtype): | ||
return consistent_scheduler( | ||
batch=self.batch, | ||
heads=self.heads, | ||
seq_len=self.seq_len, | ||
dim=self.dim, | ||
layout=self.layout, | ||
dtype_QKV=self.Q_dtype, | ||
dtype_Out=self.Out_dtype, | ||
dtype_Accu=self.Accu_dtype, | ||
is_causal=self.is_causal, | ||
) | ||
else: | ||
raise ValueError("Currently only support native compute for scheduler") | ||
|
||
def cleanup(self): | ||
self._free_workspace() | ||
|
||
@property | ||
def batch(self): | ||
return self.config.batch | ||
|
||
@property | ||
def heads(self): | ||
return self.config.heads | ||
|
||
@property | ||
def seq_len(self): | ||
return self.config.seq_len | ||
|
||
@property | ||
def dim(self): | ||
return self.config.dim | ||
|
||
@property | ||
def Q_dtype(self): | ||
return self.config.Q_dtype | ||
|
||
@property | ||
def K_dtype(self): | ||
return self.config.K_dtype | ||
|
||
@property | ||
def V_dtype(self): | ||
return self.config.V_dtype | ||
|
||
@property | ||
def Accu_dtype(self): | ||
return self.config.Accu_dtype | ||
|
||
@property | ||
def Out_dtype(self): | ||
return self.config.Out_dtype | ||
|
||
@property | ||
def layout(self): | ||
return self.config.layout | ||
|
||
@property | ||
def is_causal(self): | ||
return self.config.is_causal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from .flashatten import flashatten_blocked # noqa: F401 | ||
from .flashatten import FlashAttenScheduler # noqa: F401 | ||
|
||
|
||
def parse_layout(layout: str): | ||
trans_Q = False | ||
trans_K = layout[1] == 't' | ||
trans_V = False | ||
return trans_Q, trans_K, trans_V | ||
|
||
|
||
def select_scheduler( | ||
batch=None, | ||
heads=None, | ||
seq_len=None, | ||
dim=None, | ||
layout="nnn", | ||
dtype_QKV="float16", | ||
dtype_Out="float16", | ||
dtype_Accu="float32", | ||
is_causal=False, | ||
): | ||
trans_list = parse_layout(layout) | ||
trans_K = trans_list[1] | ||
return FlashAttenScheduler( | ||
batch=batch, | ||
heads=heads, | ||
seq_len=seq_len, | ||
dim=dim, | ||
trans_K=trans_K, | ||
dtype_QKV=dtype_QKV, | ||
dtype_Out=dtype_Out, | ||
dtype_Accu=dtype_Accu, | ||
is_causal=is_causal, | ||
) |
Oops, something went wrong.