From a5a75274bc5dc3df59d8a8e28a0a841a438002ad Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 15 Oct 2024 00:21:22 -0700 Subject: [PATCH] FA3 kvcache + split kv + gqa parallelization (#1236) --- hopper/benchmark_flash_attention_fp8.py | 10 +- hopper/benchmark_split_kv.py | 325 +++++++++++ hopper/block_info.h | 46 -- hopper/combine.h | 248 ++++++++ hopper/epilogue_fwd_sm90_tma.hpp | 399 ++++++++----- hopper/flash.h | 15 +- hopper/flash_api.cpp | 598 ++++++++++++++++++-- hopper/flash_attn_interface.py | 189 ++++++- hopper/flash_fwd_hdim128_bf16_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim128_bf16_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim128_bf16_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim128_bf16_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim128_bf16_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim128_e4m3_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim128_e4m3_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim128_e4m3_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim128_e4m3_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim128_e4m3_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim128_fp16_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim128_fp16_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim128_fp16_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim128_fp16_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim128_fp16_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim256_bf16_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim256_bf16_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim256_bf16_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim256_bf16_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim256_bf16_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim256_e4m3_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim256_e4m3_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim256_e4m3_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim256_e4m3_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim256_e4m3_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim256_fp16_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim256_fp16_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim256_fp16_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim256_fp16_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim256_fp16_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim64_bf16_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim64_bf16_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim64_bf16_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim64_bf16_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim64_bf16_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim64_e4m3_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim64_e4m3_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim64_e4m3_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim64_e4m3_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim64_e4m3_gqa8_sm90.cu | 9 + hopper/flash_fwd_hdim64_fp16_gqa16_sm90.cu | 9 + hopper/flash_fwd_hdim64_fp16_gqa2_sm90.cu | 9 + hopper/flash_fwd_hdim64_fp16_gqa32_sm90.cu | 9 + hopper/flash_fwd_hdim64_fp16_gqa4_sm90.cu | 9 + hopper/flash_fwd_hdim64_fp16_gqa8_sm90.cu | 9 + hopper/flash_fwd_kernel.h | 216 +++---- hopper/flash_fwd_launch_template.h | 531 +++++++++++++---- hopper/kernel_traits.h | 157 ++++- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 567 ++++++++++--------- hopper/seq_len.h | 247 +++++++- hopper/setup.py | 47 +- hopper/static_switch.h | 125 +++- hopper/test_attn_kvcache.py | 486 ++++++++++++++++ hopper/test_flash_attn.py | 158 +++++- hopper/test_kvcache.py | 234 ++++++++ hopper/tile_scheduler.hpp | 72 ++- hopper/utils.h | 115 +++- 65 files changed, 4375 insertions(+), 815 deletions(-) create mode 100644 hopper/benchmark_split_kv.py delete mode 100644 hopper/block_info.h create mode 100644 hopper/combine.h create mode 100644 hopper/flash_fwd_hdim128_bf16_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_bf16_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_bf16_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_bf16_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_bf16_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_e4m3_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_e4m3_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_e4m3_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_e4m3_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_e4m3_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_fp16_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_fp16_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_fp16_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_fp16_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim128_fp16_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_bf16_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_bf16_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_bf16_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_bf16_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_bf16_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_e4m3_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_e4m3_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_e4m3_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_e4m3_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_e4m3_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_fp16_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_fp16_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_fp16_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_fp16_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_fp16_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_bf16_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_bf16_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_bf16_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_bf16_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_bf16_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_e4m3_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_e4m3_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_e4m3_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_e4m3_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_e4m3_gqa8_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_fp16_gqa16_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_fp16_gqa2_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_fp16_gqa32_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_fp16_gqa4_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_fp16_gqa8_sm90.cu create mode 100644 hopper/test_attn_kvcache.py create mode 100644 hopper/test_kvcache.py diff --git a/hopper/benchmark_flash_attention_fp8.py b/hopper/benchmark_flash_attention_fp8.py index efd007dd2..8490d00b3 100644 --- a/hopper/benchmark_flash_attention_fp8.py +++ b/hopper/benchmark_flash_attention_fp8.py @@ -229,7 +229,8 @@ def time_fwd(func, *args, **kwargs): # dim = 256 dropout_p = 0.0 -methods = (["Pytorch", "Flash3", "cuDNN"] +methods = (["Pytorch", "Flash3"] + + (["cuDNN"] if cudnn is not None else []) # + (["Triton"] if attention_triton is not None else []) # + (["xformers.c"] if xops is not None else []) # + (["xformers.f"] if xops is not None else []) @@ -247,10 +248,10 @@ def time_fwd(func, *args, **kwargs): torch.cuda.empty_cache() config = (causal, headdim, batch_size, seqlen) nheads = dim // headdim - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)] + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(3)] qkv = torch.stack([q, k, v], dim=2) - qkv = qkv.to(torch.float16) + qkv = qkv.to(torch.bfloat16) f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False) time_f[config, "Pytorch"] = f res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) @@ -289,7 +290,8 @@ def time_fwd(func, *args, **kwargs): k, v, softmax_scale, - causal=causal, + causal=causal, + window_size=(-1,-1), descale_q=descale_q, descale_k=descale_k, descale_v=descale_v, diff --git a/hopper/benchmark_split_kv.py b/hopper/benchmark_split_kv.py new file mode 100644 index 000000000..d3d83590a --- /dev/null +++ b/hopper/benchmark_split_kv.py @@ -0,0 +1,325 @@ +import torch +import flash_attn +import flash_attn_interface +import itertools +import time +import math + +import torch.utils.benchmark as benchmark + +def round_up_to_power_of_2(x): + if x <= 1: + return 1 + return 1 << (x - 1).bit_length() + +def timeit(fn, *args, **kwargs): + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + fn(*args, **kwargs) + + # Benchmark using PyTorch Timer + t = benchmark.Timer( + stmt='fn(*args, **kwargs)', + globals={'fn': fn, 'args': args, 'kwargs': kwargs} + ) + + # Measure execution time + measurement = t.timeit(20) # Runs the function 20 times + # measurement = t.blocked_autorange(min_run_time=1) + avg_time = measurement.mean # Average time in seconds + + return avg_time + +def main(): + num_sms = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + + max_splits = 129 + check_all_splits = False + + causal = True + # causal = False + # dtype=torch.float16 + dtype=torch.bfloat16 + + torch.manual_seed(42) + + model_configs = [ + # ("Gemma-2-2B", 8, 4, 256), + # ("Gemma-2-9B", 16, 8, 256), + # ("Gemma-2-27B", 32, 16, 128), + # ("Qwen-2.5-0.5B", 14, 2, 64), + # ("Qwen-2.5-1.5B", 12, 2, 128), + # ("Qwen-2.5-7B", 28, 4, 128), + # ("Llama-3.1-8B", 32, 8, 128), + ("Llama-3.1-70B", 64, 8, 128), + # ("Llama-3.1-405B", 128, 8, 128), + # ("Llama-3.2-1B", 32, 8, 64), + # ("Llama-3.2-3B", 24, 8, 128), + # ("Nemotron-4-15B", 48, 8, 128), + ] + + all_batch_configs = [] + + all_batch_configs.extend(itertools.product( + # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen + [4096, 16384, 65536], # context_seqlen + # [131072], # context_seqlen + # [i for i in range(1, (num_sms) + 1)], # num_requests + [1, 4, 8, 16], # num_requests + # [1], # num_requests + [1, 4, 8, 16], # query_seqlen + # [1], # query_seqlen + )) + + num_caches = max(reqs for _, reqs, _ in all_batch_configs) + cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs) + + for model_name, nheads_q, nheads_kv, headdim in model_configs: + k_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype + ) + v_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype + ) + print(f"***{model_name}***") + print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}") + + if check_all_splits is False: + print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}") + + for context_seqlen, num_requests, query_seqlen in all_batch_configs: + bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4) + bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4) + blockH = round_up_to_power_of_2(nheads_q//nheads_kv) + blockM = 128 # true for hdim 128 causal and hdim 64 + blockM_div_H = blockM//blockH + num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H) + + q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype) + cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] + cache_seqlens = torch.tensor( + [context_seqlen] * num_requests, dtype=torch.int32, device="cuda" + ) + + fa2_time_heuristic = timeit( + flash_attn.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + ) * 1000. * 1000. + # fastest_splitk_time = float("inf") + # fastest_splitk = 0 + # for i in range(1, max_splits): + # t = timeit( + # flash_attn.flash_attn_with_kvcache, + # q=q, + # k_cache=k_cache, + # v_cache=v_cache, + # cache_seqlens=cache_seqlens, + # cache_batch_idx=cache_idxs, + # causal=causal, + # num_splits=i, + # ) * 1000. * 1000. + # if t < fastest_splitk_time: + # fastest_splitk_time = t + # fastest_splitk = i + + fa3_time_one_split = timeit( + flash_attn_interface.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=False, + num_splits=1, + ) * 1000. * 1000. + + fa3_time_gqa_heuristic = timeit( + flash_attn_interface.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=True, + num_splits=0, + max_seqlen_k_hint=context_seqlen + ) * 1000. * 1000. + + if check_all_splits: + + fa3_fastest_num_splits = 0 + fa3_fastest_splitk_time = float("inf") + + for num_splits in range(1, max_splits): + t = timeit( + flash_attn_interface.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=False, + num_splits=num_splits + ) * 1000. * 1000. + + out0 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=False, + num_splits=num_splits + ) + + out1 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=False, + num_splits=1 + ) + + max_diff = (out0 - out1).abs().max().item() + mean_diff = (out0 - out1).abs().mean().item() + # print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}") + # print (f"splits {num_splits}, time {t:.2f}") + + if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4: + print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}") + + if t < fa3_fastest_splitk_time: + fa3_fastest_splitk_time = t + fa3_fastest_num_splits = num_splits + + fa3_fastest_num_splits_gqa = 0 + fa3_fastest_splitk_time_gqa = float("inf") + for num_splits in range(1, max_splits): + + t = timeit( + flash_attn_interface.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=True, + num_splits=num_splits + ) * 1000. * 1000. + + out0 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=True, + num_splits=num_splits + ) + + out1 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=True, + num_splits=1 + ) + + max_diff = (out0 - out1).abs().max().item() + mean_diff = (out0 - out1).abs().mean().item() + # print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}") + # print (f"gqa splits {num_splits}, time {t:.2f}") + + if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4: + print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}") + + if t < fa3_fastest_splitk_time_gqa: + fa3_fastest_splitk_time_gqa = t + fa3_fastest_num_splits_gqa = num_splits + + efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms + heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa + # remeasure to smooth anomalies + if heuristic_ratio > 1.1: + + fa3_time_gqa_heuristic = timeit( + flash_attn_interface.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=True, + # num_splits=num_splits_select, + # num_splits=1, + num_splits=0, + max_seqlen_k_hint=context_seqlen + ) * 1000. * 1000. + + fa3_fastest_splitk_time_gqa = timeit( + flash_attn_interface.flash_attn_with_kvcache, + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + gqa_parallel=True, + num_splits=fa3_fastest_num_splits_gqa + ) * 1000. * 1000. + + if check_all_splits is True: + print( + f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, " + f"FA2:{fa2_time_heuristic:.2f}, " + # f"FA2 MANUAL:{fastest_splitk_time:.2f}, " + # f"FA2 NUM SPLITS:{fastest_splitk}, " + # f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, " + # f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, " + # f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, " + f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, " + f"FA3:{fa3_time_gqa_heuristic:.2f}, " + # f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, " + # f"FA2 NUM SPLITS:{fastest_splitk}, " + # f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, " + f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, " + # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, " + f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, " + f"EFF:{efficiency:.2f}, " + f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" + ) + + if check_all_splits is False: + print( + f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}" + f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}" + f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}" + f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}" + ) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/hopper/block_info.h b/hopper/block_info.h deleted file mode 100644 index 3a23a1e1f..000000000 --- a/hopper/block_info.h +++ /dev/null @@ -1,46 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockInfo { - - template - __device__ BlockInfo(const Params ¶ms, const int bidb) - : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) - , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. - // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) - { - } - - template - __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - } - - template - __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; - } - - const int sum_s_q; - const int sum_s_k; - const int actual_seqlen_q; - // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. - const int seqlen_k_cache; - const int actual_seqlen_k; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/hopper/combine.h b/hopper/combine.h new file mode 100644 index 000000000..a3317631b --- /dev/null +++ b/hopper/combine.h @@ -0,0 +1,248 @@ + +#pragma once + +#include + +#include +#include "cutlass/layout/layout.h" +#include +#include + +#include "kernel_traits.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SharedStorageLSE { + cute::array_aligned> smem_lse; + cute::array_aligned> smem_valid_splits; +}; + +// DONT use Kernel_traits here to avoid redundant compilation. +// template +template +__global__ void combine_attn_seqk_parallel(Params const params) { + // using Element = typename Kernel_traits::OutputType; + // using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = int64_t; // Kernel_traits::index_t + constexpr int kMaxSplits = 1 << Log_max_splits; + // constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = 128; //Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; + extern __shared__ char smem_[]; + using SharedStorage = SharedStorageLSE, Int>, Shape>>; + SharedStorage &shared_storage = + *reinterpret_cast(smem_); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); + Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t lse_size = params.b * params.h * params.seqlen_q; + //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then transpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE(row,col) = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } + } + __syncthreads(); + + // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) + // One thread per split. Know NumThreads = 128 >= NumMaxSplits + if (tidx < kMaxSplits) { + bool is_valid_split = false; + #pragma unroll + for (int col = 0; col < kBlockM; ++col) { + if(sLSE(tidx,col) != -INFINITY) { + is_valid_split = true; + } + } + sValidSplits(tidx) = is_valid_split; + } + __syncthreads(); + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; + + } + //return; + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } + //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); + + // Store the scales exp(lse - lse_logsum) in shared memory. + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + //if (cute::thread0()) print_tensor (cOaccum); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. + if(sValidSplits(split)) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(split,row); + if (lse_scale != 0.f) { + #pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { + #pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + //tOrO(i, m, k) += tOrOaccum(i, m, k); + } + } + } + //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } + } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + //if (cute::thread0()) { print_tensor(tOrO); } + + Tensor rO = flash::convert_type(tOrO); + // Write to gO + #pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); + if (idx < params.b * params.h * params.seqlen_q) { + //print ("final2\n"); + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; + #pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index 993f2e239..26664c104 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -20,24 +20,60 @@ using namespace cute; template struct CollectiveEpilogueFwd { + using InputType = typename Ktraits::Element; using Element = typename Ktraits::OutputType; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kBlockH = Ktraits::kBlockH; static constexpr int kHeadDim = Ktraits::kHeadDim; - using TileShape_MNK = Shape, Int, Int>; + using TileShape_MNK = Shape, Int, Int>; static constexpr int kNWarps = Ktraits::kNWarps; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; - static constexpr bool Is_WS = kNWarps >= 12; + static constexpr bool Is_WS = Ktraits::Is_WS; static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int NumMmaThreads = kNThreads - NumCopyThreads; + static constexpr bool Is_split = Ktraits::Is_split; + static constexpr bool No_smem_O = Ktraits::No_smem_O; + +#ifndef NO_FP8_COLUMN_PERMUTE + static constexpr bool epi_column_permute = is_same_v; +#else + static constexpr bool epi_column_permute = false; +#endif + + using GmemShapeOT = std::conditional_t< + Is_split, + typename Seqlen_traits::ShapeOAccumT, + typename Seqlen_traits::ShapeT + >; + using GmemStrideOT = std::conditional_t< + Is_split, + typename Seqlen_traits::StrideOAccumT, + typename Seqlen_traits::StrideT + >; + using GmemLayoutOT = std::conditional_t< + Is_split, + typename Seqlen_traits::LayoutOAccumT, + typename Seqlen_traits::LayoutT + >; + + using GmemLayoutLseT = std::conditional_t< + Is_split, + typename Seqlen_traits::LayoutLseAccumT, + typename Seqlen_traits::LayoutLseT + >; + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOCopy = typename Ktraits::SmemLayoutOCopy; + using TileShapeOCopy = typename Ktraits::TileShapeOCopy; - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = std::conditional_t, Element>, Copy_Atom>; using SharedStorage = cute::array_aligned>; using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; @@ -45,11 +81,11 @@ struct CollectiveEpilogueFwd { GmemTiledCopyOTMA{}, make_tensor( make_gmem_ptr(static_cast(nullptr)), - typename Seqlen_traits::ShapeT{}, - typename Seqlen_traits::StrideT{} + GmemShapeOT{}, + GmemStrideOT{} ), - SmemLayoutO{}, - select<0, 2>(TileShape_MNK{}), + SmemLayoutOCopy{}, + TileShapeOCopy{}, _1{})); // no mcast for O // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len) @@ -76,7 +112,7 @@ struct CollectiveEpilogueFwd { Stride<_4, _32, _1, _0>>; using ValueLayoutrO = Layout, Int>, Stride<_0, _2, Stride<_4, _1>, _8>>; - using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, Element>{}, + using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, Element>{}, ThreadLayoutrO{}, ValueLayoutrO{})); using TiledCopyShaperO = Shape<_8, Int, _16, Int>; using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); @@ -84,17 +120,17 @@ struct CollectiveEpilogueFwd { // Host side kernel arguments struct Arguments { Element* ptr_O; - typename Seqlen_traits::LayoutT const layout_O; + GmemLayoutOT const layout_O; float* ptr_LSE; - typename Seqlen_traits::LayoutLseT const layout_LSE; + GmemLayoutLseT const layout_LSE; }; // Device side kernel params struct Params { Element* ptr_O; - typename Seqlen_traits::LayoutT const layout_O; + GmemLayoutOT const layout_O; float* ptr_LSE; - typename Seqlen_traits::LayoutLseT const layout_LSE; + GmemLayoutLseT const layout_LSE; TMA_O tma_store_O; }; @@ -104,8 +140,8 @@ struct CollectiveEpilogueFwd { TMA_O tma_store_O = make_tma_copy( GmemTiledCopyOTMA{}, mO, - SmemLayoutO{}, - select<0, 2>(TileShape_MNK{}), + SmemLayoutOCopy{}, + TileShapeOCopy{}, _1{}); // no mcast for O return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O}; } @@ -113,7 +149,7 @@ struct CollectiveEpilogueFwd { /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& epilogue_params) { - if constexpr (!Seqlen_traits::kUseVarSeqLen) { + if constexpr (!Seqlen_traits::UseVarSeqLen && !No_smem_O) { cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); } } @@ -126,169 +162,254 @@ struct CollectiveEpilogueFwd { SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, - cute::tuple const& block_coord, - const Seqlen_traits& seqlen_traits_q + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q, + const cutlass::FastDivmod& qhead_per_khead_divmod ) { - auto [m_block, bidh, bidb] = block_coord; - Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); - auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + const int bidh_kv = qhead_per_khead_divmod.divide(bidh); + const int h_block = bidh % int(qhead_per_khead_divmod); Tensor tOrO_out = flash::convert_type(tOrO); - Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // Make sure all WGs have finished reading V - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if constexpr(!No_smem_O) { + if constexpr (!epi_column_permute) { + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } else { + TiledCopyrO rmem_tiled_copy_O; + Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{}); + auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc); + Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO)); + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); + cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + } Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); - Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( - mLSE, Shape>{}, bidh, bidb)(_, m_block); Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(_0{})) == 0) { + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // 2 * MMA_M + + if constexpr(!Seqlen_traits::UseGQAPacking) { + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, bidh, bidb, n_split_idx)(_, m_block); + if (get<1>(taccOcO_row(_0{})) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + } else { + // shape<1>(epilogue_params.layout_O) == h/h_k + // In common case where ceil_div(h/h_k, kBlockH) == 1, + // int(qhead_per_khead_divmod) == 1, bidh_kv == bidh, h_block == 0 + const int h_offset = shape<1>(epilogue_params.layout_O) * bidh_kv + + h_block * kBlockH; + const int m_bound = seqlen_traits_q.actual_seq_len - m_block * (kBlockM/kBlockH); + const int h_bound = shape<1>(epilogue_params.layout_O) - h_block * kBlockH; #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); } + const int h_local = row % kBlockH; + const int m_local = row/kBlockH; + if(h_local < h_bound && m_local < m_bound) { + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(mLSE, + Shape>{}, h_offset + h_local, bidb, n_split_idx) + (_, m_block); + gLSE(m_local) = lse(mi); + } } } - - int write_warp_idx = kNWarps - 1; - if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { - cutlass::arch::NamedBarrier::sync( - NumMmaThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier - ); + + if constexpr (No_smem_O) { + flash::write_rmem_to_gmem( + tOrO_out, epilogue_params.ptr_O, epilogue_params.layout_O, TileShapeOCopy{}, + m_block, h_block, bidh, bidh_kv, bidb, n_split_idx, + tiled_mma, seqlen_traits_q, thread_idx); + } else { + int write_warp_idx = kNWarps - 1; + if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { + cutlass::arch::NamedBarrier::sync( + NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ); + } + TiledCopyO gmem_tiled_copy_O; + Tensor sO_out = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutOCopy{}); + if constexpr(!Seqlen_traits::UseGQAPacking) { + flash::write_O( + epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, + epilogue_params.layout_O, TileShapeOCopy{}, sO_out, + m_block, bidh, bidb, n_split_idx, seqlen_traits_q, write_warp_idx, tiled_mma, tOrO_out + ); + } else { + Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.layout_O.shape()); + Tensor gO = seqlen_traits_q.get_o_local_tile_tensor( + mO, TileShapeOCopy{}, bidh_kv, bidb, n_split_idx) + (_, _, _, m_block, h_block); // (bM/bH, bH, K) + auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO_out); // (TMA, TMA_M, TMA_K) + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == write_warp_idx && lane_predicate) { + cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + } + } } - TiledCopyO gmem_tiled_copy_O; - flash::write_O( - epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, - epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, - m_block, bidh, bidb, seqlen_traits_q, write_warp_idx - ); } - template CUTLASS_DEVICE void - store_fp8(Params const& epilogue_params, - FrgTensorO const& tOrO, - FrgTensorLSE const& lse, + store_tail() { + if constexpr(!No_smem_O) { tma_store_wait<0>(); } + } + + // Write 0 to output and -inf to LSE + template + CUTLASS_DEVICE void + store_zero( + Params const& epilogue_params, SharedStorage& shared_storage, - TiledMma tiled_mma, int thread_idx, - cute::tuple const& block_coord, + cute::tuple const& block_coord, const Seqlen_traits& seqlen_traits_q ) { - // using SmemLayoutrO = typename Ktraits::SmemLayoutrO; - // using TiledCopyrO = typename Ktraits::TiledCopyrO; - auto [m_block, bidh, bidb] = block_coord; - - TiledCopyrO rmem_tiled_copy_O; - Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{}); - auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx); - - Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc); - Tensor tOrO_out = flash::convert_type(tOrO); // Element is Ktraits::OutputType - Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO)); - - // Make sure all WGs have finished reading V - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); - cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO); - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - - Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); - Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( - mLSE, Shape>{}, bidh, bidb)(_, m_block); - Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(size<0, 0>(taccOcO))::value == 2); - static_assert(decltype(size<0, 1>(taccOcO))::value == 2); - // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. - Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - int const seqlen_q = [&] { - if constexpr(Seqlen_traits::kUseVarSeqLen) { return seqlen_traits_q.actual_seq_len; } - else { return shape<2>(epilogue_params.layout_LSE); } - }(); - if (get<1>(taccOcO_row(_0{})) == 0) { + static_assert(!Seqlen_traits::UseGQAPacking, "Don't call store_zero for gqa packed layouts."); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + + if constexpr(!Is_split) { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O); + Tensor gO = seqlen_traits_q.get_o_local_tile_tensor( + mO, select<0, 2>(TileShape_MNK{}), bidh, bidb, n_split_idx + )(_, _, m_block); // (M, K) + + TiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } - } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM + ); } - int write_warp_idx = kNWarps - 1; - if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, bidh, bidb, n_split_idx)(_, m_block); + static_assert(kBlockM <= NumMmaThreads); + if (thread_idx < min(kBlockM, seqlen_traits_q.actual_seq_len - m_block * kBlockM)) { + gLSE(thread_idx) = !Is_split ? INFINITY : -INFINITY; } - TiledCopyO gmem_tiled_copy_O; - Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); - flash::write_O( - epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, - epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, - m_block, bidh, bidb, seqlen_traits_q, write_warp_idx - ); - } - - CUTLASS_DEVICE void - store_tail() { - tma_store_wait<0>(); } // Write 0 to output and -inf to LSE template CUTLASS_DEVICE void - store_zero( + store_zero_gqa( Params const& epilogue_params, SharedStorage& shared_storage, int thread_idx, - cute::tuple const& block_coord, - const Seqlen_traits& seqlen_traits_q + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q, + const cutlass::FastDivmod& qhead_per_khead_divmod ) { - auto [m_block, bidh, bidb] = block_coord; - Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O); - Tensor gO = seqlen_traits_q.get_local_tile_tensor( - mO, select<0, 2>(TileShape_MNK{}), bidh, bidb - )(_, _, m_block); // (M, K) + static_assert(Seqlen_traits::UseGQAPacking, "Special store_zero method for GQA packed layouts."); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + const int bidh_kv = qhead_per_khead_divmod.divide(bidh); + const int h_block = bidh % int(qhead_per_khead_divmod); + const int h_bound = min(shape<1>(epilogue_params.layout_O) - h_block * kBlockH, kBlockH); + const int m_bound = min(seqlen_traits_q.actual_seq_len - m_block * (kBlockM/kBlockH), kBlockM/kBlockH); + + if constexpr(!Is_split) { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O); + Tensor gO = seqlen_traits_q.get_o_local_tile_tensor( + mO, TileShapeOCopy{}, bidh_kv, bidb, n_split_idx) + (_, _, _, m_block, h_block); // (bM/bH, bH, K) + TiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + if constexpr(kNumRows <= kBlockH) { + // slice into bM/bH and write out zero tiles (bH, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO(0,_,_)); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + Tensor cO = cute::make_identity_tensor(select<1, 2>(TileShapeOCopy{})); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + // dummy predicate, unused since Is_even_K=true + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for(int m = 0; m < m_bound; ++m) { + tOgO = gmem_thr_copy_O.partition_D(gO(m,_,_)); + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, h_bound + ); + } + } else { + // slice into bH and write out zero tiles (bM/bH, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO(_,0,_)); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShapeOCopy{})); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + // dummy predicate, unused since Is_even_K=true + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for(int h = 0; h < h_bound; ++h) { + tOgO = gmem_thr_copy_O.partition_D(gO(_,h,_)); + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, m_bound + ); + } + } + } + + const int h_offset = shape<1>(epilogue_params.layout_O) * bidh_kv + h_block * kBlockH; + const int thread_idx_h = thread_idx % kBlockH; + const int thread_idx_m = thread_idx / kBlockH; + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); - Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( - mLSE, Shape>{}, bidh, bidb)(_, m_block); - - TiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_fragment_like(tOgO); - clear(tOrO); - // Construct identity layout for sO - Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM - ); - static_assert(kBlockM <= NumMmaThreads); - if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; } + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, h_offset + thread_idx_h, bidb, n_split_idx)(_, m_block); + if(thread_idx_h < h_bound && thread_idx_m < m_bound) { + gLSE(thread_idx_m) = !Is_split ? INFINITY : -INFINITY; + } } }; diff --git a/hopper/flash.h b/hopper/flash.h index 24ca27f69..bc1790dc2 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -49,6 +49,12 @@ struct Flash_fwd_params : public Qkv_params { index_t o_row_stride; index_t o_head_stride; + // The stride between rows of Oaccum. + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + index_t oaccum_split_stride; + // The pointer to the P matrix. void * __restrict__ p_ptr; @@ -58,6 +64,7 @@ struct Flash_fwd_params : public Qkv_params { // The dimensions. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q, total_k; + int b_k; // The scaling factors for the kernel. float scale_softmax; @@ -119,10 +126,8 @@ struct Flash_fwd_params : public Qkv_params { bool is_e4m3; bool is_causal; bool is_local; - - // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. - // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - bool is_seqlens_k_cumulative; + bool is_kv_cache; + bool use_gqa_packing; bool is_rotary_interleaved; @@ -132,6 +137,7 @@ struct Flash_fwd_params : public Qkv_params { index_t alibi_slopes_batch_stride; bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). int * __restrict__ tile_count_semaphore; float * __restrict__ descale_q_ptr; @@ -187,4 +193,5 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 638752e4d..a4ffa254f 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -21,6 +21,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, + const size_t b_k, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, @@ -52,6 +53,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.is_bf16 = q.dtype() == torch::kBFloat16; params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + params.is_kv_cache = false; // Set the pointers and strides. params.q_ptr = q.data_ptr(); @@ -97,6 +99,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Set the dimensions. params.b = b; + params.b_k = b_k; params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; @@ -137,8 +140,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_left = window_size_left; params.window_size_right = window_size_right; - params.is_causal = window_size_left == seqlen_k && window_size_right == 0; - if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) { + params.is_causal = window_size_left == int(seqlen_k) && window_size_right == 0; + if ((window_size_left < int(seqlen_k) || window_size_right < int(seqlen_k)) && !params.is_causal) { params.is_local = true; } @@ -147,13 +150,12 @@ void set_params_fprop(Flash_fwd_params ¶ms, "This flash attention build does not support local attention."); #endif - params.is_seqlens_k_cumulative = true; - #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -192,7 +194,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, bool deterministic) { set_params_fprop(params, - b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + b, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, @@ -236,11 +238,198 @@ void set_params_dgrad(Flash_bwd_params ¶ms, params.deterministic = deterministic; } -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { - // HEADDIM_SWITCH(params.d, [&] { - // run_mha_fwd_(params, stream); - // }); - if (!params.is_e4m3) { + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 80% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int batch_nheads, int num_SMs, int num_n_blocks, + int max_splits, int head_size, bool use_one_mma_wg) { + // Goal of the starting threshold is to determine whether to split or not. + // Empirically, the efficiency threshold can be much lower than 80% depending on num_n_blocks. + int num_m_blocks = batch_nheads_mblocks/batch_nheads; + float start_threshold; + float num_n_blocksf = float(num_n_blocks); + if (head_size == 128) { + if (std::log2f(num_n_blocksf) <= 4) { // 2048 -- .25 + start_threshold = .20f + (std::log2f(num_n_blocksf) - 3) * .05f; + } else if (std::log2f(num_n_blocksf) <= 5) { // 4096 -- .25 + start_threshold = .25f; + } else if (std::log2f(num_n_blocksf) <= 6) { // 8192 -- .36 + start_threshold = .28f + (std::log2f(num_n_blocksf) - 5) * .08f; + } else if (std::log2f(num_n_blocksf) <= 7) { // 16K -- .42 + start_threshold = .36f + (std::log2f(num_n_blocksf) - 6) * .06f; + } else { + // Just split freely + start_threshold = .8f; + } + if (num_m_blocks > 1 && start_threshold < .5f) + start_threshold += .05f * (std::log2f(num_n_blocksf) - 2); + } else if (head_size == 256) { + // TODO for hdim 256 + if (num_n_blocks <= 40) { + start_threshold = .24f; + } else if (std::log2f(num_n_blocksf) <= 8) { + start_threshold = .33f + std::max(0.f, (std::log2f(num_n_blocksf) - std::log2f(50)) * 0.02971f); + } else { + // Just split freely + start_threshold = .8f; + } + } else if (head_size == 64) { + if (use_one_mma_wg) { + if (std::log2f(num_n_blocksf) <= 4) { // 2K -- .33 + start_threshold = .33f; + } else if (std::log2f(num_n_blocksf) <= 5) { // 4K -- .37 + start_threshold = .33f + (std::log2f(num_n_blocksf) - 4) * .04f; + } else if (std::log2f(num_n_blocksf) <= 6) { // 8K -- .40 + start_threshold = .37f + (std::log2f(num_n_blocksf) - 5) * .03f; + } else if (std::log2f(num_n_blocksf) <= 7) { // 16K -- .43 + start_threshold = .4f + (std::log2f(num_n_blocksf) - 6) * .03f; + } else if (std::log2f(num_n_blocksf) <= 8) { // 32K -- .46 + start_threshold = .43f + (std::log2f(num_n_blocksf) - 7) * .03f; + } else { + start_threshold = .8f; + } + } else { + if (std::log2f(num_n_blocksf) <= 6) { // 8K -- .5 + start_threshold = .5f; + } else { + start_threshold = .8f; + } + } + } else { + // placeholder for other hdims + start_threshold = .8f; + } + + float first_wave = float(batch_nheads_mblocks) / num_SMs; + // printf("Start threshold and wave = %f, %f.\n", start_threshold, first_wave); + // Only use start_threshold if initial work doesn't exceed one wave + if ((first_wave/ceil(first_wave) > start_threshold && first_wave <= 1.f) || + (first_wave/ceil(first_wave) > .8f)) { + return 1; + } + // if (first_wave_batch_nheads > start_threshold) { return 1; } + // if (first_wave_batch_nheads > start_threshold || first_wave > .8f) { return 1; } + // if (float(batch_nheads)/num_SMs > start_threshold) { return 1; } + + // If num_n_blocks is too small, use 1 split + // For example, we never split for hdim = 128 and seqlen_k = 512, + // or for hdim = 128, seqlen_k = 1024, and one MMA warpgroup. + if (num_n_blocks < 8 || (use_one_mma_wg && num_n_blocks < 10)) { return 1; } + + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + + // NOTE: disable split eligibility check for FA3 since we have dynamic tile scheduler + // for exiting splits with no work early, and check leads to efficiency quantization issues. + // Comment from FA2: + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + // auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + // return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + // }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + // if (!is_split_eligible(num_splits)) { + // efficiency.push_back(0.f); + // } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, n_waves = %f, ceil(n_waves) = %f, eff = %f\n", num_splits, n_waves, ceil(n_waves), eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + // } + } + // Correct for excessive splitting with e.g. 1 bsz*nheads*mblocks + // Empirically, efficiency threshold in these cases is about 40% for 64K seqlen_k + float threshold = num_m_blocks == 1 ? std::min(0.3f + batch_nheads * 0.1f, 0.8f) : 0.8f; + threshold = threshold * max_efficiency; + // printf("Max efficiency = %f. Threshold = %f.\n", max_efficiency, threshold); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + // if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] > threshold) { + // printf("num_splits chosen = %d, threshold = %f, efficiency = %f.\n", num_splits, threshold, efficiency[num_splits - 1]); + return num_splits; + } + } + return 1; +} + +std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, + const int num_heads, const int num_heads_k, const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const float p_dropout, + const int num_splits, cudaDeviceProp *dprops, bool use_gqa_packing, bool is_causal, struct c10::TensorOptions opts) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + params.num_splits = num_splits; + at::Tensor softmax_lse_accum; + at::Tensor out_accum; + + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + const int gqa_ratio = num_heads / num_heads_k; + const int block_h = 1 << static_cast(std::ceil(std::log2(std::clamp(gqa_ratio, 1, 32)))); + const int block_m = head_size == 64 ? 192 : 128; + const bool use_one_mma_wg = max_seqlen_q <= 64/block_h; + + int block_n = 128; + if (head_size == 128 && !is_causal) { + block_n = 176; + } else if (head_size == 256) { + block_n = use_one_mma_wg ? 96 : 80; + } + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + const int batch_nheads = use_gqa_packing ? batch_size * num_heads_k : batch_size * num_heads; + const int batch_nheads_mblocks = use_gqa_packing + ? ceildiv(max_seqlen_q, block_m / block_h) * batch_nheads + : ceildiv(max_seqlen_q, block_m) * batch_nheads; + params.num_splits = num_splits_heuristic(batch_nheads_mblocks, batch_nheads, + dprops->multiProcessorCount, num_n_blocks, 128, head_size, use_one_mma_wg); + // printf("Num splits heuristic = %d.\n", params.num_splits); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.oaccum_batch_stride = out_accum.stride(-4); + params.oaccum_split_stride = out_accum.stride(0); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } + + return std::make_tuple(softmax_lse_accum, out_accum); +} + + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + + int dtype = 1; + if (params.is_bf16) { dtype = 2; } + else if (params.is_e4m3) { dtype = 3; } + PREC_SWITCH(dtype, Element, [&] { + HEADDIM_SWITCH(params.d, kHeadSize, [&] { + if(!params.use_gqa_packing) { + run_mha_fwd_(params, stream); + } else { + QUERYHEAD_SWITCH(params.h_h_k_ratio, kBlockH, [&] { + run_mha_fwd_gqa_(params, stream); + }); + } + }); + }); + +#if 0 + if (!params.is_e4m3) { if (params.is_bf16) { if (params.d == 64) { run_mha_fwd_(params, stream); @@ -265,8 +454,9 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split run_mha_fwd_(params, stream); } else if (params.d == 256) { run_mha_fwd_(params, stream); - } + } } +#endif } std::vector @@ -280,19 +470,17 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size c10::optional &descale_v_, // 1 bool is_causal, int window_size_left, - int window_size_right) { + int window_size_right, + bool use_gqa_packing = false + ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); + TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs or newer."); auto q_dtype = q.dtype(); - // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - // "FlashAttention only support fp16 and bf16 data type for now"); - // TODO: will add e4m3 later - // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, - // "FlashAttention only support fp16 and bf16 data type"); - // "FlashAttention only support fp16 and fp8 (e4m3) data type for now"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention-3 only support fp16, bf16, or fp8 e4m3 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -313,6 +501,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + // Guard against mistaken setting of gqa flag + if (num_heads == num_heads_k) { use_gqa_packing = false; } TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now"); @@ -336,7 +526,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size out = out_.value(); // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn - ? (out.dtype() == at::kHalf) + ? (out.dtype() == at::kBFloat16) : (out.dtype() == q_dtype), "Output must have the same dtype as input dtype if dtype is " "not fp8, or fp16 for fp8 input."); @@ -346,7 +536,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { if (q_dtype == at::ScalarType::Float8_e4m3fn) - out = torch::empty_like(q_padded, at::kHalf); + out = torch::empty_like(q_padded, at::kBFloat16); else out = torch::empty_like(q_padded); } @@ -370,7 +560,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size Flash_fwd_params params; set_params_fprop(params, - batch_size, + batch_size, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, @@ -387,26 +577,27 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size /*window_size_left=*/window_size_left, /*window_size_right=*/window_size_right); - auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + auto tile_count_semaphore = is_causal || params.is_local + ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + at::Tensor descale_q, descale_k, descale_v; if(q_dtype == at::ScalarType::Float8_e4m3fn) { - at::Tensor descale_q, descale_k, descale_v; - if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) { + if (descale_q_.has_value()) { descale_q = descale_q_.value(); - descale_k = descale_k_.value(); - descale_v = descale_v_.value(); CHECK_DEVICE(descale_q); - CHECK_DEVICE(descale_k); - CHECK_DEVICE(descale_v); CHECK_SHAPE(descale_q, 1); + } else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_k_.has_value()) { + descale_k = descale_k_.value(); + CHECK_DEVICE(descale_k); CHECK_SHAPE(descale_k, 1); + } else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_v_.has_value()) { + descale_v = descale_v_.value(); + CHECK_DEVICE(descale_v); CHECK_SHAPE(descale_v, 1); - } else { - descale_q = torch::ones({1}, opts.dtype(at::kFloat)); - descale_k = torch::ones({1}, opts.dtype(at::kFloat)); - descale_v = torch::ones({1}, opts.dtype(at::kFloat)); - } + } else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); } params.descale_q_ptr = descale_q.data_ptr(); params.descale_k_ptr = descale_k.data_ptr(); params.descale_v_ptr = descale_v.data_ptr(); @@ -415,6 +606,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size params.descale_k_ptr = nullptr; params.descale_v_ptr = nullptr; } + + params.use_gqa_packing = use_gqa_packing; if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -550,7 +743,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s Flash_fwd_params params; set_params_fprop(params, - batch_size, + batch_size, batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, @@ -1022,6 +1215,340 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 }; } +std::vector +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &seqlens_k_, // batch_size + c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + c10::optional &cache_batch_idx_, // indices to index into the KV cache + c10::optional &leftpad_k_, // batch_size + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + c10::optional &descale_q_, // 1 + c10::optional &descale_k_, // 1 + c10::optional &descale_v_, // 1 + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + int num_splits, + int max_seqlen_k_hint, + bool use_gqa_packing + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + // bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention-3 only support fp16, bf16, or fp8 e4m3 data type"); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : kcache.size(0); + const int page_block_size = !paged_KV ? 1 : kcache.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; + const int num_heads_k = kcache.size(2); + const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + // Guard against mistaken setting of gqa flag + if (num_heads == num_heads_k) { use_gqa_packing = false; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = + seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && + window_size_right < 0 && head_size_og % 8 == 0 && + !alibi_slopes_.has_value() && !use_gqa_packing; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + if (!paged_KV) { + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + } else { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + at::Tensor q_padded, kcache_padded, vcache_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + kcache_padded = kcache; + vcache_padded = vcache; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn + ? (out.dtype() == at::kBFloat16) + : (out.dtype() == q_dtype), + "Output must have the same dtype as input dtype if dtype is " + "not fp8, or fp16 for fp8 input."); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + if (q_dtype == at::ScalarType::Float8_e4m3fn) { + out = torch::empty_like(q_padded, at::kBFloat16); + } + else + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, batch_size_c, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, kcache_padded, vcache_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right + ); + + at::Tensor descale_q, descale_k, descale_v; + if(q_dtype == at::ScalarType::Float8_e4m3fn) { + if (descale_q_.has_value()) { + descale_q = descale_q_.value(); + CHECK_DEVICE(descale_q); + CHECK_SHAPE(descale_q, 1); + } else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_k_.has_value()) { + descale_k = descale_k_.value(); + CHECK_DEVICE(descale_k); + CHECK_SHAPE(descale_k, 1); + } else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_v_.has_value()) { + descale_v = descale_v_.value(); + CHECK_DEVICE(descale_v); + CHECK_SHAPE(descale_v, 1); + } else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); } + params.descale_q_ptr = descale_q.data_ptr(); + params.descale_k_ptr = descale_k.data_ptr(); + params.descale_v_ptr = descale_v.data_ptr(); + } else { + params.descale_q_ptr = nullptr; + params.descale_k_ptr = nullptr; + params.descale_v_ptr = nullptr; + } + + params.is_kv_cache = true; + + params.use_gqa_packing = use_gqa_packing; + + at::Tensor k, v, k_padded, v_padded; + if (k_.has_value()) { + TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); + TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); + TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); + k = k_.value(); + v = v_.value(); + TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); + TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); + CHECK_DEVICE(k); CHECK_DEVICE(v); + TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); + int seqlen_knew = k.size(1); + CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); + if (head_size_og % 8 != 0) { + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + k_padded = k; + v_padded = v; + } + params.seqlen_knew = seqlen_knew; + params.knew_ptr = k_padded.data_ptr(); + params.vnew_ptr = v_padded.data_ptr(); + // All stride are in elements, not bytes. + params.knew_batch_stride = k_padded.stride(0); + params.vnew_batch_stride = v_padded.stride(0); + params.knew_row_stride = k_padded.stride(-3); + params.vnew_row_stride = v_padded.stride(-3); + params.knew_head_stride = k_padded.stride(-2); + params.vnew_head_stride = v_padded.stride(-2); + } + + if (seqlens_k_.has_value()) { + auto seqlens_k = seqlens_k_.value(); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); + CHECK_SHAPE(seqlens_k, batch_size); + params.seqused_k = static_cast(seqlens_k.data_ptr()); + } + if (leftpad_k_.has_value()) { + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + TORCH_CHECK(false, "Left Padding K is not supported"); + //params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_cos); + TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_sin); + TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + } else { + params.rotary_dim = 0; + } + + if (cache_batch_idx_.has_value()) { + auto cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + CHECK_CONTIGUOUS(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); + params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); + } + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, num_heads_k, head_size, max_seqlen_k_hint, seqlen_q, + head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, use_gqa_packing, is_causal, opts); + + auto tile_count_semaphore = is_causal || params.is_local || params.num_splits != 1 + ? torch::zeros({1}, opts.dtype(torch::kInt32)) + : torch::empty({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + } + params.page_block_size = page_block_size; + + TORCH_CHECK(!alibi_slopes_.has_value(), "Alibi Slopes are not supported yet"); + //set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, + // or paged KV cache + //run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); + run_mha_fwd(params, stream); + + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + if (k_.has_value()) { + // It's expensive to copy the KV cache here for the case where head size not divisible by 8, + // but we don't expect to get this case in practice. This is just so that the code works for that case. + kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + } + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + + return {out, softmax_lse}; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; @@ -1029,4 +1556,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bwd", &mha_bwd, "Backward pass"); m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass"); + m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index db2bd1a11..e1e4ffd30 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None): +def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None, gqa_parallel=False): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, @@ -28,6 +28,7 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = causal, window_size[0], window_size[1], + gqa_parallel ) return out, q, k, v, out_padded, softmax_lse, S_dmask @@ -175,6 +176,7 @@ def forward( descale_q=None, descale_k=None, descale_v=None, + gqa_parallel=False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -187,13 +189,15 @@ def forward( window_size, descale_q=descale_q, descale_k=descale_k, - descale_v=descale_v, + descale_v=descale_v, + gqa_parallel=gqa_parallel, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.deterministic = deterministic + ctx.gqa_parallel = gqa_parallel return out, softmax_lse @staticmethod @@ -218,7 +222,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -309,6 +313,7 @@ def flash_attn_func( descale_q=None, descale_k=None, descale_v=None, + gqa_parallel=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -372,6 +377,7 @@ def flash_attn_func( descale_q, descale_k, descale_v, + gqa_parallel ) @@ -445,3 +451,180 @@ def flash_attn_varlen_func( seqused_q, seqused_k, ) + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + # k=None, + # v=None, + # rotary_cos=None, + # rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + # cache_leftpad: Optional[torch.Tensor] = None, + # block_table: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + # softcap=0.0, # 0.0 means deactivated + # rotary_interleaved=True, + # alibi_slopes=None, + num_splits=0, + return_softmax_lse=False, + gqa_parallel=None, + max_seqlen_k_hint=None, + descale_q=None, + descale_k=None, + descale_v=None, +): + """ + NOTE: The KV cache API for FlashAttention-3 is a work in progress. We reproduce the description + from the FlashAttention-2 method of the same name below. + + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + + # unimplemented kwargs + k=None + v=None + rotary_cos=None + rotary_sin=None + cache_leftpad=None + block_table=None + softcap=0.0 + rotary_interleaved=True + alibi_slopes=None + + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + cache_batch_idx = maybe_contiguous(cache_batch_idx) + # block_table = maybe_contiguous(block_table) + if gqa_parallel is None: + gqa_parallel = True if q.shape[1] <= 64 else False + # not in gqa/mqa setup + if q.shape[2] == k_cache.shape[2]: + gqa_parallel = False + if max_seqlen_k_hint is None: + max_seqlen_k_hint = k_cache.shape[1] + out, softmax_lse = flashattn_hopper_cuda.fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + None, + softmax_scale, + descale_q, + descale_k, + descale_v, + causal, + window_size[0], + window_size[1], + softcap, + rotary_interleaved, + num_splits, + max_seqlen_k_hint, + gqa_parallel + ) + return (out, softmax_lse) if return_softmax_lse else out diff --git a/hopper/flash_fwd_hdim128_bf16_gqa16_sm90.cu b/hopper/flash_fwd_hdim128_bf16_gqa16_sm90.cu new file mode 100644 index 000000000..d839721b1 --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_bf16_gqa2_sm90.cu b/hopper/flash_fwd_hdim128_bf16_gqa2_sm90.cu new file mode 100644 index 000000000..85d328151 --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_bf16_gqa32_sm90.cu b/hopper/flash_fwd_hdim128_bf16_gqa32_sm90.cu new file mode 100644 index 000000000..4bf5525c7 --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_bf16_gqa4_sm90.cu b/hopper/flash_fwd_hdim128_bf16_gqa4_sm90.cu new file mode 100644 index 000000000..486c762ff --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_bf16_gqa8_sm90.cu b/hopper/flash_fwd_hdim128_bf16_gqa8_sm90.cu new file mode 100644 index 000000000..157081389 --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_e4m3_gqa16_sm90.cu b/hopper/flash_fwd_hdim128_e4m3_gqa16_sm90.cu new file mode 100644 index 000000000..45ce0357d --- /dev/null +++ b/hopper/flash_fwd_hdim128_e4m3_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_e4m3_gqa2_sm90.cu b/hopper/flash_fwd_hdim128_e4m3_gqa2_sm90.cu new file mode 100644 index 000000000..1941fe4a2 --- /dev/null +++ b/hopper/flash_fwd_hdim128_e4m3_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_e4m3_gqa32_sm90.cu b/hopper/flash_fwd_hdim128_e4m3_gqa32_sm90.cu new file mode 100644 index 000000000..c3c2d5e2f --- /dev/null +++ b/hopper/flash_fwd_hdim128_e4m3_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_e4m3_gqa4_sm90.cu b/hopper/flash_fwd_hdim128_e4m3_gqa4_sm90.cu new file mode 100644 index 000000000..834109070 --- /dev/null +++ b/hopper/flash_fwd_hdim128_e4m3_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_e4m3_gqa8_sm90.cu b/hopper/flash_fwd_hdim128_e4m3_gqa8_sm90.cu new file mode 100644 index 000000000..98cdac676 --- /dev/null +++ b/hopper/flash_fwd_hdim128_e4m3_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_fp16_gqa16_sm90.cu b/hopper/flash_fwd_hdim128_fp16_gqa16_sm90.cu new file mode 100644 index 000000000..988041bf6 --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_fp16_gqa2_sm90.cu b/hopper/flash_fwd_hdim128_fp16_gqa2_sm90.cu new file mode 100644 index 000000000..92936c1d7 --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_fp16_gqa32_sm90.cu b/hopper/flash_fwd_hdim128_fp16_gqa32_sm90.cu new file mode 100644 index 000000000..103931349 --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_fp16_gqa4_sm90.cu b/hopper/flash_fwd_hdim128_fp16_gqa4_sm90.cu new file mode 100644 index 000000000..2d369fcb3 --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_fp16_gqa8_sm90.cu b/hopper/flash_fwd_hdim128_fp16_gqa8_sm90.cu new file mode 100644 index 000000000..e556921af --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_gqa16_sm90.cu b/hopper/flash_fwd_hdim256_bf16_gqa16_sm90.cu new file mode 100644 index 000000000..2c9c35652 --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_gqa2_sm90.cu b/hopper/flash_fwd_hdim256_bf16_gqa2_sm90.cu new file mode 100644 index 000000000..5e72b41c4 --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_gqa32_sm90.cu b/hopper/flash_fwd_hdim256_bf16_gqa32_sm90.cu new file mode 100644 index 000000000..90ae2162a --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_gqa4_sm90.cu b/hopper/flash_fwd_hdim256_bf16_gqa4_sm90.cu new file mode 100644 index 000000000..b7c6345b2 --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_gqa8_sm90.cu b/hopper/flash_fwd_hdim256_bf16_gqa8_sm90.cu new file mode 100644 index 000000000..566760319 --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_e4m3_gqa16_sm90.cu b/hopper/flash_fwd_hdim256_e4m3_gqa16_sm90.cu new file mode 100644 index 000000000..9c0f7d626 --- /dev/null +++ b/hopper/flash_fwd_hdim256_e4m3_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_e4m3_gqa2_sm90.cu b/hopper/flash_fwd_hdim256_e4m3_gqa2_sm90.cu new file mode 100644 index 000000000..c41ac3d4e --- /dev/null +++ b/hopper/flash_fwd_hdim256_e4m3_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_e4m3_gqa32_sm90.cu b/hopper/flash_fwd_hdim256_e4m3_gqa32_sm90.cu new file mode 100644 index 000000000..b486e1a39 --- /dev/null +++ b/hopper/flash_fwd_hdim256_e4m3_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_e4m3_gqa4_sm90.cu b/hopper/flash_fwd_hdim256_e4m3_gqa4_sm90.cu new file mode 100644 index 000000000..2b9701786 --- /dev/null +++ b/hopper/flash_fwd_hdim256_e4m3_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_e4m3_gqa8_sm90.cu b/hopper/flash_fwd_hdim256_e4m3_gqa8_sm90.cu new file mode 100644 index 000000000..ebe0f92ca --- /dev/null +++ b/hopper/flash_fwd_hdim256_e4m3_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp16_gqa16_sm90.cu b/hopper/flash_fwd_hdim256_fp16_gqa16_sm90.cu new file mode 100644 index 000000000..91fc6200e --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp16_gqa2_sm90.cu b/hopper/flash_fwd_hdim256_fp16_gqa2_sm90.cu new file mode 100644 index 000000000..21a81044a --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp16_gqa32_sm90.cu b/hopper/flash_fwd_hdim256_fp16_gqa32_sm90.cu new file mode 100644 index 000000000..502a66281 --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp16_gqa4_sm90.cu b/hopper/flash_fwd_hdim256_fp16_gqa4_sm90.cu new file mode 100644 index 000000000..e6dc49dc6 --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp16_gqa8_sm90.cu b/hopper/flash_fwd_hdim256_fp16_gqa8_sm90.cu new file mode 100644 index 000000000..046c9e304 --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_gqa16_sm90.cu b/hopper/flash_fwd_hdim64_bf16_gqa16_sm90.cu new file mode 100644 index 000000000..0381c601e --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_gqa2_sm90.cu b/hopper/flash_fwd_hdim64_bf16_gqa2_sm90.cu new file mode 100644 index 000000000..6be1d9c58 --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_gqa32_sm90.cu b/hopper/flash_fwd_hdim64_bf16_gqa32_sm90.cu new file mode 100644 index 000000000..154efcac5 --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_gqa4_sm90.cu b/hopper/flash_fwd_hdim64_bf16_gqa4_sm90.cu new file mode 100644 index 000000000..b8fe56a32 --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_gqa8_sm90.cu b/hopper/flash_fwd_hdim64_bf16_gqa8_sm90.cu new file mode 100644 index 000000000..cda356c26 --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_e4m3_gqa16_sm90.cu b/hopper/flash_fwd_hdim64_e4m3_gqa16_sm90.cu new file mode 100644 index 000000000..74e61967a --- /dev/null +++ b/hopper/flash_fwd_hdim64_e4m3_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_e4m3_gqa2_sm90.cu b/hopper/flash_fwd_hdim64_e4m3_gqa2_sm90.cu new file mode 100644 index 000000000..ff8213c05 --- /dev/null +++ b/hopper/flash_fwd_hdim64_e4m3_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_e4m3_gqa32_sm90.cu b/hopper/flash_fwd_hdim64_e4m3_gqa32_sm90.cu new file mode 100644 index 000000000..22ce8ed06 --- /dev/null +++ b/hopper/flash_fwd_hdim64_e4m3_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_e4m3_gqa4_sm90.cu b/hopper/flash_fwd_hdim64_e4m3_gqa4_sm90.cu new file mode 100644 index 000000000..b0f09e780 --- /dev/null +++ b/hopper/flash_fwd_hdim64_e4m3_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_e4m3_gqa8_sm90.cu b/hopper/flash_fwd_hdim64_e4m3_gqa8_sm90.cu new file mode 100644 index 000000000..16775723d --- /dev/null +++ b/hopper/flash_fwd_hdim64_e4m3_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp16_gqa16_sm90.cu b/hopper/flash_fwd_hdim64_fp16_gqa16_sm90.cu new file mode 100644 index 000000000..cbe5159d1 --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp16_gqa2_sm90.cu b/hopper/flash_fwd_hdim64_fp16_gqa2_sm90.cu new file mode 100644 index 000000000..f18c68b23 --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp16_gqa32_sm90.cu b/hopper/flash_fwd_hdim64_fp16_gqa32_sm90.cu new file mode 100644 index 000000000..a4cf2813d --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp16_gqa4_sm90.cu b/hopper/flash_fwd_hdim64_fp16_gqa4_sm90.cu new file mode 100644 index 000000000..8e9932dbd --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp16_gqa8_sm90.cu b/hopper/flash_fwd_hdim64_fp16_gqa8_sm90.cu new file mode 100644 index 000000000..79cbce7d0 --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index 9517c5e0c..4c5a109ad 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -24,31 +24,31 @@ namespace flash { using namespace cute; -template +template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, - CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, - Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k + Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k ) { using Element = typename Ktraits::Element; - using ElementAccum = typename Ktraits::ElementAccum; - using SoftType = ElementAccum; using TileShape_MNK = typename Ktraits::TileShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK; static_assert(Ktraits::Is_WS); static constexpr bool Is_WS = Ktraits::Is_WS; + static constexpr bool No_smem_O = Ktraits::No_smem_O; static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockH = Ktraits::kBlockH; // static constexpr int kBlockN = Ktraits::kBlockN; - // constexpr int kHeadDim = Ktraits::kHeadDim; + // static constexpr int kHeadDim = Ktraits::kHeadDim; - using CollectiveMainloop = CollectiveMainloopFwd; - using CollectiveEpilogue = CollectiveEpilogueFwd; + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; @@ -80,7 +80,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_Q.init(1 /*numThreads*/); - shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); + if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); } } // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); @@ -97,10 +97,10 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, __syncthreads(); } - static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + // static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); - // cutlass::arch::warpgroup_reg_dealloc<56>(); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if (warp_idx_in_warpgroup == 0) { // Load Q, K, V @@ -114,32 +114,37 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [m_block, bidh, bidb] = block_coord; + auto [m_block, n_split_idx, bidh, bidb] = block_coord; seqlen_traits_q.init(bidb); seqlen_traits_k.init(bidb); - if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { - continue; + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { + continue; + } } - const int n_block_max = collective_mainloop.get_n_block_max( - mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - const int n_block_min = collective_mainloop.get_n_block_min( - mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) { - scheduler.prefetch_next_work(scheduler_params, work_tile_info); - scheduler.broadcast_next_work(work_tile_info); - continue; + int n_block_min = 0, n_block_max; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } } - collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, - shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, - seqlen_traits_q, seqlen_traits_k); + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, + seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max); ++work_idx; } collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); } } else { // Consumer - cutlass::arch::warpgroup_reg_alloc(); - // cutlass::arch::warpgroup_reg_alloc(); + cutlass::arch::warpgroup_reg_alloc(); TileScheduler scheduler(&shared_storage.tile_count_semaphore); // Initialize matmul objects. @@ -162,28 +167,41 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax(mainloop_params.softmax_scale_log2); auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [m_block, bidh, bidb] = block_coord; + auto [m_block, n_split_idx, bidh, bidb] = block_coord; seqlen_traits_q.init(bidb); seqlen_traits_k.init(bidb); - if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { - continue; - } - const int n_block_max = collective_mainloop.get_n_block_max( - mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - const int n_block_min = collective_mainloop.get_n_block_min( - mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. - collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); - continue; + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { + continue; + } } - - collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, - tOrO, softmax, n_block_max, n_block_min, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage, - seqlen_traits_q, seqlen_traits_k); - // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); - collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); + int n_block_max, n_block_min = 0; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. + if constexpr(!Seqlen_traits_Q::UseGQAPacking) { + collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q); + } else { + collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + } + continue; + } + } + + collective_mainloop.mma( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, + tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx, + m_block, shared_storage, seqlen_traits_q, seqlen_traits_k); + // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); + collective_epilogue.store( + epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); ++work_idx; } @@ -192,35 +210,34 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } -template +template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, - CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, - Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k + Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k ) { using Element = typename Ktraits::Element; static_assert(cutlass::sizeof_bits_v == 8); - using ElementAccum = typename Ktraits::ElementAccum; - using SoftType = ElementAccum; using TileShape_MNK = typename Ktraits::TileShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK; static_assert(Ktraits::Is_WS); static constexpr bool Is_WS = Ktraits::Is_WS; - static constexpr bool kUseVarSeqLen = Seqlen_traits::kUseVarSeqLen; + static constexpr bool No_smem_O = Ktraits::No_smem_O; static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockH = Ktraits::kBlockH; // static constexpr int kBlockN = Ktraits::kBlockN; // static constexpr int kHeadDim = Ktraits::kHeadDim; - static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128; + static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128 && Ktraits::kNWarps != 8; static constexpr bool Use_max_offset = true; - using CollectiveMainloop = CollectiveMainloopFwd; - using CollectiveEpilogue = CollectiveEpilogueFwd; + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA; @@ -260,7 +277,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_Q.init(1 /*numThreads*/); - shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); + if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); } } // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); @@ -277,6 +294,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, float descale_v = *mainloop_params.descale_v_ptr; shared_storage.softmax_scale_qk_log2 = mainloop_params.softmax_scale_log2 * descale_q * descale_k; shared_storage.descale_v = descale_v; + shared_storage.seqlen_init_k = seqlen_traits_k.UseVarSeqLen || bool(seqlen_traits_k.seq_used); // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { @@ -286,10 +304,10 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, __syncthreads(); } - static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); if (warp_group_idx == 0) { // Producer - cutlass::arch::warpgroup_reg_dealloc(); - + cutlass::arch::warpgroup_reg_dealloc(); + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState smem_pipe_read, smem_pipe_release; @@ -300,19 +318,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [m_block, bidh, bidb] = block_coord; + auto [m_block, n_split_idx, bidh, bidb] = block_coord; - if constexpr(kUseVarSeqLen) { - seqlen_traits_q.init(bidb); - seqlen_traits_k.init(bidb); - if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { + if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); } + if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); } + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { continue; } } - int n_block_max = collective_mainloop.get_n_block_max( - mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if constexpr(Is_causal) { - if(n_block_max <= 0) { + int n_block_min = 0, n_block_max; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local ||seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); // need to sync producer warpgroup @@ -321,10 +342,9 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } } collective_mainloop.load_fp8( - mainloop_params, pipeline_k, pipeline_v, pipeline_vt, - smem_pipe_write, smem_pipe_read, shared_storage, - scheduler, scheduler_params, work_tile_info, block_coord, work_idx, - seqlen_traits_q, seqlen_traits_k); + mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, + seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max); ++work_idx; // don't need to sync producer warpgroup here // if constexpr (Is_causal) { @@ -332,7 +352,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write); } else { // Consumer - cutlass::arch::warpgroup_reg_alloc(); + cutlass::arch::warpgroup_reg_alloc(); TileScheduler scheduler(&shared_storage.tile_count_semaphore); // Initialize matmul objects. @@ -344,6 +364,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, scheduler.init_consumer(); int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); @@ -353,37 +374,42 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax(shared_storage.softmax_scale_qk_log2); auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [m_block, bidh, bidb] = block_coord; + auto [m_block, n_split_idx, bidh, bidb] = block_coord; - if constexpr(kUseVarSeqLen) { - seqlen_traits_q.init(bidb); - seqlen_traits_k.init(bidb); - if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { + if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); } + if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); } + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { continue; } } - int n_block_max = collective_mainloop.get_n_block_max( - mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if constexpr(Is_causal) { - if(n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. - collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); + int n_block_max, n_block_min = 0; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. + if constexpr(!Seqlen_traits_Q::UseGQAPacking) { + collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q); + } else { + collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + } continue; } } collective_mainloop.mma_fp8( mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release, - tOrO, softmax, n_block_max, - threadIdx.x - NumCopyThreads, work_idx, m_block, - shared_storage, seqlen_traits_q, seqlen_traits_k); - - #ifndef NO_FP8_COLUMN_PERMUTE - collective_epilogue.store_fp8(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); - #else - collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); - #endif + tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, + shared_storage, seqlen_traits_q, seqlen_traits_k); + + collective_epilogue.store( + epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + ++work_idx; } collective_epilogue.store_tail(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 0c0790e4b..46363b18f 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -16,46 +16,55 @@ #include "kernel_traits.h" #include "seq_len.h" #include "utils.h" +#include "combine.h" - -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; using OutputType = typename Kernel_traits::OutputType; using TileShape_MNK = typename Kernel_traits::TileShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK; - // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); - using CollectiveMainloop = flash::CollectiveMainloopFwd; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + constexpr static bool Is_split = Kernel_traits::Is_split; + static_assert(Seqlen_traits_Q::UseGQAPacking == (Kernel_traits::kBlockH > 1), "If kBlockH > 1, use gqa packed layouts"); + static_assert(!(Is_split && Seqlen_traits::UseVarSeqLen), "Split KV not yet supported for variable seqlen."); + + using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; using Scheduler = std::conditional_t< - Seqlen_traits::kUseVarSeqLen || Is_local, + Seqlen_traits::UseVarSeqLen, flash::SingleTileScheduler, - std::conditional_t + std::conditional_t, + flash::DynamicPersistentTileScheduler< + Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, + Kernel_traits::NumProducerThreads, + Is_split + > >>; // using Scheduler = flash::SingleTileScheduler; - Seqlen_traits seqlen_traits_q( + Seqlen_traits_Q seqlen_traits_q( params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q); Seqlen_traits seqlen_traits_k( params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k); + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ - static_cast(params.q_ptr), + static_cast(params.q_ptr), seqlen_traits_q.get_gmem_layout( - params.seqlen_q, params.d, params.h, params.b, + params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, params.q_row_stride, params.q_head_stride, params.q_batch_stride ), // layout_Q static_cast(params.k_ptr), seqlen_traits_k.get_gmem_layout( - params.seqlen_k, params.d, params.h_k, params.b, + params.seqlen_k, params.d, params.h_k, params.b_k, params.k_row_stride, params.k_head_stride, params.k_batch_stride ), // layout_K static_cast(params.v_ptr), seqlen_traits_k.get_gmem_layout( - params.seqlen_k, params.d, params.h_k, params.b, + params.seqlen_k, params.d, params.h_k, params.b_k, params.v_row_stride, params.v_head_stride, params.v_batch_stride ), // layout_V params.scale_softmax_log2, @@ -63,32 +72,53 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.descale_k_ptr, params.descale_v_ptr, params.window_size_left, - params.window_size_right - }); - typename CollectiveEpilogue::Params epilogue_params = - CollectiveEpilogue::to_underlying_arguments({ - static_cast(params.o_ptr), - seqlen_traits_q.get_gmem_layout( - params.seqlen_q, params.d, params.h, params.b, - params.o_row_stride, params.o_head_stride, params.o_batch_stride - ), // layout_O - static_cast(params.softmax_lse_ptr), - seqlen_traits_q.get_lse_gmem_layout( - params.seqlen_q, params.h, params.b - ) // layout_LSE + params.window_size_right, + ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH), + params.cache_batch_idx, + Is_split ? params.num_splits : 1 }); + typename CollectiveEpilogue::Params epilogue_params = [&] { + if constexpr(!Is_split) { + return CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.o_ptr), + seqlen_traits_q.get_gmem_layout( + params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, + params.o_row_stride, params.o_head_stride, params.o_batch_stride + ), // layout_O + static_cast(params.softmax_lse_ptr), + seqlen_traits_q.get_lse_gmem_layout( + params.seqlen_q, params.h, params.b + ) // layout_LSE + }); + } else { + return CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.oaccum_ptr), + seqlen_traits_q.get_oaccum_gmem_layout( + params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, params.num_splits, + params.oaccum_row_stride, params.oaccum_head_stride, params.oaccum_batch_stride, + params.oaccum_split_stride + ), // layout_O + static_cast(params.softmax_lseaccum_ptr), + seqlen_traits_q.get_lseaccum_gmem_layout( + params.seqlen_q, params.h, params.b, params.num_splits + ), // layout_LSE + }); + } + }(); - int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); - num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); - typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore}; - typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM/Kernel_traits::kBlockH); + num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); + int num_blocks_h = params.h_k * ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH); + typename Scheduler::Arguments scheduler_args = + {num_blocks_m, Is_split ? params.num_splits : 1, num_blocks_h, params.b, params.tile_count_semaphore}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); // Get the ptr to kernel function. void *kernel; if constexpr(cutlass::sizeof_bits_v == 8) - kernel = (void *)flash::compute_attn_ws_fp8; + kernel = (void *)flash::compute_attn_ws_fp8; else - kernel = (void *)flash::compute_attn_ws; + kernel = (void *)flash::compute_attn_ws; int smem_size = sizeof(typename Kernel_traits::SharedStorage); // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); @@ -106,148 +136,407 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); static constexpr int ctaSize = Kernel_traits::kNWarps * 32; dim3 block_dims(ctaSize); - dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); - cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; - cutlass::launch_kernel_on_cluster( - launch_params, kernel, mainloop_params, epilogue_params, - scheduler_params, seqlen_traits_q, seqlen_traits_k); + if constexpr(size(ClusterShape{}) > 1) { + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster( + launch_params, kernel, mainloop_params, epilogue_params, + scheduler_params, seqlen_traits_q, seqlen_traits_k); + } else { + if constexpr(cutlass::sizeof_bits_v == 8) { + flash::compute_attn_ws_fp8 + <<>> + (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k); + } else { + flash::compute_attn_ws + <<>> + (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k); + } + + } CHECK_CUDA_KERNEL_LAUNCH(); + + if constexpr (Is_split) { + using FinalOutputType = typename Kernel_traits::FinalOutputType; + static_assert(is_same_v, "Assume OutputType of main kernel is float."); + static_assert(is_same_v, "ElementAccum must be float."); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kHeadDim = Kernel_traits::kHeadDim; + constexpr static int kBlockM = kHeadDim % 128 == 0 ? 4 : (kHeadDim % 64 == 0 ? 8 : 16); + constexpr static bool Is_even_K = true; // always true for our current setting + void *kernel_combine; + int smem_size_combine; + NUM_SPLITS_SWITCH(params.num_splits, kLogMaxSplits, [&] { + constexpr static int kMaxSplits = 1 << kLogMaxSplits; + kernel_combine = (void *) flash::combine_attn_seqk_parallel< + FinalOutputType, ElementAccum, kHeadDim, kBlockM, kLogMaxSplits, Is_even_K, Flash_fwd_params>; + smem_size_combine = sizeof( + flash::SharedStorageLSE, Int>, Shape>>); + }); + if (smem_size_combine >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel_combine, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_combine)); + } + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + dim3 block_dims_combine(128); + dim3 cluster_dims_combine(1, 1, 1); + cutlass::ClusterLaunchParams launch_params_combine{ + grid_combine, block_dims_combine, cluster_dims_combine, smem_size_combine, stream}; + cutlass::launch_kernel_on_cluster(launch_params_combine, kernel_combine, params); + CHECK_CUDA_KERNEL_LAUNCH(); + } } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; + constexpr static bool UseCluster = false; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.is_local, Is_local, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 3 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] { run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Is_local && !Is_causal, Seqlen_traits + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits >(params, stream); + // }); }); + }); }); + }); }); } template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { + + MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_local, Is_local, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Is_local && !Is_causal, Seqlen_traits - >(params, stream); - }); + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + // and not Is_causal, Is_split, or varseqlen + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits + >(params, stream); + }); }); + }); }); + }); }); } + + template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { + + MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_local, Is_local, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Is_local && !Is_causal, Seqlen_traits - >(params, stream); - }); + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + // and not Is_causal, Is_split, or varseqlen + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits + >(params, stream); + }); }); + }); }); + }); }); } template void run_mha_fwd_hdim64_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - constexpr static int kBlockM = 192; constexpr static int kBlockN = 128; - constexpr static int kNWarps = 4 + kBlockM/16; - constexpr static int kStages = 4; + constexpr static int kStages = 4; + // constexpr static bool UseCluster = false; + // constexpr static int kBlockM = 192; + // constexpr static int kNWarps = 4 + kBlockM/16; using Seqlen_traits = flash::FixedSeqLenTraits; - if(params.is_causal) { - run_flash_fwd, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); - } else { - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { - run_flash_fwd, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); + + MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 3, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits_fp8, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits + >(params, stream); + }); + }); }); - } - // BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal && - // !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - // run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - // }); - // }); - // }); + }); + }); } template void run_mha_fwd_hdim128_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - constexpr static int kBlockM = 128; constexpr static int kBlockN = 256; - constexpr static int kNWarps = 4 + kBlockM/16; constexpr static int kStages = 2; + // constexpr static int kBlockM = 128; + // constexpr static int kNWarps = 4 + kBlockM/16; using Seqlen_traits = flash::FixedSeqLenTraits; - if(params.is_causal) { - run_flash_fwd, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); - } else { - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { - run_flash_fwd, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); + + MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits_fp8, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits + >(params, stream); + }); + }); }); - } - // BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal && - // !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - // run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - // }); - // }); - // }); + }); + }); } template void run_mha_fwd_hdim256_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; - constexpr static int kBlockM = 128; constexpr static int kBlockN = 128; - constexpr static int kNWarps = 4 + kBlockM/16; constexpr static int kStages = 2; + // constexpr static int kBlockM = 128; + // constexpr static int kNWarps = 4 + kBlockM/16; using Seqlen_traits = flash::FixedSeqLenTraits; - if(params.is_causal) { - run_flash_fwd, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); - } else { - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { - run_flash_fwd, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); + + MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits_fp8, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits + >(params, stream); + }); + }); }); - } - // BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal && - // !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - // run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - // }); - // }); - // }); + }); + }); +} + +/* +** GQA methods +*/ + +template +void run_mha_fwd_hdim64_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 3, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim128_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim256_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim64_fp8_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + constexpr static int kBlockN = 128; + constexpr static int kStages = 4; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 3, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits_fp8, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim128_fp8_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + constexpr static int kBlockN = 256; + constexpr static int kStages = 2; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits_fp8, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim256_fp8_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + constexpr static int kBlockN = 128; + constexpr static int kStages = 2; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits_fp8, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); } diff --git a/hopper/kernel_traits.h b/hopper/kernel_traits.h index 684833ab8..e9210731d 100644 --- a/hopper/kernel_traits.h +++ b/hopper/kernel_traits.h @@ -33,6 +33,42 @@ struct SharedStorageQKVO { }; }; +// Use if Oaccum is too large for SharedStorageQKVO +template +struct SharedStorageQKVOaccum { + cute::array_aligned> smem_q; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + int tile_count_semaphore; + }; +}; + +// SharedStorage struct with no smem for O +template +struct SharedStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + int tile_count_semaphore; + }; +}; + template struct SharedStorageQKVOVt { @@ -54,16 +90,67 @@ struct SharedStorageQKVOVt { int tile_count_semaphore; float softmax_scale_qk_log2; float descale_v; + bool seqlen_init_k; + }; +}; + +// Use if Oaccum is too large for SharedStorageQKVOVt +template +struct SharedStorageQKVOVtaccum { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + struct { + cute::array_aligned> smem_v; + cute::array_aligned> smem_v_out; + }; + cute::array_aligned> smem_o; + }; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + typename cutlass::PipelineAsync::SharedStorage pipeline_vt; + int tile_count_semaphore; + float softmax_scale_qk_log2; + float descale_v; + bool seqlen_init_k; + }; +}; + +template +struct SharedStorageQKVVt { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + cute::array_aligned> smem_v_out; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + typename cutlass::PipelineAsync::SharedStorage pipeline_vt; + int tile_count_semaphore; + float softmax_scale_qk_log2; + float descale_v; + bool seqlen_init_k; }; }; // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template + int kClusterM_ = 1, typename elem_type=cutlass::half_t, bool Is_split_=false, int kBlockH_ = 1> struct Flash_fwd_kernel_traits { using Element = elem_type; using ElementAccum = float; - using OutputType = elem_type; + using FinalOutputType = elem_type; + using OutputType = std::conditional_t; using index_t = int64_t; // The number of threads. @@ -72,14 +159,16 @@ struct Flash_fwd_kernel_traits { static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; - static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); - static constexpr bool Is_WS = kNWarps_ >= 12; + static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); + static constexpr bool Is_WS = true; static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers"); static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockH = kBlockH_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); + static_assert(kBlockM % kBlockH == 0); using TileShape_MNK = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; @@ -87,6 +176,9 @@ struct Flash_fwd_kernel_traits { static constexpr int kStages = kStages_; + static constexpr bool Is_split = Is_split_; + static constexpr bool No_smem_O = Is_split; + using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( std::conditional_t< @@ -104,6 +196,14 @@ struct Flash_fwd_kernel_traits { decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + // for gmem -> smem Q copy + using FactoringLayoutQ = Layout, Int, Int>, + Stride, _1, Int>>; + using TileShapeQCopy = std::conditional_t<(kBlockH > 1), + decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>; + using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1), + decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>; + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = @@ -122,15 +222,20 @@ struct Flash_fwd_kernel_traits { make_ordered_layout( make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), Step<_2, _1, _3>{}))); - + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + // for smem -> gmem O copy + using TileShapeOCopy = TileShapeQCopy; + using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1), + decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>; using SmemCopyAtomQ = Copy_Atom; - using SharedStorage = SharedStorageQKVO; + using SharedStorage = std::conditional_t, + SharedStorageQKV>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using MainloopPipelineNoTMA = typename cutlass::PipelineAsync; @@ -141,13 +246,19 @@ struct Flash_fwd_kernel_traits { // Traits struct for fp8 kernel with in-kernel transpose template + int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t, bool Is_split_ = false, int kBlockH_ = 1> struct Flash_fwd_kernel_traits_fp8 { using Element = elem_type; static_assert(cutlass::sizeof_bits_v == 8); using ElementAccum = float; - using OutputType = cutlass::half_t; - using index_t = int64_t; + using FinalOutputType = cutlass::bfloat16_t; + using OutputType = std::conditional_t; + using index_t = int64_t; + + static constexpr bool Is_split = Is_split_; + static constexpr bool No_smem_O = false; + // NOTE: not using smem for epilogue degrades perf substantially. + // static constexpr bool No_smem_O = Is_split; // The number of threads. static constexpr int kNWarps = kNWarps_; @@ -155,14 +266,16 @@ struct Flash_fwd_kernel_traits_fp8 { static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; - static_assert(kNWarps_ == 12 || kNWarps_ == 16); + static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); static constexpr bool Is_WS = true; static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers"); static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockH = kBlockH_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); + static_assert(kBlockM % kBlockH == 0); using TileShape_MNK = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; @@ -171,6 +284,9 @@ struct Flash_fwd_kernel_traits_fp8 { static constexpr int kStages = kStages_; static_assert(kStages > 1); + // Use this to save enough smem when writing out in float precision. + static constexpr bool VO_union_all = Is_split && (kBlockM != 64) && (kHeadDim == 256); + using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), @@ -184,6 +300,14 @@ struct Flash_fwd_kernel_traits_fp8 { decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + // for gmem -> smem Q copy + using FactoringLayoutQ = Layout, Int, Int>, + Stride, _1, Int>>; + using TileShapeQCopy = std::conditional_t<(kBlockH > 1), + decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>; + using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1), + decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>; + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = @@ -226,6 +350,10 @@ struct Flash_fwd_kernel_traits_fp8 { using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + // for smem -> gmem O copy + using TileShapeOCopy = TileShapeQCopy; + using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1), + decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>; // used for rmem -> smem O copy in fp8 kernel to undo column permutation using ThreadLayoutrO = Layout, _4, _1>, @@ -240,8 +368,11 @@ struct Flash_fwd_kernel_traits_fp8 { using SmemCopyAtomQ = Copy_Atom; - using SharedStorage = SharedStorageQKVOVt; + using SharedStorage = std::conditional_t, + SharedStorageQKVOVtaccum>, + SharedStorageQKVVt>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using MainloopPipelineNoTMA = typename cutlass::PipelineAsync; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 111421580..cf8a8fa7c 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -79,7 +79,7 @@ struct SmemTransposeFp8_64x64 { } }; -template +template struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; @@ -87,12 +87,19 @@ struct CollectiveMainloopFwd { using ClusterShape = typename Ktraits::ClusterShape_MNK; static constexpr int kStages = Ktraits::kStages; - static constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr int kHeadDim = Ktraits::kHeadDim; + // static constexpr int kBlockM = Ktraits::kBlockM; + // static constexpr int kBlockN = Ktraits::kBlockN; + // static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr bool Is_split = Ktraits::Is_split; + static constexpr bool No_smem_O = Ktraits::No_smem_O; using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutQCopy = typename Ktraits::SmemLayoutQCopy; + using TileShapeQCopy = typename Ktraits::TileShapeQCopy; using SmemLayoutK = typename Ktraits::SmemLayoutK; using SmemLayoutV = typename Ktraits::SmemLayoutV; using SmemLayoutVt = typename Ktraits::SmemLayoutVt; @@ -101,11 +108,11 @@ struct CollectiveMainloopFwd { GmemTiledCopyQ{}, make_tensor( make_gmem_ptr(static_cast(nullptr)), - repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), - typename Seqlen_traits::StrideT{} + repeat_like(typename Seqlen_traits_Q::StrideT{}, int32_t(0)), + typename Seqlen_traits_Q::StrideT{} ), - SmemLayoutQ{}, - select<0, 2>(TileShape_MNK{}), + SmemLayoutQCopy{}, + TileShapeQCopy{}, _1{})); // no mcast for Q using TMA_K = decltype(make_tma_copy( @@ -142,14 +149,13 @@ struct CollectiveMainloopFwd { static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; - static constexpr bool UseSchedulerBarrier = - cutlass::sizeof_bits_v == 8 ? kHeadDim >= 128 - : kHeadDim <= 128; + static constexpr bool UseSchedulerBarrier = Ktraits::kNWarps >= 12 && + (cutlass::sizeof_bits_v == 8 ? kHeadDim >= 128 : kHeadDim <= 128); // Host side kernel arguments struct Arguments { Element const* ptr_Q; - typename Seqlen_traits::LayoutT layout_Q; + typename Seqlen_traits_Q::LayoutT layout_Q; Element const* ptr_K; typename Seqlen_traits::LayoutT layout_K; Element const* ptr_V; @@ -160,11 +166,14 @@ struct CollectiveMainloopFwd { float const* descale_v_ptr; int window_size_left; int window_size_right; + int const qhead_per_khead; + int const* cache_batch_idx; + int const num_splits; }; // Device side kernel params struct Params { - typename Seqlen_traits::LayoutT layout_Q; + typename Seqlen_traits_Q::LayoutT layout_Q; typename Seqlen_traits::LayoutT layout_K; typename Seqlen_traits::LayoutT layout_V; cutlass::FastDivmod qhead_per_khead_divmod; @@ -177,6 +186,8 @@ struct CollectiveMainloopFwd { float const* descale_v_ptr; int window_size_left; int window_size_right; + int const* cache_batch_idx; + cutlass::FastDivmod num_splits_divmod; }; @@ -186,8 +197,8 @@ struct CollectiveMainloopFwd { TMA_Q tma_load_Q = make_tma_copy( GmemTiledCopyQ{}, mQ, - SmemLayoutQ{}, - select<0, 2>(TileShape_MNK{}), + SmemLayoutQCopy{}, + TileShapeQCopy{}, _1{}); // no mcast for Q Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); TMA_K tma_load_K = make_tma_copy( @@ -204,11 +215,13 @@ struct CollectiveMainloopFwd { select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any return {args.layout_Q, args.layout_K, args.layout_V, - cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), + cutlass::FastDivmod(args.qhead_per_khead), tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2, args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr, - args.window_size_left, args.window_size_right}; + args.window_size_left, args.window_size_right, + args.cache_batch_idx, + cutlass::FastDivmod(args.num_splits)}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -220,41 +233,60 @@ struct CollectiveMainloopFwd { } CUTLASS_DEVICE - int get_n_block_max( - Params const& mainloop_params, int m_block, - const Seqlen_traits& seqlen_traits_q, - const Seqlen_traits& seqlen_traits_k + void get_n_block_min_max( + Params const& mainloop_params, + int m_block, + int n_split_idx, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int& n_block_min, + int& n_block_max ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); - int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); - int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal || Is_local) { + // static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + n_block_max = cute::ceil_div(seqlen_k, kBlockN); + + if constexpr(Is_split) { + int const n_blocks_per_split + = mainloop_params.num_splits_divmod.divide(n_block_max + int(mainloop_params.num_splits_divmod) - 1); + n_block_min = n_split_idx * n_blocks_per_split; + n_block_max = std::min(n_block_max, (n_split_idx + 1) * n_blocks_per_split); + } + + if constexpr (Is_causal) { n_block_max = std::min( n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN)); + } else if constexpr (Is_local) { + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN)); + n_block_min = std::max( + n_block_min, + (m_block * kBlockM_div_H + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN); } - return n_block_max; } CUTLASS_DEVICE - int get_n_block_min( - Params const& mainloop_params, int m_block, - const Seqlen_traits& seqlen_traits_q, - const Seqlen_traits& seqlen_traits_k + void get_n_block_max( + Params const& mainloop_params, + int m_block, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int& n_block_max ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); - int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); - if constexpr (!Is_local) { - return 0; - } else { - return std::max( - 0, - (m_block * kBlockM + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN - ); + // static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN)); } } @@ -269,13 +301,15 @@ struct CollectiveMainloopFwd { Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, typename Scheduler::WorkTileInfo& work_tile_info, - cute::tuple block_coord, + cute::tuple block_coord, int work_idx, - const Seqlen_traits& seqlen_traits_q, - const Seqlen_traits& seqlen_traits_k + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int n_block_min, + int n_block_max ) { - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); @@ -283,19 +317,30 @@ struct CollectiveMainloopFwd { Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); - auto [m_block, bidh, bidb] = block_coord; - int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb]; + const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gQ = seqlen_traits_q.get_local_tile_tensor( - mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) + Tensor gQ = [&] { + // Need this inside lambda to capture structured binding + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + if constexpr(Seqlen_traits_Q::UseGQAPacking) { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh_kv, bidb) + (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K) + } else { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K) + } + }(); Tensor gK = seqlen_traits_k.get_local_tile_tensor( - mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) Tensor gV = seqlen_traits_k.get_local_tile_tensor( - mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -314,8 +359,6 @@ struct CollectiveMainloopFwd { } } - const int n_block_min = get_n_block_min(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - const int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); int n_block = n_block_max - 1; int lane_predicate = cute::elect_one_sync(); @@ -337,8 +380,7 @@ struct CollectiveMainloopFwd { // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. - shared_storage.barrier_O.wait((work_idx + 1) % 2); - + if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); } if (lane_predicate) { // CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 2 @@ -353,6 +395,7 @@ struct CollectiveMainloopFwd { ++smem_pipe_write_v; } } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); if (lane_predicate) { pipeline_v.producer_acquire(smem_pipe_write_v); @@ -361,6 +404,7 @@ struct CollectiveMainloopFwd { ++smem_pipe_write_v; } scheduler.broadcast_next_work(work_tile_info); + } template @@ -375,16 +419,18 @@ struct CollectiveMainloopFwd { Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, typename Scheduler::WorkTileInfo& work_tile_info, - cute::tuple block_coord, + cute::tuple block_coord, int work_idx, - const Seqlen_traits& seqlen_traits_q, - const Seqlen_traits& seqlen_traits_k + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int n_block_min, + int n_block_max ) { using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV; using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt; - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); @@ -408,19 +454,30 @@ struct CollectiveMainloopFwd { Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); - auto [m_block, bidh, bidb] = block_coord; - int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + auto [m_block, split_idx, bidh, bidb] = block_coord; + const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb]; + const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gQ = seqlen_traits_q.get_local_tile_tensor( - mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) + Tensor gQ = [&] { + // Need this inside lambda to capture structured binding + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + if constexpr(Seqlen_traits_Q::UseGQAPacking) { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh_kv, bidb) + (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K) + } else { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K) + } + }(); Tensor gK = seqlen_traits_k.get_local_tile_tensor( - mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) Tensor gV = seqlen_traits_k.get_local_tile_tensor( - mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -439,7 +496,6 @@ struct CollectiveMainloopFwd { } } - int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); int n_block = n_block_max - 1; int lane_predicate = cute::elect_one_sync(); @@ -454,142 +510,62 @@ struct CollectiveMainloopFwd { // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - if constexpr(Is_causal) { - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); + if constexpr(!Ktraits::VO_union_all) { pipeline_v.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); } - shared_storage.barrier_O.wait((work_idx + 1) % 2); - - CUTLASS_PRAGMA_UNROLL - for (int iter = 0; iter < kStages && n_block > 0; ++iter, --n_block) { - pipeline_v.consumer_wait(smem_pipe_read); - // pipeline_vt.producer_acquire(smem_pipe_write); - do_transpose_V(smem_pipe_read.index()); - pipeline_vt.producer_commit(smem_pipe_write); - pipeline_v.consumer_release(smem_pipe_read); - - ++smem_pipe_write; - ++smem_pipe_read; - - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index())); - pipeline_v.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index())); - } - } - - #pragma unroll 2 - for (; n_block > 0; --n_block) { - pipeline_v.consumer_wait(smem_pipe_read); - pipeline_vt.producer_acquire(smem_pipe_write); - do_transpose_V(smem_pipe_read.index()); - pipeline_vt.producer_commit(smem_pipe_write); - pipeline_v.consumer_release(smem_pipe_read); - - ++smem_pipe_write; - ++smem_pipe_read; - - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index())); - pipeline_v.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index())); - } - } - - scheduler.prefetch_next_work(scheduler_params, work_tile_info); - scheduler.broadcast_next_work(work_tile_info); - - pipeline_v.consumer_wait(smem_pipe_read); - if (n_block_max > kStages) - pipeline_vt.producer_acquire(smem_pipe_write); - do_transpose_V(smem_pipe_read.index()); - pipeline_vt.producer_commit(smem_pipe_write); - pipeline_v.consumer_release(smem_pipe_read); + } + // With fp8 kernel, smem_o is in union with smem_v_out, + // except for split kernel + hdim 256, + // so could use NamedBarrier instead of ClusterBarrier. + // But, this doesn't appear to have any benefit. + if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); } - ++smem_pipe_write; - ++smem_pipe_read; - } else { + if constexpr(Ktraits::VO_union_all) { if (warp_idx_in_warpgroup == 0 && lane_predicate) { - shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); pipeline_v.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); + tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); } - // With fp8 kernel, smem_o is in union with smem_v_out, - // so could use NamedBarrier instead of ClusterBarrier. - // But, this doesn't appear to have any benefit. - shared_storage.barrier_O.wait((work_idx + 1) % 2); - + } + + #pragma unroll 2 + for (; n_block > n_block_min; --n_block) { pipeline_v.consumer_wait(smem_pipe_read); - // pipeline_vt.producer_acquire(smem_pipe_write); + pipeline_vt.producer_acquire(smem_pipe_write); do_transpose_V(smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); pipeline_v.consumer_release(smem_pipe_read); ++smem_pipe_write; ++smem_pipe_read; - --n_block; - - constexpr int extra_iterations = kStages - 1; - CUTLASS_PRAGMA_UNROLL - for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter) { - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tKgK(_, n_block), tKsK(_, smem_pipe_write.index())); - pipeline_v.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); - } - - pipeline_v.consumer_wait(smem_pipe_read); - // pipeline_vt.producer_acquire(smem_pipe_write); - do_transpose_V(smem_pipe_read.index()); - pipeline_vt.producer_commit(smem_pipe_write); - pipeline_v.consumer_release(smem_pipe_read); - - ++smem_pipe_write; - ++smem_pipe_read; - --n_block; - } + + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv), + tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index())); + pipeline_v.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), + tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index())); + } + } - // CUTLASS_PRAGMA_NO_UNROLL - #pragma unroll 2 - for (; n_block >= 0; --n_block) { - - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tKgK(_, n_block), tKsK(_, smem_pipe_write.index())); - pipeline_v.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); - } - - pipeline_v.consumer_wait(smem_pipe_read); - pipeline_vt.producer_acquire(smem_pipe_write); - do_transpose_V(smem_pipe_read.index()); - pipeline_vt.producer_commit(smem_pipe_write); - pipeline_v.consumer_release(smem_pipe_read); - - ++smem_pipe_write; - ++smem_pipe_read; - } - // scheduler.prefetch_next_work(scheduler_params, work_tile_info); - // scheduler.broadcast_next_work(work_tile_info); - } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + + pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); + do_transpose_V(smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); + pipeline_v.consumer_release(smem_pipe_read); + + ++smem_pipe_write; + ++smem_pipe_read; } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster @@ -635,13 +611,16 @@ struct CollectiveMainloopFwd { CUTLASS_DEVICE void warp_scheduler_barrier_arrive() { - if constexpr (!UseSchedulerBarrier) { return; } - static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); - if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); + if constexpr (!UseSchedulerBarrier) { + return; } else { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); + } else { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + } } } @@ -649,17 +628,19 @@ struct CollectiveMainloopFwd { mma_init() { // Tell producer (warp 0) that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - if constexpr (!UseSchedulerBarrier) { return; } - static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); - if (cutlass::canonical_warp_group_idx() > 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); - } - if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { - if (cutlass::canonical_warp_group_idx() > 2) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/); + if constexpr (!UseSchedulerBarrier) { + return; + } else { + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); + } + if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { + if (cutlass::canonical_warp_group_idx() > 2) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/); + } } } - } template @@ -671,19 +652,20 @@ struct CollectiveMainloopFwd { PipelineState& smem_pipe_read_v, FrgTensorO& tOrO, Softmax& softmax, - int n_block_count, int n_block_min, + int n_block_max, int thread_idx, int work_idx, int m_block, SharedStorage& shared_storage, - const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits_Q& seqlen_traits_q, const Seqlen_traits& seqlen_traits_k ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH; Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); @@ -709,24 +691,26 @@ struct CollectiveMainloopFwd { tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; int const seqlen_q = seqlen_traits_q.actual_seq_len; int const seqlen_k = seqlen_traits_k.actual_seq_len; - int n_block = n_block_count - 1; + int n_block = n_block_max - 1; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); warp_scheduler_barrier_arrive(); - - if (work_idx != 0) { - int lane_predicate = cute::elect_one_sync(); - if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { - tma_store_wait<0>(); - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.barrier_O.arrive(cta_id, lane_predicate); + if constexpr (!No_smem_O) { + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } } } } @@ -735,15 +719,16 @@ struct CollectiveMainloopFwd { ++smem_pipe_read_k; auto col_limit_right = [&](int row, int n_block) { - return std::min( - seqlen_k - n_block * kBlockN, - row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + mainloop_params.window_size_right - ); + int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H; + if constexpr(Is_local) + return col_limit_base + mainloop_params.window_size_right; + else + return col_limit_base; }; auto col_limit_left = [&](int row, int n_block) { return std::max( 0, - row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - mainloop_params.window_size_left + row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left ); }; { @@ -757,14 +742,15 @@ struct CollectiveMainloopFwd { // using std::min is faster than doing col >= limit0 or col >= limit1 // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the // right hand side can be negative and might be converted to a very large unsigned integer. - if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) { + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) { tSrS(i) = -INFINITY; - } else if constexpr (Is_local) { - if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) { + } else if constexpr(Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) { tSrS(i) = -INFINITY; } } - } + } } } @@ -774,7 +760,7 @@ struct CollectiveMainloopFwd { Tensor scores_scale = make_fragment_like(softmax.row_max); clear(scores_scale); - constexpr int n_masking_steps = (!Is_causal) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM_div_H, kBlockN) + 1; // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) { @@ -792,7 +778,8 @@ struct CollectiveMainloopFwd { Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1)) { + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1)) { tSrS(i) = -INFINITY; } } @@ -823,9 +810,10 @@ struct CollectiveMainloopFwd { Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; if ( - int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1) || - int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block - 1) + int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1) || + int(get<1>(tScS(i))) < col_limit_left(row, n_block - 1) ) { tSrS(i) = -INFINITY; } @@ -847,7 +835,7 @@ struct CollectiveMainloopFwd { softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - cute::copy(softmax.template finalize(tSrS), scores_scale); + cute::copy(softmax.template finalize(tSrS), scores_scale); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang ++smem_pipe_read_v; @@ -864,18 +852,21 @@ struct CollectiveMainloopFwd { PipelineState& smem_pipe_release, FrgTensorO& tOrO, Softmax& softmax, - int n_block_count, + int n_block_min, + int n_block_max, int thread_idx, int work_idx, int m_block, SharedStorage& shared_storage, - const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits_Q& seqlen_traits_q, const Seqlen_traits& seqlen_traits_k ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + // static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH; Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); @@ -898,10 +889,9 @@ struct CollectiveMainloopFwd { }; tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - // workaround for fp8 only perf regression pending change to seqlen traits class - int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); - int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); - int n_block = n_block_count - 1; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + int n_block = n_block_max - 1; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } @@ -911,34 +901,50 @@ struct CollectiveMainloopFwd { consumer_wait(pipeline_k, smem_pipe_read); warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - if (work_idx != 0) { - int lane_predicate = cute::elect_one_sync(); - if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { - tma_store_wait<0>(); - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.barrier_O.arrive(cta_id, lane_predicate); - } - } + if constexpr (!No_smem_O) { + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } + } + } } warpgroup_wait<0>(); warp_scheduler_barrier_arrive(); pipeline_k.consumer_release(smem_pipe_read); - auto col_limit_causal = [&](int row, int n_block) { - return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; + auto col_limit_right = [&](int row, int n_block) { + int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H; + if constexpr(Is_local) + return col_limit_base + mainloop_params.window_size_right; + else + return col_limit_base; + }; + auto col_limit_left = [&](int row, int n_block) { + return std::max( + 0, + row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left + ); }; { Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col + if constexpr (!Is_causal && !Is_local) { // Just masking based on col if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } } else { // mask based on both row and col - if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, - col_limit_causal(int(get<0>(tScS(i))), n_block))) { + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) { tSrS(i) = -INFINITY; + } else if constexpr(Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) { + tSrS(i) = -INFINITY; + } } } } @@ -957,11 +963,11 @@ struct CollectiveMainloopFwd { ++smem_pipe_read; --n_block; - constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN); + constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM_div_H, kBlockN); if constexpr(Is_causal) { - CUTLASS_PRAGMA_UNROLL - for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); warp_scheduler_barrier_sync(); @@ -971,7 +977,8 @@ struct CollectiveMainloopFwd { Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block)) { + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block)) { tSrS(i) = -INFINITY; } } @@ -991,12 +998,12 @@ struct CollectiveMainloopFwd { permute_regs_A_to_C(tOrP); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); } + if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); } ++smem_pipe_read; } - } else { + } else if constexpr(!Is_local) { CUTLASS_PRAGMA_UNROLL - for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) { + for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); if constexpr(Delay_V_release) { @@ -1009,9 +1016,9 @@ struct CollectiveMainloopFwd { if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); } else { consumer_wait(pipeline_vt, smem_pipe_read); } - cute::copy(softmax.template max(tSrS), scores_scale); + cute::copy(softmax.template max(tSrS), scores_scale); softmax.rescale_o(tOrO, scores_scale); - softmax.template online_softmax(tSrS); + softmax.template online_softmax(tSrS); Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); permute_regs_A_to_C(tOrP); @@ -1026,17 +1033,33 @@ struct CollectiveMainloopFwd { if constexpr(Delay_V_release) { warp_scheduler_barrier_sync(); CUTLASS_PRAGMA_NO_UNROLL - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if ( + int(get<1>(tScS(i))) >= col_limit_right(row, n_block) || + int(get<1>(tScS(i))) < col_limit_left(row, n_block) + ) { + tSrS(i) = -INFINITY; + } + } + } + warp_scheduler_barrier_arrive(); pipeline_k.consumer_release(smem_pipe_read); pipeline_vt.consumer_release(smem_pipe_release); - cute::copy(softmax.template max(tSrS), scores_scale); + cute::copy(softmax.template max(tSrS), scores_scale); softmax.rescale_o(tOrO, scores_scale); - softmax.template online_softmax(tSrS); + softmax.template online_softmax(tSrS); Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); permute_regs_A_to_C(tOrP); @@ -1052,17 +1075,33 @@ struct CollectiveMainloopFwd { } else { if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); } CUTLASS_PRAGMA_NO_UNROLL - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); } flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if ( + int(get<1>(tScS(i))) >= col_limit_right(row, n_block) || + int(get<1>(tScS(i))) < col_limit_left(row, n_block) + ) { + tSrS(i) = -INFINITY; + } + } + } + warp_scheduler_barrier_arrive(); pipeline_k.consumer_release(smem_pipe_read); - cute::copy(softmax.template max(tSrS), scores_scale); + cute::copy(softmax.template max(tSrS), scores_scale); softmax.rescale_o(tOrO, scores_scale); - softmax.template online_softmax(tSrS); + softmax.template online_softmax(tSrS); Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); permute_regs_A_to_C(tOrP); @@ -1075,7 +1114,7 @@ struct CollectiveMainloopFwd { if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); } } cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - cute::copy(softmax.template finalize(tSrS, shared_storage.descale_v), scores_scale); + cute::copy(softmax.template finalize(tSrS, shared_storage.descale_v), scores_scale); softmax.rescale_o(tOrO, scores_scale); return; } diff --git a/hopper/seq_len.h b/hopper/seq_len.h index 76c4d08a3..63c18679f 100644 --- a/hopper/seq_len.h +++ b/hopper/seq_len.h @@ -4,6 +4,9 @@ #pragma once +#include +#include + #include #include @@ -11,8 +14,11 @@ namespace flash { static constexpr int kMaxTileSize = 128; -template class SeqLenTraits { +template class SeqLenTraits { public: + static_assert(!(UseVarSeqLen_ && UseGQAPacking_), + "Variable sequence length with GQA parallelization not implemented yet."); + // Total number of queries / keys. Unpadded. int sum_s = 0; // seq len offsets. @@ -23,17 +29,26 @@ template class SeqLenTraits { int actual_seq_len = -1; // Whether this is for fixed-seq-len or var-seq-len. - static constexpr bool kUseVarSeqLen = UseVarSeqLen; + static constexpr bool UseVarSeqLen = UseVarSeqLen_; + static constexpr bool UseGQAPacking = UseGQAPacking_; using ShapeT = std::conditional_t< UseVarSeqLen, - cute::Shape, - cute::Shape + cute::Shape, + std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + > >; using StrideT = std::conditional_t< UseVarSeqLen, cute::Shape, - cute::Shape + std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + > >; using LayoutT = cute::Layout; @@ -49,33 +64,102 @@ template class SeqLenTraits { >; using LayoutLseT = cute::Layout; + // Not used for varseqlen + using ShapeOAccumT = std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + >; + using StrideOAccumT = std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + >; + using LayoutOAccumT = cute::Layout; + + using ShapeLseAccumT = cute::Shape; + using StrideLseAccumT = cute::Shape; + using LayoutLseAccumT = cute::Layout; + CUTLASS_HOST SeqLenTraits() {} CUTLASS_HOST SeqLenTraits( int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr): sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {} + CUTLASS_DEVICE void init(int bidb) { + // TODO: add leftpad, seqlen_new for kv cache support + if (seq_used) { + actual_seq_len = seq_used[bidb]; + } + } + + CUTLASS_DEVICE void init_no_guard(int bidb) { + actual_seq_len = seq_used[bidb]; + } + // Returns the layout of a tensor in MKHB format in global memory. // padded: only useful for var-seq-len for dq_accum and softmax_d. CUTLASS_HOST_DEVICE auto get_gmem_layout( int m, int k, int h, int b, int64_t m_stride, int64_t h_stride, int64_t b_stride, bool padded = false) const { - static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen."); + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); + // static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking."); return make_layout(make_shape(m, k, h, b), make_stride(m_stride, cute::_1{}, h_stride, b_stride)); } // Returns the layout of a tensor in MKHB format in global memory. // padded: only useful for var-seq-len for dq_accum and softmax_d. + // Overload that separates h into h_k and h/h_k. + CUTLASS_HOST_DEVICE auto get_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + bool padded = false) const { + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); + static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking."); + return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b), + make_stride(m_stride, cute::_1{}, h_stride, b_stride)); + } + + // Returns the layout of a tensor in MKHBT format in global memory, + // where T is number of splits. + CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout( + int m, int k, int h, int b, int num_splits, + int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride, + bool padded = false) const { + return make_layout(make_shape(m, k, h, b, num_splits), + make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride)); + } + + // Returns the layout of a tensor in MKHBT format in global memory, + // where T is number of splits. + // Overload that separates h into h_k and h/h_k. + CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits, + int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride, + bool padded = false) const { + return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b, num_splits), + make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride)); + } + + // Returns the layout of lse tensor in BHM format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. CUTLASS_HOST_DEVICE auto get_lse_gmem_layout( int m, int h, int b, bool padded = false) const { - static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen."); + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); return make_layout(make_shape(b, h, m), make_stride(int64_t(h * m), int64_t(m), cute::_1())); } - CUTLASS_DEVICE void init(int bidb) {} + // Returns the layout of lse tensor in TBHM format in global memory, + // where T is number of splits. + CUTLASS_HOST_DEVICE auto get_lseaccum_gmem_layout( + int m, int h, int b, int num_splits, bool padded = false) const { + return make_layout(make_shape(num_splits, b, h, m), + make_stride(int64_t(b * h * m), int64_t(h * m), int64_t(m), cute::_1())); + } template CUTLASS_DEVICE auto get_local_tile_tensor( @@ -86,18 +170,57 @@ template class SeqLenTraits { return g_tensor; } - template + template CUTLASS_DEVICE auto get_lse_local_tile_tensor( const MTensor &m_tensor, const Shape &tile_shape, - int bidh, int bidb, bool padded = false) const { - auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_)); - return g_tensor; + int bidh, int bidb, int n_split_idx, bool padded = false) const { + // m_tensor has shape (B, H, M) or (splits, B, H, M) + // Expect tile shape (bM) + // Returns g_tensor of shape = (bM, ceil_div(M,bM)) + if constexpr(!Is_split) { + auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_)); + return g_tensor; + } else { + auto g_tensor = local_tile(m_tensor(n_split_idx, bidb, bidh, _), tile_shape, make_coord(_)); + return g_tensor; + } } + + template + CUTLASS_DEVICE auto get_o_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, int split_idx, bool padded = false) const { + // static_assert(!UseVarSeqLen, "Don't use get_o_local_tile_tensor with VarSeqLen."); + // m_tensor has shape (M, K, H, B) or (M, K, H, B, splits) + // Expect tile shape (bM, K) + // Returns g_tensor of shape = (bM, K, ceil_div(M,bM)) + if constexpr(!Is_split) { + auto g_tensor = local_tile( + m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{})); + return g_tensor; + } else { + auto g_tensor = local_tile( + m_tensor(_, _, bidh, bidb, split_idx), tile_shape, make_coord(_, _0{})); + return g_tensor; + } + } + }; -using FixedSeqLenTraits = SeqLenTraits; +using FixedSeqLenTraits = SeqLenTraits; +using VarSeqLenTraits = SeqLenTraits; +using FixedGQASeqLenTraits = SeqLenTraits; + +template <> +CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) { + actual_seq_len = + seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]); +} -using VarSeqLenTraits = SeqLenTraits; +template <> +CUTLASS_DEVICE void FixedGQASeqLenTraits::init(int bidb) { + // no op +} // Returns the static layout of a var-seq-len tensor in global memory based on // max_seq_len and max_batch_size. @@ -113,6 +236,16 @@ CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout( make_stride(m_stride, cute::_1{}, h_stride)); } +template <> +CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + bool padded) const { + return make_layout( + make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h_k * h_h_k_ratio), + make_stride(m_stride, cute::_1{}, h_stride)); +} + // padded: only useful for var-seq-len for dq_accum and softmax_d. // When padded is True, use B_M + kMaxTileSize * B as the total B_M. template <> @@ -123,12 +256,6 @@ CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout( make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1())); } -template <> -CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) { - actual_seq_len = - seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]); -} - template <> template CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor( @@ -148,11 +275,33 @@ CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor( return g_tensor; } +// TODO: restructure to not duplicate code template <> -template +template +CUTLASS_DEVICE auto VarSeqLenTraits::get_o_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, int n_split_idx, bool padded) const { + static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits"); + auto g_offset = local_tile( + m_tensor(_, _, bidh), + cute::make_shape(1, get<1>(tile_shape)), + make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{})); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout( + cute::make_shape(actual_seq_len, get<1>(tile_shape)), + g_offset.stride() + )); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +template <> +template CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor( const MTensor &m_tensor, const Shape &tile_shape, - int bidh, int bidb, bool padded) const { + int bidh, int bidb, int n_split_idx, bool padded) const { + static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits"); auto g_offset = local_tile( m_tensor(bidh, _), cute::make_shape(_1{}), make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0))); @@ -163,6 +312,60 @@ CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor( return g_tensor; } +// Returns layout of QO tensor in (M,H/HK,K,HK,B) format in global memory. +template <> +CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride, bool padded) const { + return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b), + make_stride(m_stride, h_stride, cute::_1{}, + h_stride * h_h_k_ratio, b_stride)); +} + +// Returns layout of Oaccum tensor in (M,H/HK,K,HK,B,T) format in global memory. +template <> +CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_oaccum_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits, + int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride, + bool padded) const { + return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b, num_splits), + make_stride(m_stride, h_stride, cute::_1{}, + h_stride * h_h_k_ratio, b_stride, + split_stride)); +} + +template <> +template +CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh_kv, int bidb, bool padded) const { + // m_tensor has shape (M, H/H_K, K, H_K, B) + // Expect tile_shape (bM/bH, bH, K) + // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH)) + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{})); + return g_tensor; +} + +template <> +template +CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_o_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh_kv, int bidb, int split_idx, bool padded) const { + // m_tensor has shape (M, H/H_K, K, H_K, B) or (M, H/H_K, K, H_K, B, splits) + // Expect tile_shape (bM/bH, bH, K) + // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH)) + if constexpr(!Is_split) { + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{})); + return g_tensor; + } else { + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb, split_idx), tile_shape, make_coord(_, _, _0{})); + return g_tensor; + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash diff --git a/hopper/setup.py b/hopper/setup.py index 029455a3e..f9f3cfd25 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -125,7 +125,52 @@ def append_nvcc_threads(nvcc_extra_args): "flash_bwd_hdim128_bf16_sm90.cu", "flash_fwd_hdim64_e4m3_sm90.cu", "flash_fwd_hdim128_e4m3_sm90.cu", - "flash_fwd_hdim256_e4m3_sm90.cu" + "flash_fwd_hdim256_e4m3_sm90.cu", + "flash_fwd_hdim64_fp16_gqa2_sm90.cu", + "flash_fwd_hdim64_fp16_gqa4_sm90.cu", + "flash_fwd_hdim64_fp16_gqa8_sm90.cu", + "flash_fwd_hdim64_fp16_gqa16_sm90.cu", + "flash_fwd_hdim64_fp16_gqa32_sm90.cu", + "flash_fwd_hdim128_fp16_gqa2_sm90.cu", + "flash_fwd_hdim128_fp16_gqa4_sm90.cu", + "flash_fwd_hdim128_fp16_gqa8_sm90.cu", + "flash_fwd_hdim128_fp16_gqa16_sm90.cu", + "flash_fwd_hdim128_fp16_gqa32_sm90.cu", + "flash_fwd_hdim256_fp16_gqa2_sm90.cu", + "flash_fwd_hdim256_fp16_gqa4_sm90.cu", + "flash_fwd_hdim256_fp16_gqa8_sm90.cu", + "flash_fwd_hdim256_fp16_gqa16_sm90.cu", + "flash_fwd_hdim256_fp16_gqa32_sm90.cu", + "flash_fwd_hdim64_bf16_gqa2_sm90.cu", + "flash_fwd_hdim64_bf16_gqa4_sm90.cu", + "flash_fwd_hdim64_bf16_gqa8_sm90.cu", + "flash_fwd_hdim64_bf16_gqa16_sm90.cu", + "flash_fwd_hdim64_bf16_gqa32_sm90.cu", + "flash_fwd_hdim128_bf16_gqa2_sm90.cu", + "flash_fwd_hdim128_bf16_gqa4_sm90.cu", + "flash_fwd_hdim128_bf16_gqa8_sm90.cu", + "flash_fwd_hdim128_bf16_gqa16_sm90.cu", + "flash_fwd_hdim128_bf16_gqa32_sm90.cu", + "flash_fwd_hdim256_bf16_gqa2_sm90.cu", + "flash_fwd_hdim256_bf16_gqa4_sm90.cu", + "flash_fwd_hdim256_bf16_gqa8_sm90.cu", + "flash_fwd_hdim256_bf16_gqa16_sm90.cu", + "flash_fwd_hdim256_bf16_gqa32_sm90.cu", + "flash_fwd_hdim64_e4m3_gqa2_sm90.cu", + "flash_fwd_hdim64_e4m3_gqa4_sm90.cu", + "flash_fwd_hdim64_e4m3_gqa8_sm90.cu", + "flash_fwd_hdim64_e4m3_gqa16_sm90.cu", + "flash_fwd_hdim64_e4m3_gqa32_sm90.cu", + "flash_fwd_hdim128_e4m3_gqa2_sm90.cu", + "flash_fwd_hdim128_e4m3_gqa4_sm90.cu", + "flash_fwd_hdim128_e4m3_gqa8_sm90.cu", + "flash_fwd_hdim128_e4m3_gqa16_sm90.cu", + "flash_fwd_hdim128_e4m3_gqa32_sm90.cu", + "flash_fwd_hdim256_e4m3_gqa2_sm90.cu", + "flash_fwd_hdim256_e4m3_gqa4_sm90.cu", + "flash_fwd_hdim256_e4m3_gqa8_sm90.cu", + "flash_fwd_hdim256_e4m3_gqa16_sm90.cu", + "flash_fwd_hdim256_e4m3_gqa32_sm90.cu", ] nvcc_flags = [ "-O3", diff --git a/hopper/static_switch.h b/hopper/static_switch.h index d9ec62224..57dcb8c5f 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -27,41 +27,30 @@ } \ }() -#define PREC_SWITCH(PRECTYPE, ...) \ +#define PREC_SWITCH(PRECTYPE, NAME, ...) \ [&] { \ - if (PRECTYPE == 1) { \ - using kPrecType = cutlass::half_t; \ - constexpr static bool kSoftFp16 = false; \ - constexpr static bool kHybrid = false; \ + if (PRECTYPE == 3) { \ + using NAME = cutlass::float_e4m3_t; \ return __VA_ARGS__(); \ } else if (PRECTYPE == 2) { \ - using kPrecType = cutlass::float_e4m3_t; \ - constexpr static bool kSoftFp16 = false; \ - constexpr static bool kHybrid = false; \ + using NAME = cutlass::bfloat16_t; \ return __VA_ARGS__(); \ - } else if (PRECTYPE == 3) { \ - using kPrecType = cutlass::float_e4m3_t; \ - constexpr static bool kSoftFp16 = false; \ - constexpr static bool kHybrid = true; \ - return __VA_ARGS__(); \ - } else if (PRECTYPE == 4) { \ - using kPrecType = cutlass::float_e4m3_t; \ - constexpr static bool kSoftFp16 = true; \ - constexpr static bool kHybrid = false; \ + } else { \ + using NAME = cutlass::half_t; \ return __VA_ARGS__(); \ } \ }() -#define HEADDIM_SWITCH(HEADDIM, ...) \ +#define HEADDIM_SWITCH(HEADDIM, CONST_NAME, ...) \ [&] { \ if (HEADDIM == 64) { \ - constexpr static int kHeadSize = 64; \ + constexpr static int CONST_NAME = 64; \ return __VA_ARGS__(); \ } else if (HEADDIM == 128) { \ - constexpr static int kHeadSize = 128; \ + constexpr static int CONST_NAME = 128; \ return __VA_ARGS__(); \ - } else if (HEADDIM == 256) { \ - constexpr static int kHeadSize = 256; \ + } else { \ + constexpr static int CONST_NAME = 256; \ return __VA_ARGS__(); \ } \ }() @@ -76,4 +65,94 @@ using NAME = flash::FixedSeqLenTraits; \ return __VA_ARGS__(); \ } \ - }() + }() + +#define SEQLEN_SWITCH_FWD(VAR_SEQ_LEN_Q, SEQ_USED_K, NAME_Q, NAME_K, ...) \ + [&] { \ + bool useVarSeqLenQ = VAR_SEQ_LEN_Q; \ + bool useSeqUsedK = SEQ_USED_K; \ + if (useVarSeqLenQ) { \ + using NAME_Q = flash::VarSeqLenTraits; \ + using NAME_K = flash::VarSeqLenTraits; \ + return __VA_ARGS__(); \ + } else if (useSeqUsedK) { \ + using NAME_Q = flash::FixedSeqLenTraits; \ + using NAME_K = flash::FixedSeqLenTraitsDynamic; \ + return __VA_ARGS__(); \ + } else { \ + using NAME_Q = flash::FixedSeqLenTraits; \ + using NAME_K = flash::FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + }() + +#define QUERYHEAD_SWITCH(QUERYHEADS, CONST_NAME, ...) \ + [&] { \ + if (QUERYHEADS <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (QUERYHEADS <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (QUERYHEADS <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (QUERYHEADS <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MMA_3WG_SWITCH(QLEN, CONST_NAME, ...) \ + [&] { \ + if (QLEN <= 64) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (QLEN <= 128) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 3; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MMA_2WG_SWITCH(QLEN, CONST_NAME, ...) \ + [&] { \ + if (QLEN <= 64) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } \ + }() + +#define NUM_SPLITS_SWITCH(NUM_SPLITS, LOG_MAX_SPLITS, ...) \ + [&] { \ + if (NUM_SPLITS <= 2) { \ + constexpr static int LOG_MAX_SPLITS = 1; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 4) { \ + constexpr static int LOG_MAX_SPLITS = 2; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 8) { \ + constexpr static int LOG_MAX_SPLITS = 3; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 16) { \ + constexpr static int LOG_MAX_SPLITS = 4; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 32) { \ + constexpr static int LOG_MAX_SPLITS = 5; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int LOG_MAX_SPLITS = 6; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int LOG_MAX_SPLITS = 7; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_attn_kvcache.py b/hopper/test_attn_kvcache.py new file mode 100644 index 000000000..726d44ce3 --- /dev/null +++ b/hopper/test_attn_kvcache.py @@ -0,0 +1,486 @@ +import pytest +from einops import rearrange, repeat +import torch +import flash_attn +import flash_attn_interface +import itertools +import math +import time + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("num_requests", [1, 4]) +@pytest.mark.parametrize("query_seqlen", [1, 8, 120]) +@pytest.mark.parametrize("context_seqlen", [1024, 3131, 4224]) +@pytest.mark.parametrize("headdim", [64, 128, 256]) +@pytest.mark.parametrize("gqa_parallel", [False, True]) +@pytest.mark.parametrize( + "nheads_kv, gqa_ratio", + [ + (1, 1), + (2, 5), + (3, 3), + (1, 32), + (5, 7), + (8, 1), + (1, 16), + (12, 4), + (8, 2), + ], +) +def test_flash_attn_kvcache_nosplit(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel): + device = "cuda" + num_caches = num_requests + cache_seqlen = context_seqlen + nheads_q = nheads_kv * gqa_ratio + + k_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + v_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) + # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] + cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") + torch.cuda.synchronize() + + out_ref, _ = attention_ref( + q, + k_cache, + v_cache, + causal=causal, + ) + + out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + # cache_batch_idx=cache_idxs, + causal=causal, + num_splits=1, + return_softmax_lse=True, + gqa_parallel=gqa_parallel + ) + + + torch.cuda.synchronize() + assert ((out_ref - out_fa3).abs().max().item() <= 4e-3) + assert ((out_ref - out_fa3).abs().mean().item() <= 2e-4) + + +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("num_requests", [1, 3]) +@pytest.mark.parametrize("query_seqlen", [1, 8, 120]) +@pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555]) +@pytest.mark.parametrize("headdim", [64, 128, 256]) +@pytest.mark.parametrize("gqa_parallel", [True, False]) +@pytest.mark.parametrize( + "nheads_kv, gqa_ratio", + [ + (1, 1), + (2, 5), + (3, 3), + (1, 32), + (5, 7), + (8, 1), + (1, 16), + (12, 4), + (8, 2), + ], +) +def test_flash_attn_kvcache_nosplit_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel): + device = "cuda" + num_caches = num_requests + cache_seqlen = context_seqlen + nheads_q = nheads_kv * gqa_ratio + + k_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + v_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) + q = q.to(torch.float8_e4m3fn) + k_cache = k_cache.to(torch.float8_e4m3fn) + v_cache = v_cache.to(torch.float8_e4m3fn) + # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] + cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") + torch.cuda.synchronize() + + out_ref, _ = attention_ref( + q, + k_cache, + v_cache, + causal=causal, + ) + + descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda') + descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda') + descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda') + out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + # cache_batch_idx=cache_idxs, + causal=causal, + num_splits=1, + return_softmax_lse=True, + gqa_parallel=gqa_parallel, + descale_q=descale_q, descale_k=descale_k, descale_v=descale_v + ) + + + torch.cuda.synchronize() + assert ((out_ref - out_fa3).abs().max().item() <= 4e-2) + assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3) + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_heuristic_only", [True]) +# @pytest.mark.parametrize("use_heuristic_only", [False]) +@pytest.mark.parametrize("causal", [True, False]) +# @pytest.mark.parametrize("num_requests", [1, 4, 16]) +@pytest.mark.parametrize("num_requests", [1, 3]) +# @pytest.mark.parametrize("query_seqlen", [1, 16, 32, 128]) +@pytest.mark.parametrize("query_seqlen", [1, 8, 25]) +# @pytest.mark.parametrize("context_seqlen", [4096, 16384, 65536]) +@pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555]) +@pytest.mark.parametrize("headdim", [64, 128, 256]) +@pytest.mark.parametrize("cache_seqlen_rand", [True, False]) +@pytest.mark.parametrize("gqa_parallel", [True, False]) +@pytest.mark.parametrize( + "nheads_kv, gqa_ratio", + [ + (1, 1), + (4, 1), + (2, 2), + (3, 3), + (4, 4), + (2, 5), + (3, 9), + (1, 16), + (1, 32), + ], +) +def test_flash_attn_kvcache_output(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype): + device = "cuda" + num_caches = 16 + if context_seqlen <= 65536: + cache_seqlen = 65536 + else: + cache_seqlen = context_seqlen + nheads_q = nheads_kv * gqa_ratio + if use_heuristic_only: + max_splits = 1 + else: + max_splits = 128 + + k_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + v_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) + + q = q.to(dtype) + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] + cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") + torch.cuda.synchronize() + + out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + num_splits=1, + return_softmax_lse=True, + gqa_parallel=False + ) + + # i=0 case is with num splits heuristic + for i in range(0, max_splits+1): + out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + num_splits=i, + return_softmax_lse=True, + gqa_parallel=gqa_parallel, + max_seqlen_k_hint=context_seqlen + ) + + torch.cuda.synchronize() + print ('output-ref', i, out_ref) + print ('output-fa3',i, out_fa3) + print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item()) + print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item()) + print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item()) + print ('lse-mean-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().mean().item()) + + if cache_seqlen_rand: + assert ((out_ref - out_fa3).abs().max().item() <= 1e-2) + assert ((out_ref - out_fa3).abs().mean().item() <= 1e-3) + else: + assert ((out_ref - out_fa3).abs().max().item() <= 2e-3) + assert ((out_ref - out_fa3).abs().mean().item() <= 1e-4) + lse_max_ref = lse_ref.abs().max().item() + lse_mean_ref = lse_ref.abs().mean().item() + lse_max_fa3 = lse_fa3.abs().max().item() + lse_mean_fa3 = lse_fa3.abs().mean().item() + lse_max_diff = (lse_ref - lse_fa3).abs().max().item() + lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item() + assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3) + assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4) + + + +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("use_heuristic_only", [True]) +# @pytest.mark.parametrize("use_heuristic_only", [False]) +@pytest.mark.parametrize("causal", [True, False]) +# @pytest.mark.parametrize("num_requests", [1, 4, 16]) +@pytest.mark.parametrize("num_requests", [1, 3]) +# @pytest.mark.parametrize("query_seqlen", [1, 16, 32, 128]) +@pytest.mark.parametrize("query_seqlen", [1, 8, 25]) +# @pytest.mark.parametrize("context_seqlen", [4096, 16384, 65536]) +@pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555]) +@pytest.mark.parametrize("headdim", [64, 128, 256]) +@pytest.mark.parametrize("cache_seqlen_rand", [True, False]) +@pytest.mark.parametrize("gqa_parallel", [True, False]) +@pytest.mark.parametrize( + "nheads_kv, gqa_ratio", + [ + (1, 1), + (4, 1), + (2, 2), + (3, 3), + (4, 4), + (2, 5), + (3, 9), + (1, 16), + (1, 32), + ], +) +def test_flash_attn_kvcache_output_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype): + device = "cuda" + num_caches = 16 + if context_seqlen <= 65536: + cache_seqlen = 65536 + else: + cache_seqlen = context_seqlen + nheads_q = nheads_kv * gqa_ratio + if use_heuristic_only: + max_splits = 1 + else: + max_splits = 128 + + k_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + v_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 + ) + q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) + + q = q.to(dtype) + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] + cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") + torch.cuda.synchronize() + + + descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda') + descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda') + descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda') + + out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + num_splits=1, + return_softmax_lse=True, + gqa_parallel=False, + descale_q=descale_q, descale_k=descale_k, descale_v=descale_v + ) + + # i=0 case is with num splits heuristic + for i in range(0, max_splits+1): + out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_idxs, + causal=causal, + num_splits=i, + return_softmax_lse=True, + gqa_parallel=gqa_parallel, + max_seqlen_k_hint=context_seqlen, + descale_q=descale_q, descale_k=descale_k, descale_v=descale_v + ) + + torch.cuda.synchronize() + print ('output-ref', i, out_ref) + print ('output-fa3',i, out_fa3) + print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item()) + print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item()) + print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item()) + print ('lse-mean-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().mean().item()) + + if cache_seqlen_rand: + assert ((out_ref - out_fa3).abs().max().item() <= 1e-1) + assert ((out_ref - out_fa3).abs().mean().item() <= 1e-2) + else: + assert ((out_ref - out_fa3).abs().max().item() <= 2e-2) + assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3) + lse_max_ref = lse_ref.abs().max().item() + lse_mean_ref = lse_ref.abs().mean().item() + lse_max_fa3 = lse_fa3.abs().max().item() + lse_mean_fa3 = lse_fa3.abs().mean().item() + lse_max_diff = (lse_ref - lse_fa3).abs().max().item() + lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item() + assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3) + assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4) + + +if __name__ == "__main__": + main() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index dc8b35baf..bc3f6e5a7 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -23,24 +23,153 @@ def print_diffs(out, out_ref): print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}") +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("gqa_parallel", [False, True]) +@pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize("descale", [1.0]) +@pytest.mark.parametrize("descale", [1.0, 2.0, 3.0]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +def test_flash_attn_output_fp8( + seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale, gqa_parallel +): + device = "cuda" + dtype_init = torch.bfloat16 + print(dtype) + print('causal',causal) + print('local',local) + print('gqa_parallel',gqa_parallel) + # set seed + torch.random.manual_seed(42) + # batch_size = 40 + # nheads = 16 + batch_size = 4 + nheads = 6 + nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # nheads_kv = 1 + # batch_size = 9 + # nheads = 6 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) + + q = q.to(dtype) + k = k.to(dtype) + v = v.to(dtype) + + softmax_scale = q.shape[-1] ** (-0.5) + descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda') + descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') + descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') + + out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic, gqa_parallel=gqa_parallel, + descale_q=descale_q, descale_k=descale_k, descale_v=descale_v) + + q = q.to(dtype_init) + k = k.to(dtype_init) + v = v.to(dtype_init) + + descale_q = descale_q.to(dtype_init) + descale_k = descale_k.to(dtype_init) + descale_v = descale_v.to(dtype_init) + q = q * descale_q + k = k * descale_k + v = v * descale_v + + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q, k).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 1e-2 + atol = 4 * (out_pt - out_ref).abs().max().item() + 1e-2 + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=atol, check_dtype=False) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("gqa_parallel", [False, True]) +# @pytest.mark.parametrize("gqa_parallel", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [64, 96, 128]) -# @pytest.mark.parametrize("d", [256]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("descale", [1.0]) # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0]) @@ -48,7 +177,6 @@ def print_diffs(out, out_ref): "seqlen_q,seqlen_k", [ (1, 1), - # (257, 1), (64, 128), (128, 128), (256, 256), @@ -64,26 +192,30 @@ def print_diffs(out, out_ref): (1023, 1024), (1024, 1023), (4096, 4096), + (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale + seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale, gqa_parallel ): device = "cuda" if(dtype == torch.float8_e4m3fn): - dtype_init = torch.float16 + dtype_init = torch.bfloat16 else: dtype_init = dtype print(dtype) + print('causal',causal) + print('local',local) + print('gqa_parallel',gqa_parallel) # set seed - torch.random.manual_seed(0) + torch.random.manual_seed(42) # batch_size = 40 # nheads = 16 batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - # nheads_kv = 2 + # nheads_kv = 1 # batch_size = 9 # nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) @@ -99,12 +231,12 @@ def test_flash_attn_output( descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') + if(dtype != torch.float8_e4m3fn): - out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) + out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic, gqa_parallel=gqa_parallel) else: - out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( - q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v - ) + out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic, gqa_parallel=gqa_parallel, + descale_q=descale_q, descale_k=descale_k, descale_v=descale_v) q = q.to(dtype_init) k = k.to(dtype_init) @@ -189,7 +321,7 @@ def test_flash_attn_output( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 else: # just test correctness of fp8 kernel w/o further quantization techniques - assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 2e-2 if d <= 128 and dtype != torch.float8_e4m3fn: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5 diff --git a/hopper/test_kvcache.py b/hopper/test_kvcache.py new file mode 100644 index 000000000..7764bd5d5 --- /dev/null +++ b/hopper/test_kvcache.py @@ -0,0 +1,234 @@ +import torch +#from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache +import flash_attn_interface as fa3 +import flash_attn as fa2 +import torch.utils.benchmark as benchmark +import time + +import argparse +import math + +parser = argparse.ArgumentParser(description='Process some integers.') +parser.add_argument('--causal', action='store_true') +parser.add_argument('--splits', type=int, default=1) +parser.add_argument('--repeats', type=int, default=10) +parser.add_argument('--validate', action='store_true') +parser.add_argument('--gqa', action='store_true') + +args = parser.parse_args() + +def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, '- Forward pass') + t = benchmark.Timer( + stmt='fn(**kwinputs)', + globals={'fn': fn, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(desc, m) + return t, m + +def benchmark_fa_kv(fn, repeats=10, *args, **kwargs): + # warmup + for _ in range(5): + fn(*args, **kwargs) + niters = repeats + torch.cuda.synchronize() + start = time.time() + for _ in range(niters): + fn(*args, **kwargs) + torch.cuda.synchronize() + end = time.time() + return (end - start) / niters + +def main(): + # *SAMPLE CONFIG* + # Model arch params: + nheads_q = 64 + nheads_kv = 8 + headdim = 128 + #dtype = torch.bfloat16 + dtype = torch.float16 + + # Cache settings: + num_caches = 8 + cache_seqlen = 1024 * 16 + + # Batching settings + ntokens = 1024 + max_queries_per_batch = 4 + small_request_ntokens = 16 + + # Input settings + query_seqlens = [900, 12, 1] + num_queries = len(query_seqlens) + # Need to add empty queries to fill out `max_queries_per_batch` + num_padding_queries = max_queries_per_batch - num_queries + context_seqlens = [4096, 5120*2, 6145*2] + #context_seqlens = [4096, 5120*2, 6152*2] + + # Validation + assert sum(query_seqlens) <= ntokens + assert all(s < small_request_ntokens for s in query_seqlens[1:]) + assert num_queries <= max_queries_per_batch + assert all(s < cache_seqlen for s in context_seqlens) + + torch.manual_seed(5434) + + # Allocate some tensors + k_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype + ) + v_cache = torch.randn( + (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype + ) + + q_buf_large = torch.randn( + (1, ntokens, nheads_q, headdim), device="cuda", dtype=dtype + ) + cache_seqlen_large = torch.tensor( + [context_seqlens[0]], dtype=torch.int32, device="cuda" + ) + cache_idx_large = torch.tensor([1], dtype=torch.int32, device="cuda") + + q_buf_small = torch.randn( + (max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim), + device="cuda", + dtype=dtype, + ) + cache_seqlens_small = torch.tensor( + context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device="cuda" + ) + cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[ + : max_queries_per_batch - 1 + ] + + if args.validate: + # Call flash attn + # First for the single full-sized query + out0, lse0 = fa3.flash_attn_with_kvcache( + q=q_buf_large, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlen_large, + cache_batch_idx=cache_idx_large, + causal=bool(args.causal), + num_splits=args.splits, + return_softmax_lse=True, + #num_splits=1 + ) + + # Second for n-1 small queries + out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache( + q=q_buf_small, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens_small, + cache_batch_idx=cache_idxs_small, + causal=bool(args.causal), + num_splits=1, + gqa_decoding=bool(args.gqa), + return_softmax_lse=True, + ) + + # Second for n-1 small queries + out1, lse1 = fa3.flash_attn_with_kvcache( + q=q_buf_small, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens_small, + cache_batch_idx=cache_idxs_small, + causal=bool(args.causal), + num_splits=args.splits, + gqa_decoding=bool(args.gqa), + return_softmax_lse=True, + ) + + # Call flash attn + # First for the single full-sized query + out2 = fa2.flash_attn_with_kvcache( + q=q_buf_large, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlen_large, + cache_batch_idx=cache_idx_large, + causal=bool(args.causal), + num_splits=args.splits, + ) + + print ('big') + print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small) + print ('diff-mean', (out0 - out2).abs().mean().item()) + + + # Second for n-1 small queries + out3, lse_fa2 = fa2.flash_attn_with_kvcache( + q=q_buf_small, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens_small, + cache_batch_idx=cache_idxs_small, + causal=bool(args.causal), + num_splits=args.splits, + return_softmax_lse=True, + #num_splits=1 + ) + + print ('small') #, out1) + print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape) + print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item()) + print ('diff-max', (out1 - out3).abs().max().item()) + print ('diff-mean', (out1 - out3).abs().mean().item()) + + + print ('fa3', args.repeats) + time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats, + q=q_buf_large, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlen_large, + cache_batch_idx=cache_idx_large, + causal=bool(args.causal), + num_splits=args.splits, + ) + + time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats, + q=q_buf_small, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens_small, + cache_batch_idx=cache_idxs_small, + causal=bool(args.causal), + num_splits=args.splits, + ) + + print ('fa2 ') + + time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats, + q=q_buf_large, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlen_large, + cache_batch_idx=cache_idx_large, + causal=bool(args.causal), + num_splits=args.splits + ) + + time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats, + q=q_buf_small, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens_small, + cache_batch_idx=cache_idxs_small, + causal=bool(args.causal), + num_splits=args.splits + ) + + print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big) + print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small) + +if __name__ == "__main__": + main() diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 2fbb417e4..5cc136f72 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -19,7 +19,7 @@ struct SingleTileScheduler { // Host side kernel arguments struct Arguments { - int const num_blocks_m, num_head, num_batch; + int const num_blocks_m, num_splits, num_head, num_batch; int* const tile_count_semaphore = nullptr; }; @@ -49,9 +49,9 @@ struct SingleTileScheduler { } CUTLASS_DEVICE - cute::tuple + cute::tuple get_block_coord(Params const& params) const { - return {M_idx, H_idx, B_idx}; + return {M_idx, 1, H_idx, B_idx}; } }; @@ -88,26 +88,31 @@ struct SingleTileScheduler { /////////////////////////////////////////////////////////////////////////////// +template class StaticPersistentTileScheduler { public: // Host side kernel arguments struct Arguments { - int const num_blocks_m, num_head, num_batch; + int const num_blocks_m, num_splits, num_head, num_batch; int* const tile_count_semaphore = nullptr; }; // Device side kernel params struct Params { - int total_blocks; - cutlass::FastDivmod m_block_divmod, head_divmod; + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod; }; static Params to_underlying_arguments(Arguments const& args) { - return {args.num_blocks_m * args.num_head * args.num_batch, - cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)}; + // return {args.num_blocks_m * args.num_head * args.num_batch, + // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)}; + return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), + cutlass::FastDivmod(args.num_splits), + cutlass::FastDivmod(args.num_head)}; } static dim3 @@ -125,11 +130,19 @@ class StaticPersistentTileScheduler { } CUTLASS_DEVICE - cute::tuple + cute::tuple get_block_coord(Params const& params) const { - int m_block, bidh, bidb; - bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); - return {m_block, bidh, bidb}; + int m_block, split_idx, bidh, bidb; + if constexpr(!Is_split) { + bidb = params.head_divmod.divmod(bidh, + params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, 1, bidh, bidb}; + } else { + bidb = params.head_divmod.divmod(bidh, + params.split_divmod.divmod(split_idx, + params.m_block_divmod.divmod(m_block, tile_idx))); + return {m_block, split_idx, bidh, bidb}; + } } }; @@ -164,7 +177,9 @@ class StaticPersistentTileScheduler { }; -template +template class DynamicPersistentTileScheduler { protected: @@ -174,21 +189,26 @@ class DynamicPersistentTileScheduler { // Host side kernel arguments struct Arguments { - int const num_blocks_m, num_head, num_batch; + int const num_blocks_m, num_splits, num_head, num_batch; int* const tile_count_semaphore; }; // Device side kernel params struct Params { - int const total_blocks; - cutlass::FastDivmod const m_block_divmod, head_divmod; + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod; int* const tile_count_semaphore; }; static Params to_underlying_arguments(Arguments const& args) { - return {args.num_blocks_m * args.num_head * args.num_batch, - cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head), + // return {args.num_blocks_m * args.num_head * args.num_batch, + // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head), + // args.tile_count_semaphore}; + return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), + cutlass::FastDivmod(args.num_splits), + cutlass::FastDivmod(args.num_head), args.tile_count_semaphore}; } @@ -207,11 +227,19 @@ class DynamicPersistentTileScheduler { } CUTLASS_DEVICE - cute::tuple + cute::tuple get_block_coord(Params const& params) const { - int m_block, bidh, bidb; - bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); - return {m_block, bidh, bidb}; + int m_block, split_idx, bidh, bidb; + if constexpr(!Is_split) { + bidb = params.head_divmod.divmod(bidh, + params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, 1, bidh, bidb}; + } else { + bidb = params.head_divmod.divmod(bidh, + params.split_divmod.divmod(split_idx, + params.m_block_divmod.divmod(m_block, tile_idx))); + return {m_block, split_idx, bidh, bidb}; + } } }; diff --git a/hopper/utils.h b/hopper/utils.h index aaf0712ad..8232c7fc4 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -261,16 +261,16 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor __forceinline__ __device__ void write_tma( ElemO* O, const TMACopyO& tma_store_O, const LayoutO& layout_O, const TileShapeO& tile_shape_O, - const SMemO& sO, int m_block, int bidh, int bidb, + const SMemO& sO, int m_block, int bidh, int bidb, int n_split_idx, const SeqLenTraits& seqlen_traits_o, int write_warp_idx) { Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape()); - Tensor gO = seqlen_traits_o.get_local_tile_tensor( - mO, tile_shape_O, bidh, bidb + Tensor gO = seqlen_traits_o.get_o_local_tile_tensor( + mO, tile_shape_O, bidh, bidb, n_split_idx )(_, _, m_block); // (M, K) auto block_tma_O = tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) @@ -286,6 +286,94 @@ __forceinline__ __device__ void write_tma( // tma_store_wait<0>(); } +// Epilogue that copies RMEM -> GMEM directly for GQA enabled. +// Reports as uncoalesced stores by the profiler +template +__forceinline__ __device__ void write_rmem_to_gmem( + TensorO &tOrO, OutputType *O, const LayoutO& layout_O, TileShapeO tile_shape_O, + int m_block, int h_block, int bidh, int bidh_kv, int bidb, int n_split_idx, + TiledMma& tiled_mma, const SeqLenTraits& seqlen_traits_o, int thread_idx) { + static_assert(is_same_v, "rmem dtype must be float"); + Tensor mO = make_tensor(make_gmem_ptr(O), layout_O); + Tensor gO = [&] { + if constexpr(Use_gqa_layout) { + return seqlen_traits_o.get_o_local_tile_tensor( + mO, tile_shape_O, bidh_kv, bidb, n_split_idx + )(_, _, _, m_block, h_block); // (bM/bH, bH, K) + } else { + return seqlen_traits_o.get_o_local_tile_tensor( + mO, tile_shape_O, bidh, bidb, n_split_idx + )(_, _, m_block); // (bM, bK) + } + }(); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto tile_shape_mnk = cute::tile_shape(tiled_mma); + Tensor cO = cute::make_identity_tensor(select<0, 1>(tile_shape_mnk)); + Tensor tOcO = thread_mma.partition_C(cO); + // tOcO has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor tOcO_row = tOcO(make_coord(_0{}, _, _0{}), _, _0{}); + // reshape from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + const int m_bound = seqlen_traits_o.actual_seq_len - m_block * size<0>(gO); + // hardcoded col_idx to circumvent reg spilling with counting tensor + const int col_start_idx = !Column_permute_fp8 ? 2 * (thread_idx % 4) : 4 * (thread_idx % 4); + + if constexpr (Use_gqa_layout) { + static constexpr int kBlockH = size<1>(gO); + const int h_bound = shape<1>(layout_O) - h_block * kBlockH; + #pragma unroll + for(int nrow = 0; nrow < size<0>(tOrO_rowcol); ++nrow) { + const int row = int(get<0>(tOcO_row(nrow))); + const int h_local = row % kBlockH; + const int m_local = row / kBlockH; + if(h_local < h_bound && m_local < m_bound) { + if constexpr(!Column_permute_fp8) { + Tensor tOrO_nrow_float2 = recast(tOrO_rowcol(nrow, _)); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol)/2; ++ncol) { + *reinterpret_cast(&(gO(m_local, h_local, col_start_idx + 8 * ncol))) = + tOrO_nrow_float2(ncol); + } + } else { + Tensor tOrO_nrow = tOrO_rowcol(nrow, _); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol); ncol += 4) { + gO(m_local, h_local, col_start_idx + 4 * ncol) = tOrO_nrow(ncol); + gO(m_local, h_local, col_start_idx + 4 * ncol + 2) = tOrO_nrow(ncol + 1); + gO(m_local, h_local, col_start_idx + 4 * ncol + 1) = tOrO_nrow(ncol + 2); + gO(m_local, h_local, col_start_idx + 4 * ncol + 3) = tOrO_nrow(ncol + 3); + } + } + } + } + } else { + #pragma unroll + for(int nrow = 0; nrow < size<0>(tOrO_rowcol); ++nrow) { + const int row = int(get<0>(tOcO_row(nrow))); + if(row < m_bound) { + if constexpr(!Column_permute_fp8) { + Tensor tOrO_nrow_float2 = recast(tOrO_rowcol(nrow, _)); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol)/2; ++ncol) { + *reinterpret_cast(&(gO(row, col_start_idx + 8 * ncol))) = + tOrO_nrow_float2(ncol); + } + } else { + Tensor tOrO_nrow = tOrO_rowcol(nrow, _); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol); ncol += 4) { + gO(row, col_start_idx + 4 * ncol) = tOrO_nrow(ncol); + gO(row, col_start_idx + 4 * ncol + 2) = tOrO_nrow(ncol + 1); + gO(row, col_start_idx + 4 * ncol + 1) = tOrO_nrow(ncol + 2); + gO(row, col_start_idx + 4 * ncol + 3) = tOrO_nrow(ncol + 3); + } + } + } + } + } +} + template __forceinline__ __device__ void write_tiled( @@ -333,17 +421,24 @@ __forceinline__ __device__ void write_tiled( } } -template + typename TileShapeO, typename SMemO, typename SeqLenTraits, class TensorO, typename TiledMma> __forceinline__ __device__ void write_O( ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O, const LayoutO& layout_O, const TileShapeO& tile_shape_O, - const SMemO& sO, int m_block, int bidh, int bidb, - const SeqLenTraits& seqlen_traits_o, int write_warp_idx) { - if constexpr (IsTMACopy) { - write_tma(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o, write_warp_idx); + const SMemO& sO, int m_block, int bidh, int bidb, int n_split_idx, + const SeqLenTraits& seqlen_traits_o, int write_warp_idx, TiledMma & tiledMma1, TensorO & tOrO) { + + if constexpr (IsRegToGmem) { + static_assert(Is_split, "use write_rmem_to_gmem with split kv kernel only"); + write_rmem_to_gmem(tOrO, O, layout_O, tile_shape_O, m_block, bidh, bidb, n_split_idx, + tiledMma1, seqlen_traits_o, threadIdx.x - NumCopyThreads); + } else if constexpr (IsTMACopy) { + write_tma(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, + n_split_idx, seqlen_traits_o, write_warp_idx); } else { + static_assert(!Is_split, "Don't use write_tiled with split kv kernel"); write_tiled(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o); } }