Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

超长上下文依赖的 attention 计算 #3

Open
caijixueIT opened this issue Jun 12, 2024 · 1 comment
Open

超长上下文依赖的 attention 计算 #3

caijixueIT opened this issue Jun 12, 2024 · 1 comment

Comments

@caijixueIT
Copy link

caijixueIT commented Jun 12, 2024

在 prefill 阶段调用 flash-attention v2 官方实现,计算如下 attention
QKV shape [batch, seq_len, num_heads, head_dim] = [1, 128*1024, 128, 128] 计算速度相当慢,有没有针对长上下文(>128K)注意力的优化思路,请教大佬

@66RING
Copy link
Owner

66RING commented Jun 12, 2024

@caijixueIT
这个已经是很经典的prefill compute bound问题了,目前我能回忆出来的也不多:

  • 一个是FlashDecoding++. 主要思路就是进一步找能并行化的点,所以他们提出了异步softmax来让softmax并行化从而加速
  • 一个是FlashDecoding(或者说是FlashAttention3),主要就是说在sequence len这个维度加并行度,比如小batch长序列的时候并行度可能是没用满的,然后再加一个维度让softmax并行化。这个其实已经集成近flash attention了(所以有人说它是flash attention3),好像和run_mha_fwd_splitkv_dispatch相关。你可以搜一下flash attention里run_mha_fwd_splitkv_dispatch对应的python接口然后换掉modeling里的实现。或者xformers.ops.memory_efficient_attention.
  • 还有一个就不那么self-attention了,就是linear attention这些。主要思路大概是换掉softmax,从而让(QK)V能变成Q(KV), 这样计算代价就和seqlen成线性而不是平方关系了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants