-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MLA (Multi-Head Latent Attention) in DeepSeek-v2 #237
Comments
Hello @yzh119, I know about this, it involves merging the latent proj_in and proj_out into the Q/KV proj_in/out Maybe I will try to take a look this weekend. |
Btw it is Multi-Head Latent Attention |
Thank you @jon-chuang ! |
Any updates for deepseek v2? |
This PR implements the JIT compilation (#170 ) of flashinfer, after this PR, flashinfer will compile kernels just-in-time for different input data types and shapes, and cached the kernels at the disk, instead of pre-compile a set of kernels in the wheel. # Motivation The pip wheel size is exploding as we add support to more data types, more head dimensions, more attention variants and more kernel implementation. Pre-compile everything is not sustainable, and impedes development speed. This PR refactors the codebase to use torch's [JIT Compiling Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions) feature instead of pre-compile kernels in the wheel. ## Attention Variants We learned from [FlexAttention](https://pytorch.org/blog/flexattention/) and describes every attention variant as a template class, each instance of the struct can carry some closure variable defined in local memory or shared memory, below are two examples (logits soft cap and alibi attention, the programming interface is tentative and will be updated as we improve the programmability of the JIT template): ```cuda template <typename ParamsT> struct LogitsSoftCap { using DTypeQ = typename ParamsT::DTypeQ; using DTypeKV = typename ParamsT::DTypeKV; using DTypeO = typename ParamsT::DTypeO; uint32_t qo_len, kv_len; uint32_t window_left; __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; } template <typename T> __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); } template <typename T> __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return params.logits_soft_cap * math::log2e * float(math::tanh(logits)); } __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return true; } }; template <typename ParamsT> struct ALIBIAttention { using DTypeQ = typename ParamsT::DTypeQ; using DTypeKV = typename ParamsT::DTypeKV; using DTypeO = typename ParamsT::DTypeO; using IdType = typename ParamsT::IdType; uint32_t qo_len, kv_len; uint32_t window_left; __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; } template <typename T> __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { return float(q) * params.sm_scale * math::log2e; } template <typename T> __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); } __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return true; } }; ``` User can customize their own `ParamsT` class and variants class to define their own attention variants, we hope such refactor will make the codebase more concise and extensive. # Roadmap After this PR, we will add support for: 1. PyPI wheels #153 2. fp8 tensor cores attention: #502 3. different head dimensions: #142 #454 #455 4. flashattention3 #369 5. multi-head latency attention #237 6. Generate ParamsT and Attention variants description from python dsl The development of this features have been blocked by the limitation of wheel size (binary size >= 2GB will trigger some linking issues), I hope this PR will make development easier in the future.
Is MLA supported now? If it is supported, could you point out how to use it? |
Hi, #551 is the first step to support MLA. The MLA prefill is still need some time. |
hi @liangzelang I'm afraid you still can't use MHA prefill kernel to support MLA prefill even you manually do the projection and concatenation form compressed_kv and k_pe to produce the KV data, because there is still one slight difference that RoPE is only applied to the 64-dim portion out of the whole 192-dim. |
MLA(Multi-Head Latency Attention) was proposed in DeepSeek-v2 for efficient inference.
The text was updated successfully, but these errors were encountered: