Skip to content

Commit

Permalink
FA3 kvcache + split kv + gqa parallelization (#1236)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah authored Oct 15, 2024
1 parent bedf877 commit a5a7527
Show file tree
Hide file tree
Showing 65 changed files with 4,375 additions and 815 deletions.
10 changes: 6 additions & 4 deletions hopper/benchmark_flash_attention_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def time_fwd(func, *args, **kwargs):
# dim = 256
dropout_p = 0.0

methods = (["Pytorch", "Flash3", "cuDNN"]
methods = (["Pytorch", "Flash3"]
+ (["cuDNN"] if cudnn is not None else [])
# + (["Triton"] if attention_triton is not None else [])
# + (["xformers.c"] if xops is not None else [])
# + (["xformers.f"] if xops is not None else [])
Expand All @@ -247,10 +248,10 @@ def time_fwd(func, *args, **kwargs):
torch.cuda.empty_cache()
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)]
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(3)]

qkv = torch.stack([q, k, v], dim=2)
qkv = qkv.to(torch.float16)
qkv = qkv.to(torch.bfloat16)
f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
time_f[config, "Pytorch"] = f
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
Expand Down Expand Up @@ -289,7 +290,8 @@ def time_fwd(func, *args, **kwargs):
k,
v,
softmax_scale,
causal=causal,
causal=causal,
window_size=(-1,-1),
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
Expand Down
325 changes: 325 additions & 0 deletions hopper/benchmark_split_kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
import torch
import flash_attn
import flash_attn_interface
import itertools
import time
import math

import torch.utils.benchmark as benchmark

def round_up_to_power_of_2(x):
if x <= 1:
return 1
return 1 << (x - 1).bit_length()

def timeit(fn, *args, **kwargs):
torch.cuda.synchronize()

# Warmup
for _ in range(5):
fn(*args, **kwargs)

# Benchmark using PyTorch Timer
t = benchmark.Timer(
stmt='fn(*args, **kwargs)',
globals={'fn': fn, 'args': args, 'kwargs': kwargs}
)

# Measure execution time
measurement = t.timeit(20) # Runs the function 20 times
# measurement = t.blocked_autorange(min_run_time=1)
avg_time = measurement.mean # Average time in seconds

return avg_time

def main():
num_sms = torch.cuda.get_device_properties(
torch.cuda.current_device()
).multi_processor_count

max_splits = 129
check_all_splits = False

causal = True
# causal = False
# dtype=torch.float16
dtype=torch.bfloat16

torch.manual_seed(42)

model_configs = [
# ("Gemma-2-2B", 8, 4, 256),
# ("Gemma-2-9B", 16, 8, 256),
# ("Gemma-2-27B", 32, 16, 128),
# ("Qwen-2.5-0.5B", 14, 2, 64),
# ("Qwen-2.5-1.5B", 12, 2, 128),
# ("Qwen-2.5-7B", 28, 4, 128),
# ("Llama-3.1-8B", 32, 8, 128),
("Llama-3.1-70B", 64, 8, 128),
# ("Llama-3.1-405B", 128, 8, 128),
# ("Llama-3.2-1B", 32, 8, 64),
# ("Llama-3.2-3B", 24, 8, 128),
# ("Nemotron-4-15B", 48, 8, 128),
]

all_batch_configs = []

all_batch_configs.extend(itertools.product(
# [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen
[4096, 16384, 65536], # context_seqlen
# [131072], # context_seqlen
# [i for i in range(1, (num_sms) + 1)], # num_requests
[1, 4, 8, 16], # num_requests
# [1], # num_requests
[1, 4, 8, 16], # query_seqlen
# [1], # query_seqlen
))

num_caches = max(reqs for _, reqs, _ in all_batch_configs)
cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)

for model_name, nheads_q, nheads_kv, headdim in model_configs:
k_cache = torch.randn(
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
)
v_cache = torch.randn(
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
)
print(f"***{model_name}***")
print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}")

if check_all_splits is False:
print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}")

for context_seqlen, num_requests, query_seqlen in all_batch_configs:
bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4)
bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4)
blockH = round_up_to_power_of_2(nheads_q//nheads_kv)
blockM = 128 # true for hdim 128 causal and hdim 64
blockM_div_H = blockM//blockH
num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H)

q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype)
cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
cache_seqlens = torch.tensor(
[context_seqlen] * num_requests, dtype=torch.int32, device="cuda"
)

fa2_time_heuristic = timeit(
flash_attn.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
) * 1000. * 1000.
# fastest_splitk_time = float("inf")
# fastest_splitk = 0
# for i in range(1, max_splits):
# t = timeit(
# flash_attn.flash_attn_with_kvcache,
# q=q,
# k_cache=k_cache,
# v_cache=v_cache,
# cache_seqlens=cache_seqlens,
# cache_batch_idx=cache_idxs,
# causal=causal,
# num_splits=i,
# ) * 1000. * 1000.
# if t < fastest_splitk_time:
# fastest_splitk_time = t
# fastest_splitk = i

fa3_time_one_split = timeit(
flash_attn_interface.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=False,
num_splits=1,
) * 1000. * 1000.

fa3_time_gqa_heuristic = timeit(
flash_attn_interface.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=True,
num_splits=0,
max_seqlen_k_hint=context_seqlen
) * 1000. * 1000.

if check_all_splits:

fa3_fastest_num_splits = 0
fa3_fastest_splitk_time = float("inf")

for num_splits in range(1, max_splits):
t = timeit(
flash_attn_interface.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=False,
num_splits=num_splits
) * 1000. * 1000.

out0 = flash_attn_interface.flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=False,
num_splits=num_splits
)

out1 = flash_attn_interface.flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=False,
num_splits=1
)

max_diff = (out0 - out1).abs().max().item()
mean_diff = (out0 - out1).abs().mean().item()
# print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}")
# print (f"splits {num_splits}, time {t:.2f}")

if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")

if t < fa3_fastest_splitk_time:
fa3_fastest_splitk_time = t
fa3_fastest_num_splits = num_splits

fa3_fastest_num_splits_gqa = 0
fa3_fastest_splitk_time_gqa = float("inf")
for num_splits in range(1, max_splits):

t = timeit(
flash_attn_interface.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=True,
num_splits=num_splits
) * 1000. * 1000.

out0 = flash_attn_interface.flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=True,
num_splits=num_splits
)

out1 = flash_attn_interface.flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=True,
num_splits=1
)

max_diff = (out0 - out1).abs().max().item()
mean_diff = (out0 - out1).abs().mean().item()
# print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}")
# print (f"gqa splits {num_splits}, time {t:.2f}")

if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")

if t < fa3_fastest_splitk_time_gqa:
fa3_fastest_splitk_time_gqa = t
fa3_fastest_num_splits_gqa = num_splits

efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms
heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa
# remeasure to smooth anomalies
if heuristic_ratio > 1.1:

fa3_time_gqa_heuristic = timeit(
flash_attn_interface.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=True,
# num_splits=num_splits_select,
# num_splits=1,
num_splits=0,
max_seqlen_k_hint=context_seqlen
) * 1000. * 1000.

fa3_fastest_splitk_time_gqa = timeit(
flash_attn_interface.flash_attn_with_kvcache,
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_idxs,
causal=causal,
gqa_parallel=True,
num_splits=fa3_fastest_num_splits_gqa
) * 1000. * 1000.

if check_all_splits is True:
print(
f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, "
f"FA2:{fa2_time_heuristic:.2f}, "
# f"FA2 MANUAL:{fastest_splitk_time:.2f}, "
# f"FA2 NUM SPLITS:{fastest_splitk}, "
# f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, "
# f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, "
# f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, "
f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, "
f"FA3:{fa3_time_gqa_heuristic:.2f}, "
# f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, "
# f"FA2 NUM SPLITS:{fastest_splitk}, "
# f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, "
f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, "
# f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, "
f"EFF:{efficiency:.2f}, "
f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}"
)

if check_all_splits is False:
print(
f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}"
f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}"
f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}"
f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}"
)



if __name__ == "__main__":
main()
Loading

0 comments on commit a5a7527

Please sign in to comment.