-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Could you provide some FA examples to illustrate the improvement in FA2? #43
Comments
Hello, as described in README, the performance is measured without counting the quantization and smoothing overhead, but only the time of the attention kernel. Also we admit that at head_dim 64 the kernel is now slow on A100/A800. We provide the code to bench sage attention kernel as follows: import torch
from flash_attn.utils.benchmark import benchmark_forward
import sageattention._qattn as qattn
import argparse
parser = argparse.ArgumentParser(description='Benchmark QK Int8 PV FP16 Triton')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--num_heads', type=int, default=32, help='Number of heads')
parser.add_argument('--head_dim', type=int, default=128, help='Head dimension')
args = parser.parse_args()
head = args.num_heads
batch = args.batch_size
headdim = args.head_dim
print(f"CUDA QK Int8 PV FP16")
print(f"batch: {batch}, head: {head}, headdim: {headdim}")
WARP_Q = 32
WARP_K = 64
is_causal = False
_is_causal = 1 if is_causal else 0
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1)
kernel = qattn.qk_int8_sv_f16_accum_f16_attn_per_warp
q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()
k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()
q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float).cuda()
k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float).cuda()
v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
sm_scale = 1 / (headdim ** 0.5)
for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0)
torch.cuda.synchronize()
_, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0, repeats=100, verbose=False, desc='Triton')
print(f'{seq_len} flops:{flops/time.mean*1e-12}')
is_causal = True
_is_causal = 1 if is_causal else 0
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1)
kernel = qattn.qk_int8_sv_f16_accum_f16_attn_per_warp
q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()
k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8).cuda()
q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float).cuda()
k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float).cuda()
v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16).cuda()
sm_scale = 1 / (headdim ** 0.5)
for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0)
torch.cuda.synchronize()
_, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 0, _is_causal, sm_scale, 0, repeats=100, verbose=False, desc='Triton')
print(f'{seq_len} flops:{flops/time.mean*1e-12}') Please let me know your new measurement result! |
Also our result is measured on A800 PCIE. |
I run on A100. CUDA QK Int8 PV FP16 |
Actually, we had built a FA2 unit test, which had original FA fwd op and sageAttention op. However, we did not observe the same performance improvement shown in the figure you provided across various sizes.
Here is our main code to verify the perf:
Our result:
P.S.
Our Nvidia Environment:
A800
python: 3.10
torch: 2.3.0
flash-attn: 2.4.2
triton: 2.3.0
sageAttention: branch-2.0.0 latest
Thanks a lot!
The text was updated successfully, but these errors were encountered: