Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

三个版本的性能对比结果如何? #4

Open
Amanda-Barbara opened this issue Jun 17, 2024 · 3 comments
Open

三个版本的性能对比结果如何? #4

Amanda-Barbara opened this issue Jun 17, 2024 · 3 comments

Comments

@Amanda-Barbara
Copy link

大佬,三个版本各自实现的flash-attention的性能对比结果如何?

@66RING
Copy link
Owner

66RING commented Jun 17, 2024

@Amanda-Barbara 学习版,不考虑性能,看懂了直接看官方版就行。triton的和cutlass的在特定的shape下能接近官方实现,因为为了简单起见这里的cutlass版写死了分块的大小,而官方版本会根据数据规模选择最优的分块大小。cuda版没做任何优化,纯属熟悉flash流程。

@vfdff
Copy link

vfdff commented Oct 30, 2024

请问 最优的分块大小 一般要考虑哪些因素?

@66RING
Copy link
Owner

66RING commented Oct 30, 2024

请问 最优的分块大小 一般要考虑哪些因素?

@vfdff 不太好说,感觉和输入规模,硬件算力,编译版本,驱动版本,smem大小,计算的形状等都有关, 感觉是个申请资源和使用资源的tradeoff。可以看一些别人枚举的例子(下面代码来自FlagAttention)

def get_config(M, D):
    if torch.cuda.get_device_capability() == (8, 0):
        if D <= 64:
            BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
        else:
            if M <= 1024:
                BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
            else:
                BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
    elif torch.cuda.get_device_capability() == (8, 6):
        if D <= 64:
            BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
        else:
            BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
    else:
        BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
    return (BLOCK_M, BLOCK_N, num_stages, num_warps)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants