You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Torch's SDPA doesn't require V to have the same dimensions as inputs, it even noted in docs with different dimensions E and Ev as when V is multiplied by, head dimensions is gone and we have only L x L matrix.
In [23]: qk=torch.randn(4, 4, 4, 8).bfloat16().cuda()
In [24]: v=torch.randn(4, 4, 4, 16).bfloat16().cuda()
In [25]: F.scaled_dot_product_attention(qk, qk, v).shapeOut[25]: torch.Size([4, 4, 4, 16])
While it's theoretically possible, we don't plan to do that. The reason is that we're already templating on the head dimension (32, 64, 96, 128, 160, 192, 224, 256). If V has a different head dimension we'd need to increase the number of templates by 8x, and compilation time will increase by 8x.
While it's theoretically possible, we don't plan to do that. The reason is that we're already templating on the head dimension (32, 64, 96, 128, 160, 192, 224, 256). If V has a different head dimension we'd need to increase the number of templates by 8x, and compilation time will increase by 8x.
Hi,I found that the latest flash3 only supports head sizes of 64 or 128, are you planning to include more?
Torch's SDPA doesn't require V to have the same dimensions as inputs, it even noted in docs with different dimensions E and Ev as when V is multiplied by, head dimensions is gone and we have only L x L matrix.
same with xfrormers, they use
K
andKv
in doc.However flash attention 2 [2.4.2] requires head dimensions to match.
(as documented it requires all tensors to have
headdim
per head (error uses different name than documentation))can it be relaxed to have different head_size for v or implementation depends on head dimensions match?
The text was updated successfully, but these errors were encountered: