Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could you provide some FA examples to illustrate the improvement in FA2? #43

Open
RyeYuan opened this issue Nov 19, 2024 · 3 comments
Open

Comments

@RyeYuan
Copy link

RyeYuan commented Nov 19, 2024

Actually, we had built a FA2 unit test, which had original FA fwd op and sageAttention op. However, we did not observe the same performance improvement shown in the figure you provided across various sizes.
Here is our main code to verify the perf:

def benchmark_forward(
    fn, *inputs, repeats=100, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
):
    """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
    def amp_wrapper(*inputs, **kwinputs):
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
            fn(*inputs, **kwinputs)

    t = benchmark.Timer(
        stmt="fn_amp(*inputs, **kwinputs)",
        globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
        num_threads=torch.get_num_threads(),
    )
    m = t.timeit(repeats)
    return t, m

q = warp_tensor(torch.randn(batch_size, qheads,  seq_len, head_size, device=device,dtype=torch.float16, requires_grad=True))
k = warp_tensor(torch.randn(batch_size, kvheads, seq_len, head_size, device=device,dtype=torch.float16, requires_grad=True))
v = warp_tensor(torch.randn(batch_size, kvheads, seq_len, head_size, device=device,dtype=torch.float16, requires_grad=True))
import flash_attn
from flash_attn.flash_attn_interface import  flash_attn_func
t = benchmark_forward(flash_attn_func, q, k, v, 0.0, causal=causal, window_size=window_size, repeats=repeats, verbose=False)[1].times[0]
from sageattention import sageattn
t_sage = benchmark_forward(sageattn, q, k, v, tensor_layout="NHD", is_causal=causal, smooth_k=False)[1].times[0]

Our result:
image

P.S.
Our Nvidia Environment:
A800
python: 3.10
torch: 2.3.0
flash-attn: 2.4.2
triton: 2.3.0
sageAttention: branch-2.0.0 latest

Thanks a lot!

@jason-huang03
Copy link
Member

jason-huang03 commented Nov 20, 2024

Hello, as described in README, the performance is measured without counting the quantization and smoothing overhead, but only the time of the attention kernel. Also we admit that at head_dim 64 the kernel is now slow on A100/A800. We provide the code to bench sage attention kernel as follows:

import torch
from flash_attn.utils.benchmark import benchmark_forward

import sageattention._qattn as qattn

import argparse

parser = argparse.ArgumentParser(description='Benchmark QK Int8 PV FP16 Triton')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--num_heads', type=int, default=32, help='Number of heads')
parser.add_argument('--head_dim', type=int, default=128, help='Head dimension')
args = parser.parse_args()

head = args.num_heads
batch = args.batch_size
headdim = args.head_dim

print(f"CUDA QK Int8 PV FP16")
print(f"batch: {batch}, head: {head}, headdim: {headdim}")

WARP_Q = 32
WARP_K = 64

is_causal = False
_is_causal = 1 if is_causal else 0
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:

    flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1)
    kernel = qattn.qk_int8_sv_f16_accum_f16_attn_per_warp

    q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()
    k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()

    q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float).cuda()
    k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float).cuda()
    v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
    o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
    sm_scale = 1 / (headdim ** 0.5)
    for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0)
    torch.cuda.synchronize()
    _, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0, repeats=100, verbose=False, desc='Triton')
    print(f'{seq_len} flops:{flops/time.mean*1e-12}')

is_causal = True
_is_causal = 1 if is_causal else 0
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1)
    kernel = qattn.qk_int8_sv_f16_accum_f16_attn_per_warp

    q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()
    k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()

    q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float).cuda()
    k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float).cuda()
    v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
    o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
    sm_scale = 1 / (headdim ** 0.5)
    for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0)
    torch.cuda.synchronize()
    _, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0, repeats=100, verbose=False, desc='Triton')
    print(f'{seq_len} flops:{flops/time.mean*1e-12}')

Please let me know your new measurement result!

@jason-huang03
Copy link
Member

Also our result is measured on A800 PCIE.

@laomao0
Copy link

laomao0 commented Nov 28, 2024

Also our result is measured on A800 PCIE.

I run on A100.

CUDA QK Int8 PV FP16
batch: 4, head: 32, headdim: 128
is_causal: False
1024 flops:234.11632614022466
2048 flops:252.01598652480024
4096 flops:257.6368488648967
8192 flops:261.31631823053766
16384 flops:262.3792827611407
32768 flops:262.4355549880885
is_causal: True
1024 flops:175.29882153415872
2048 flops:214.3322904328756
4096 flops:234.26005810663187
8192 flops:247.81052348404322
16384 flops:255.51569468555905
32768 flops:259.50637619737614

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants