-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
三个版本的性能对比结果如何? #4
Comments
@Amanda-Barbara 学习版,不考虑性能,看懂了直接看官方版就行。triton的和cutlass的在特定的shape下能接近官方实现,因为为了简单起见这里的cutlass版写死了分块的大小,而官方版本会根据数据规模选择最优的分块大小。cuda版没做任何优化,纯属熟悉flash流程。 |
请问 最优的分块大小 一般要考虑哪些因素? |
@vfdff 不太好说,感觉和输入规模,硬件算力,编译版本,驱动版本,smem大小,计算的形状等都有关, 感觉是个申请资源和使用资源的tradeoff。可以看一些别人枚举的例子(下面代码来自FlagAttention) def get_config(M, D):
if torch.cuda.get_device_capability() == (8, 0):
if D <= 64:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
else:
if M <= 1024:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
else:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
elif torch.cuda.get_device_capability() == (8, 6):
if D <= 64:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
else:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
else:
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
return (BLOCK_M, BLOCK_N, num_stages, num_warps) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
大佬,三个版本各自实现的flash-attention的性能对比结果如何?
The text was updated successfully, but these errors were encountered: