-
Notifications
You must be signed in to change notification settings - Fork 578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MLA for DeepSeek-V2 with Triton - step 1 #905
Conversation
@ispobock Nice work! |
Logits diffHF:
sglang MLA:
Benchmark
main branch + triton kernel:
main branch + flashinfer kernel:
MLA (this PR):
reproduce:
Evaluationmmlu average accuracy:
gsm8k accuracy:
reproduce:
|
Hi @ispobock Very impressive throughput improvement, may you test the eval, such as MMLU and gsm8k, and also do a regression test on Llama 3 8B Instruct with/without FlashInfer? Thanks. |
Hi @ispobock You may add |
After completing the benchmark and evaluation, we might add a switch to the MLA feature, keeping FlashInfer as the default for now on DeepSeek V2. Currently, even the main branch struggles with running DeepSeek V2 on H100s due to issues with Triton's implementation. #913 We could look into adding weight fusion support in another PR. Special thanks to @ispobock's contribution and @grimoire for previously implementing the MLA version of DeepSeek V2 in LMDeploy PyTorch Engine, which has been incredibly helpful and inspiring. https://github.com/InternLM/lmdeploy/pull/1621/files The initial MLA implementation on Triton significantly outperforms MHA on Triton. We're considering incorporating MLA Attention into FlashInfer moving forward and would appreciate if @ispobock could explore this possibility. Looking forward to an update from @yzh119. Do you have any suggestions? Thanks. @merrymercy @Ying1123 @hnyls2002 @yzh119 |
Benchmark result is updated. |
A100 80G
|
After this commit 94b1578, when we want to use MLA, we should add |
A100 80G x8
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM cc @merrymercy @Ying1123 @hnyls2002 @yzh119
LGTM. Very nice work!! 🎉 |
hold on |
@ispobock We may add a regression testing with Llama 3. |
There is an issue when disabling flashinfer, let me fix it. |
ok |
Evaluated the average accuracy on mmlu by Llama-3-8B:
DeepSeek-V2-Lite:
|
TP cases # flashinfer
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-lite --disable-radix-cache --trust-remote-code --tp 2
# triton
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-lite --disable-radix-cache --trust-remote-code --tp 2 --disable-flashinfer
# mla
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-lite --disable-radix-cache --trust-remote-code --tp 2 --enable-mla python3 -m sglang.bench_serving --backend sglang
python3 benchmark/mmlu/bench_sglang.py --nsub 10
|
Conclusion:
|
Great work! I might have some bandwidth to work on flashinfer's MLA next week. |
Motivation
MLA implementation.
Modification