Replies: 1 comment 3 replies
-
Just done a bit of straightforward porting of flexattention to here: https://github.com/zinccat/flaxattention |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
PyTorch 2.5 was recently released with one of the headline features being a prototype of FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention, which seems to be an interesting combination of (i) deep implementation wizardry and (ii) UX improvements, the latter being the main selling point for end-user adoption. What is the state of flexible, performance-optimized attention implementations in the JAX ML ecosystem?
Perhaps it's possible to achieve similar performance in a cross-platform, even more flexible way with Pallas kernels ("implementation wizardry") but as an end-user interested in attention variants (there are dozens of us! dozens!) I admit it would be nice to have a unified API that already exists, similar to what FlexAttention is promising for the PyTorch community.
Related: jax.nn.dot_product_attention (discussion in #21371 recognizes that attention is a sufficiently fundamental operation in modern ML practice to be addressed in core JAX, not delegated to downstream ML frameworks), #18121, #18314
Beta Was this translation helpful? Give feedback.
All reactions