diff --git a/bitblas/__init__.py b/bitblas/__init__.py index cae8f9d5f..937a3c1c7 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -39,6 +39,7 @@ from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 +from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 from .module import Linear # noqa: F401 import warnings diff --git a/bitblas/ops/general_flashatten/__init__.py b/bitblas/ops/general_flashatten/__init__.py new file mode 100644 index 000000000..d9db00fab --- /dev/null +++ b/bitblas/ops/general_flashatten/__init__.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.base.roller.hint import Hint +from tvm.target import Target +from .tilelang import select_scheduler as consistent_scheduler +from ..base_scheduler import BaseScheduler +from ..operator import OperatorConfig, Operator, BaseKernelNameGenerator +from ...base.arch.cuda import CUDA +from ...utils import auto_detect_nvidia_target +from dataclasses import dataclass +from typing import Union, Tuple, Literal, Optional +import logging +import torch + +logger = logging.getLogger(__name__) + +WORKSPACE_SIZE = 1024 * 1024 * 256 + + +def is_native_compute(Q_dtype, K_dtype, V_dtype) -> bool: + return Q_dtype == K_dtype and K_dtype == V_dtype + + +@dataclass(frozen=True) +class FlashAttenConfig(OperatorConfig): + batch: Union[int, Tuple[int]] = None + # TODO should distinguish from q_heads and kv_heads + heads: Optional[int] = None + kv_heads: Optional[int] = None + seq_len: Optional[int] = None + dim: Optional[int] = None + Q_dtype: str = "float16" + K_dtype: str = Q_dtype # for default + V_dtype: str = Q_dtype + Accu_dtype: str = "float32" + Out_dtype: str = "float16" + layout: Literal["nnn", "ntn"] = "nnn" + is_causal: bool = False + + +class FlashAttenKernelNameGenerator(BaseKernelNameGenerator): + + KERNEL_PREFIX = "flashatten" + + def is_valid_config(self, config: OperatorConfig) -> bool: + return isinstance(config, FlashAttenConfig) + + @staticmethod + def simplify_dtype(dtype: str) -> str: + if dtype.startswith("float"): + return f"f{dtype[5:]}" + elif dtype.startswith("bfloat"): + return f"bf{dtype[6:]}" + elif dtype.startswith("int"): + return f"i{dtype[3:]}" + elif dtype.startswith("uint"): + return f"u{dtype[4:]}" + else: + raise ValueError("Currently only support float, bfloat, int, uint") + + def generate(self, hint: Hint = None) -> str: + config = self.config + kernel_name = self.KERNEL_PREFIX + shape_str = f"batch{self.config.batch}heads{self.config.heads}seqlen{self.config.seq_len}dim{self.config.dim}" + Q_dtype = self.simplify_dtype(config.Q_dtype) + K_dtype = self.simplify_dtype(config.K_dtype) + V_dtype = self.simplify_dtype(config.V_dtype) + Accu_dtype = self.simplify_dtype(config.Accu_dtype) + Out_dtype = self.simplify_dtype(config.Out_dtype) + precision_str = f"Q{Q_dtype}_K{K_dtype}_V{V_dtype}_Accu{Accu_dtype}_Out{Out_dtype}" + kernel_name = "_".join([kernel_name, shape_str, precision_str]) + # TODO need to add hint + assert self.is_valid(kernel_name), "Kernel name invalid" + return kernel_name + + +class FlashAtten(Operator): + + BITBLAS_TRICK_DTYPE_MAP = { + "float32": ("fp", 32), + "float16": ("fp", 16), + "int8": ("int", 8), + "int4": ("int", 4), + } + + def __init__( + self, + config: FlashAttenConfig, + name: str = "flashatten", + target: Optional[Union[str, Target]] = None, + enable_tuning: bool = True, + from_database: bool = False, + backend: str = "tl", + ): + if target is None: + target = auto_detect_nvidia_target() + logger.info(f"Auto detected target: {target}") + + assert (config.Q_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.Q_dtype}" + assert (config.K_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.K_dtype}" + assert (config.V_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.V_dtype}" + assert backend == "tl", "FlashAttention only support TL compiler" + + source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.Q_dtype] + + self.source_format = source_format + self.bit = bit + self.backend = backend + super().__init__(name, config, target, backend) + + target = self.target + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + + self.dispatch_tl(target, from_database, source_format, enable_tuning) + + def dispatch_tl(self, + target: Target, + from_database: bool = False, + source_format: str = "fp16", + enable_tuning: bool = True): + self.arch = CUDA(target) + if not from_database: + self._build_default_module(target) + self.workspace = None + if enable_tuning: + self.hardware_aware_finetune() + self.torch_output_dtype = getattr(torch, self.Out_dtype) + + def get_kernel_name_generator(self): + return FlashAttenKernelNameGenerator(self.config) + + def _alloc_workspace(self): + return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + + def _free_workspace(self): + # release the workspace if it is None + if self.workspace is not None: + self.workspace = None + + def _select_scheduler(self) -> Optional[BaseScheduler]: + if is_native_compute(self.Q_dtype, self.K_dtype, self.V_dtype): + return consistent_scheduler( + batch=self.batch, + heads=self.heads, + seq_len=self.seq_len, + dim=self.dim, + layout=self.layout, + dtype_QKV=self.Q_dtype, + dtype_Out=self.Out_dtype, + dtype_Accu=self.Accu_dtype, + is_causal=self.is_causal, + ) + else: + raise ValueError("Currently only support native compute for scheduler") + + def cleanup(self): + self._free_workspace() + + @property + def batch(self): + return self.config.batch + + @property + def heads(self): + return self.config.heads + + @property + def seq_len(self): + return self.config.seq_len + + @property + def dim(self): + return self.config.dim + + @property + def Q_dtype(self): + return self.config.Q_dtype + + @property + def K_dtype(self): + return self.config.K_dtype + + @property + def V_dtype(self): + return self.config.V_dtype + + @property + def Accu_dtype(self): + return self.config.Accu_dtype + + @property + def Out_dtype(self): + return self.config.Out_dtype + + @property + def layout(self): + return self.config.layout + + @property + def is_causal(self): + return self.config.is_causal diff --git a/bitblas/ops/general_flashatten/tilelang/__init__.py b/bitblas/ops/general_flashatten/tilelang/__init__.py new file mode 100644 index 000000000..e3a4f0af9 --- /dev/null +++ b/bitblas/ops/general_flashatten/tilelang/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .flashatten import flashatten_blocked # noqa: F401 +from .flashatten import FlashAttenScheduler # noqa: F401 + + +def parse_layout(layout: str): + trans_Q = False + trans_K = layout[1] == 't' + trans_V = False + return trans_Q, trans_K, trans_V + + +def select_scheduler( + batch=None, + heads=None, + seq_len=None, + dim=None, + layout="nnn", + dtype_QKV="float16", + dtype_Out="float16", + dtype_Accu="float32", + is_causal=False, +): + trans_list = parse_layout(layout) + trans_K = trans_list[1] + return FlashAttenScheduler( + batch=batch, + heads=heads, + seq_len=seq_len, + dim=dim, + trans_K=trans_K, + dtype_QKV=dtype_QKV, + dtype_Out=dtype_Out, + dtype_Accu=dtype_Accu, + is_causal=is_causal, + ) diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py new file mode 100644 index 000000000..c333819c4 --- /dev/null +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -0,0 +1,289 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +from bitblas.ops.base_scheduler import BaseScheduler +import tvm.tl.language as T +from dataclasses import dataclass +from typing import Optional +import logging + +logger = logging.Logger(__name__) + + +@dataclass +class FlashAttenScheduler(BaseScheduler): + # flashattention config + batch: Optional[int] = None + heads: Optional[int] = None + seq_len: Optional[int] = None + dim: Optional[int] = None + trans_K: bool = False + dtype_QKV: str = "float16" + dtype_Out: str = "float16" + dtype_Accu: str = "float32" + is_causal: bool = False + # block config + block_M: int = 64 + block_N: int = 64 + num_stages: int = 2 + threads: int = 128 + enable_rasterization: bool = False + + def choose_pipeline( + self, + iterable, + num_stages, + ): + enable_pipeline = num_stages > 1 + if enable_pipeline: + return T.Pipelined(iterable, num_stages=num_stages) + else: + return T.serial(iterable) + + def with_default_config(self): + block_M = getattr(self, "block_M", 64) + block_N = getattr(self, "block_N", 64) + num_stages = getattr(self, "num_stages", 2) + threads = getattr(self, "threads", 128) + enable_rasterization = getattr(self, "rasterization", False) + return self.apply_config( + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization) + + def apply_config( + self, + block_M=64, + block_N=64, + num_stages=2, + threads=128, + enable_rasterization=False, + ): + batch, heads, seq_len, dim = self.batch, self.heads, self.seq_len, self.dim + trans_K = self.trans_K + dtypeQKV, dtypeAccu, dtypeOut = self.dtype_QKV, self.dtype_Accu, self.dtype_Out + is_causal = self.is_causal + + Q_shape = (batch, seq_len, heads, dim) + K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len) + V_shape = (batch, seq_len, heads, dim) + Output_shape = (batch, seq_len, heads, dim) + + Q_shared_shape = (block_M, dim) + K_shared_shape = (block_N, dim) if not trans_K else (dim, block_N) + V_shared_shape = (block_N, dim) + + Q_local_shape = (block_M, dim) + + @T.prim_func + def main( + Q: T.Buffer(Q_shape, dtypeQKV), + K: T.Buffer(K_shape, dtypeQKV), + V: T.Buffer(V_shape, dtypeQKV), + Output: T.Buffer(Output_shape, dtypeOut), + ): + scale = (1.0 / dim)**0.5 * 1.44269504 + with T.Kernel( + T.ceildiv(seq_len, block_M), batch, heads, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared(Q_shared_shape, dtypeQKV) + Q_local = T.alloc_fragment(Q_local_shape, dtypeQKV) + K_shared = T.alloc_shared(K_shared_shape, dtypeQKV) + V_shared = T.alloc_shared(V_shared_shape, dtypeQKV) + score_QK = T.alloc_fragment((block_M, block_N), dtypeAccu) + score_QK_sum = T.alloc_fragment((block_M), dtypeAccu) + score_QK_qkvtype = T.alloc_fragment((block_M, block_N), dtypeQKV) + score_scale = T.alloc_fragment((block_M), dtypeAccu) + local_rowmax = T.alloc_fragment((block_M), dtypeAccu) + prev_rowmax = T.alloc_fragment((block_M), dtypeAccu) + global_l = T.alloc_fragment((block_M), dtypeAccu) + block_output = T.alloc_fragment((block_M, dim), dtypeOut) + + if enable_rasterization: + # rasterization factor + T.use_swizzle(10) + T.copy(Q[by, bx * block_M:(bx + 1) * block_M, bz, :], Q_shared) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M, dim): + Q_local[i, j] *= scale + T.fill(block_output, 0) + T.fill(global_l, 0) + T.fill(local_rowmax, -T.infinity(dtypeAccu)) + + for k in self.choose_pipeline( + T.ceildiv((bx + 1) * + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N), + num_stages=num_stages): + if trans_K: + T.copy(K[by, :, bz, k * block_N:(k + 1) * block_N], K_shared) + else: + T.copy(K[by, k * block_N:(k + 1) * block_N, bz, :], K_shared) + T.copy(V[by, k * block_N:(k + 1) * block_N, bz, :], V_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + score_QK[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(dtypeAccu)) + else: + T.fill(score_QK, 0) + T.gemm( + Q_local, + K_shared, + score_QK, + transpose_B=(not trans_K), + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(local_rowmax, prev_rowmax) + T.reduce_max(score_QK, local_rowmax, dim=1, clear=False) + for i, j in T.Parallel(block_M, block_N): + score_QK[i, j] = T.exp2(score_QK[i, j] - local_rowmax[i]) + for i in T.Parallel(block_M): + score_scale[i] = T.exp2(prev_rowmax[i] - local_rowmax[i]) + T.reduce_sum(score_QK, score_QK_sum, dim=1) + for i in T.Parallel(block_M): + global_l[i] = global_l[i] * score_scale[i] + score_QK_sum[i] + for i, j in T.Parallel(block_M, dim): + block_output[i, j] *= score_scale[i] + T.copy(score_QK, score_QK_qkvtype) + T.gemm( + score_QK_qkvtype, + V_shared, + block_output, + policy=T.GemmWarpPolicy.FullRow, + ) + for i, j in T.Parallel(block_M, dim): + block_output[i, j] /= global_l[i] + T.copy(block_output, Output[by, bx * block_M:(bx + 1) * block_M, bz, :]) + + return self.maybe_simplify(main) + + +def maybe_pipeline( + iterable, + num_stages, +): + enable_pipeline = num_stages > 1 + if enable_pipeline: + return T.Pipelined(iterable, num_stages=num_stages) + else: + return T.serial(iterable) + + +def flashatten_blocked( + batch, + seq_len, + heads, + dim, + block_M_seq=64, + block_N_seq=64, + trans_Q=False, # (batch, seq_len, heads, dim) for default, (batch, dim, heads, seq_len) for trans + trans_K=False, # (batch, seq_len, heads, dim) for default, (batch, dim, heads, seq_len) for trans + trans_V=False, # (batch, seq_len, heads, dim) for default, (batch, dim, heads, seq_len) for trans + dtypeQKV="float16", + dtypeAccu="float32", + dtypeOut="float16", + num_stages=2, + threads=128, + is_causal=False, + enable_rasterization=False, # Enhance L2 Locality +): + Q_shape = (batch, seq_len, heads, dim) if not trans_Q else (batch, dim, heads, seq_len) + K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len) + V_shape = (batch, seq_len, heads, dim) if not trans_V else (batch, dim, heads, seq_len) + Output_shape = (batch, seq_len, heads, dim) + + Q_shared_shape = (block_M_seq, dim) if not trans_Q else (dim, block_M_seq) + K_shared_shape = (block_N_seq, dim) if not trans_K else (dim, block_N_seq) + V_shared_shape = (block_N_seq, dim) if not trans_V else (dim, block_N_seq) + + Q_local_shape = (block_M_seq, dim) if not trans_Q else (dim, block_M_seq) + + @T.prim_func + def main( + Q: T.Buffer(Q_shape, dtypeQKV), + K: T.Buffer(K_shape, dtypeQKV), + V: T.Buffer(V_shape, dtypeQKV), + Output: T.Buffer(Output_shape, dtypeOut), + ): + scale = (1.0 / dim)**0.5 * 1.44269504 + with T.Kernel( + T.ceildiv(seq_len, block_M_seq), batch, heads, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared(Q_shared_shape, dtypeQKV) + Q_local = T.alloc_fragment(Q_local_shape, dtypeQKV) + K_shared = T.alloc_shared(K_shared_shape, dtypeQKV) + V_shared = T.alloc_shared(V_shared_shape, dtypeQKV) + score_QK = T.alloc_fragment((block_M_seq, block_N_seq), dtypeAccu) + score_QK_sum = T.alloc_fragment((block_M_seq), dtypeAccu) + score_QK_qkvtype = T.alloc_fragment((block_M_seq, block_N_seq), dtypeQKV) + score_scale = T.alloc_fragment((block_M_seq), dtypeAccu) + local_rowmax = T.alloc_fragment((block_M_seq), dtypeAccu) + prev_rowmax = T.alloc_fragment((block_M_seq), dtypeAccu) + global_l = T.alloc_fragment((block_M_seq), dtypeAccu) + block_output = T.alloc_fragment((block_M_seq, dim), dtypeOut) + + if enable_rasterization: + # rasterization factor + T.use_swizzle(10) + if trans_Q: + T.copy(Q[by, :, bz, bx * block_M_seq:(bx + 1) * block_M_seq], Q_shared) + else: + T.copy(Q[by, bx * block_M_seq:(bx + 1) * block_M_seq, bz, :], Q_shared) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M_seq, dim): + Q_local[i, j] *= scale + T.fill(block_output, 0) + T.fill(global_l, 0) + T.fill(local_rowmax, -T.infinity(dtypeAccu)) + + for k in maybe_pipeline( + T.ceildiv( + (bx + 1) * + block_M_seq, block_N_seq) if is_causal else T.ceildiv(seq_len, block_N_seq), + num_stages=num_stages): + if trans_K: + T.copy(K[by, :, bz, k * block_N_seq:(k + 1) * block_N_seq], K_shared) + else: + T.copy(K[by, k * block_N_seq:(k + 1) * block_N_seq, bz, :], K_shared) + if trans_V: + T.copy(V[by, :, bz, k * block_N_seq:(k + 1) * block_N_seq], V_shared) + else: + T.copy(V[by, k * block_N_seq:(k + 1) * block_N_seq, bz, :], V_shared) + if is_causal: + for i, j in T.Parallel(block_M_seq, block_N_seq): + score_QK[i, j] = T.if_then_else(bx * block_M_seq + i >= k * block_N_seq + j, + 0, -T.infinity(dtypeAccu)) + else: + T.fill(score_QK, 0) + T.gemm( + Q_local, + K_shared, + score_QK, + transpose_A=trans_Q, + transpose_B=(not trans_K), + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(local_rowmax, prev_rowmax) + T.reduce_max(score_QK, local_rowmax, dim=1, clear=False) + for i, j in T.Parallel(block_M_seq, block_N_seq): + score_QK[i, j] = T.exp2(score_QK[i, j] - local_rowmax[i]) + for i in T.Parallel(block_M_seq): + score_scale[i] = T.exp2(prev_rowmax[i] - local_rowmax[i]) + T.reduce_sum(score_QK, score_QK_sum, dim=1) + for i in T.Parallel(block_M_seq): + global_l[i] = global_l[i] * score_scale[i] + score_QK_sum[i] + for i, j in T.Parallel(block_M_seq, dim): + block_output[i, j] *= score_scale[i] + T.copy(score_QK, score_QK_qkvtype) + T.gemm( + score_QK_qkvtype, + V_shared, + block_output, + transpose_B=trans_V, + policy=T.GemmWarpPolicy.FullRow, + ) + for i, j in T.Parallel(block_M_seq, dim): + block_output[i, j] /= global_l[i] + T.copy(block_output, Output[by, bx * block_M_seq:(bx + 1) * block_M_seq, bz, :]) + + return main diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index b7b884443..126474f8a 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -52,8 +52,8 @@ def is_native_compute(A_dtype, W_dtype) -> bool: @dataclass(frozen=True) class MatmulConfig(OperatorConfig): M: Union[int, Tuple[int]] = None - N: int = None - K: int = None + N: Optional[int] = None + K: Optional[int] = None A_dtype: str = "float16" # is a wrapper for source_format and bit W_dtype: str = A_dtype # W_dtype is the same as A_dtype by default @@ -296,6 +296,7 @@ def generate(self, hint=None) -> str: # kernel_name += f"_pb{config.propagate_b.value}" kernel_name = "_".join([kernel_name, self.serialize_hint(hint)]) + assert self.is_valid(kernel_name), "Kernel name invalid" return kernel_name def is_valid_config(self, config: OperatorConfig) -> bool: diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index eb173352f..50d40f122 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -22,6 +22,7 @@ from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass import logging +import re logger = logging.getLogger(__name__) @@ -60,10 +61,16 @@ def generate(self, hint: Hint = None) -> str: """Generate the kernel name based on the config and hint""" pass + def is_valid(self, kernel_name: str = None) -> bool: + '''Validate kernel name after generation''' + pattern = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') + return kernel_name.isidentifier() and pattern.match(kernel_name) + class DefaultKernelNameGenerator(BaseKernelNameGenerator): DEFAULT_PREFIX = "main" + kernel_name = None def __init__(self, config: OperatorConfig, name: str): self.DEFAULT_PREFIX = name diff --git a/testing/python/operators/test_general_flashatten_ops_backend_tl.py b/testing/python/operators/test_general_flashatten_ops_backend_tl.py new file mode 100644 index 000000000..13617c4ea --- /dev/null +++ b/testing/python/operators/test_general_flashatten_ops_backend_tl.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import FlashAttenConfig, FlashAtten +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +# fmt: off +def flashatten_codegen_default(batch, heads, seq_len, dim, Q_dtype, K_dtype, V_dtype, Accu_dtype, + Out_dtype, layout, is_causal): + + flashatten_config = FlashAttenConfig( + batch=batch, + heads=heads, + seq_len=seq_len, + dim=dim, + Q_dtype=Q_dtype, + K_dtype=K_dtype, + V_dtype=V_dtype, + Accu_dtype=Accu_dtype, + Out_dtype=Out_dtype, + layout=layout, + is_causal=is_causal) + flashatten = FlashAtten(config=flashatten_config, enable_tuning=False, backend="tl") + assert get_codegen_result(flashatten) + + +def test_matmul_codegen_default(): + flashatten_codegen_default(1, 4, 256, 256, "float16", "float16", "float16", "float32", + "float16", "nnn", False) + flashatten_codegen_default(1, 4, 256, 256, "float16", "float16", "float16", "float32", + "float16", "ntn", False) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_flashatten_tilelang_scheduler.py b/testing/python/operators/test_general_flashatten_tilelang_scheduler.py new file mode 100644 index 000000000..bfe9f8964 --- /dev/null +++ b/testing/python/operators/test_general_flashatten_tilelang_scheduler.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm.ir import structural_equal +from bitblas.ops.general_flashatten.tilelang.flashatten import FlashAttenScheduler + + +def assert_flashatten_scheduler_simplify(batch, + heads, + seq_len, + dim, + trans_K=False, + dtype_QKV="float16", + dtype_Out="float16", + dtype_Accu="float32", + is_causal=False): + flashatten = FlashAttenScheduler( + batch=batch, + heads=heads, + seq_len=seq_len, + dim=dim, + trans_K=trans_K, + dtype_QKV=dtype_QKV, + dtype_Out=dtype_Out, + dtype_Accu=dtype_Accu, + is_causal=is_causal, + ).deactivate_simplify().with_default_config() + + simplified_flashatten = FlashAttenScheduler.Simplify(flashatten) + + is_equal = structural_equal(flashatten, simplified_flashatten) + + assert is_equal is False, "Simplify should not return the same schedule" + + +def test_scheduler_simplify(): + assert_flashatten_scheduler_simplify(1, 4, 256, 256) + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 10e9ade7c..6ecca21e3 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -4,7 +4,7 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import tl -from bitblas.ops.general_matmul.tilelang.dense.matmul import ( +from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( MatmulScheduler, MatmulFineGrainScheduler, MatmulWeightPropagationScheduler, diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index 87c685e08..b82e75cd4 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -4,7 +4,7 @@ from bitblas import tvm as tvm import bitblas.testing from tvm.ir import structural_equal -from bitblas.ops.general_matmul.tilelang.dense.matmul import ( +from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( MatmulScheduler,) diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 4638f10bc..a7362c782 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -7,6 +7,7 @@ import bitblas import logging from bitblas import set_log_level +from bitblas.ops.general_flashatten.tilelang.flashatten import flashatten_blocked set_log_level(logging.DEBUG) @@ -48,6 +49,141 @@ def ref_flashattn_result(batch, heads, seq_len, dim, is_casual, dtype="float16") return res +def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, num_stages, + is_causal): + tl_prim_func = flashatten_blocked( + batch=batch, + seq_len=seq_len, + heads=heads, + dim=dim, + trans_K=trans_K, + dtypeQKV=dtypeQKV, + dtypeAccu=dtypeAccu, + num_stages=num_stages, + is_causal=is_causal, + ) + mod, params = tl.lower(tl_prim_func) + mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + from flash_attn.flash_attn_interface import flash_attn_func + # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils + ins = mod._get_inputs() + tilelang_res = mod(*ins) + Q, K, V = ins[0], ins[1], ins[2] + if trans_K: + K = K.transpose(1, 3).contiguous() + ref_res = flash_attn_func(Q, K, V, causal=is_causal) + torch.testing.assert_close(tilelang_res, ref_res, rtol=0.01, atol=0.01) + + +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_flashattn_blocked(): + flashattn_tilelang(1, 1, 256, 256, False, "float16", "float32", 1, False) + flashattn_tilelang(1, 4, 256, 256, False, "float16", "float32", 1, False) + flashattn_tilelang(1, 1, 512, 256, False, "float16", "float32", 1, False) + flashattn_tilelang(1, 4, 512, 256, False, "float16", "float32", 1, False) + flashattn_tilelang(1, 1, 512, 256, True, "float16", "float32", 1, False) + flashattn_tilelang(1, 4, 512, 256, True, "float16", "float32", 1, False) + + +def flashattn_ref(batch, heads, seq_len, dim, is_causal): + + def kernel(block_M=64, block_N=64, num_stages=1, thread_num=128): + scale = (1.0 / dim)**0.5 * 1.44269504 + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M, dim): + Q_local[i, j] *= scale + loop_range = ( + T.ceildiv( + (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + 0, + -T.infinity(acc_s.dtype), + ) + else: + T.clear(acc_s) + T.gemm( + Q_local, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(acc_s, acc_s_cast) + T.gemm( + acc_s_cast, + V_shared, + acc_o, + policy=T.GemmWarpPolicy.FullRow, + ) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + + return main + + mod, params = tl.lower(kernel()) + mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) + + +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_flashattn_ref(): + flashattn_ref(1, 4, 256, 256, False) + flashattn_ref(1, 8, 256, 256, False) + flashattn_ref(1, 4, 256, 256, True) + flashattn_ref(1, 8, 256, 256, True) + flashattn_ref(4, 4, 256, 256, True) + flashattn_ref(4, 8, 256, 256, True) + + def flashattn_autotune(batch, heads, seq_len, dim, is_causal): @autotune(