-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantized Mistral: Prompt processing slower than llama.cpp #153
Comments
Llama.cpp
Mistral.rs
Llama.cpp does dequant first, then matmul. We're doing dequant and matmul directly. This issue is useful ggerganov/llama.cpp#3776 where they enable the current approach |
@lucasavila00, do you think we should also dequantize to F16 for large batch size? To my understanding, this beneficial because the BLAS implementation of matrix-matrix product is faster than our MMQ kernel as the batch size increases. |
@EricLBuehler I'd like to test it... I tried running the candle example using candle before they added the MMQ kernels, and performance was the same-ish. I also tried to manually dequantize the QMatMuls of the attention layer and saw no improvements. If you have a different approach I'd be glad to test it. |
https://github.com/huggingface/candle-cublaslt I think we need to dequantize and use these cublastlt kernels? I'll try it |
@lucasavila00, that sounds great. Please let me know the results! |
@EricLBuehler candle already uses cublaslt, see MR #230 forcing dequantization then matmul
master
|
@lucasavila00, that is very interesting. How did you force the dequantization? |
With the lt_mul function of the MR https://github.com/EricLBuehler/mistral.rs/pull/230/files#diff-da1e6f56f0e565985ccaa246f41d45f33271525bb3ae0d3a776cb282ce797676R27 I forced it for the attention weights and MLP only |
@lucasavila00, does |
|
@EricLBuehler when I run llama.cpp and mistral.rs in interactive mode then I get close results... https://gist.github.com/lucasavila00/0155f94fbf13e988384af53af8841b0f
So I guess our |
Ah, nevermind the above. Llama.cpp samples 700tok/s in CPU. I forgot the ngl param https://gist.github.com/lucasavila00/646b6f6cb9757d1329dc7296b5f16e3e
So llama.cpp is indeed 3x faster, both benchmarks measure correctly etc |
@lucasavila00, I wonder if it is the |
@EricLBuehler that's seems to be the case. I can't find where the turning kernels come from though. I assume these are from an nvidia library, but I can't figure out why llama.cpp uses a different version from candle/cudarc 🤔 |
The version differs depending on heuristics Using this for matmuls I can trigger the turning kernels, but it takes too long on the f32->f16 conversions 🤔
|
Llama.cpp can dequantize directly to f16, candle cannot... Maybe it's worth it to raise an issue for direct-f16-dequantization? |
@lucasavila00, I have raised an issue. |
The PR #238 has the latest iteration of the code. It uses dequant+matmul only for prompts, and does the matmul in f16. It also has comparisons of runs between |
I think the current difference is now due to different kernels? Even though the names of the kernels are almost the same, it seems the ones used by candle are slower. I'm trying to figure out why they don't use the exact same kernels. The kernels distribution between llama.cpp and mistral.rs are almost the same. And the overall time matches the discrepancy between those 2 kernels. |
If I am not mistaken, our completion performance should also be improved by 60% (like prompt perf) because of the new F16 dequant support? |
For batch sizes > 8, yes. For batch sizes <=8 I think we'll want to continue to use MMQ (that's what llama.cpp does) The cublas MR still has these as TODOs though https://github.com/EricLBuehler/mistral.rs/pull/238/files#diff-da1e6f56f0e565985ccaa246f41d45f33271525bb3ae0d3a776cb282ce797676R20-R22 |
Ah, ok. I'm interested in how our performance compares to llama.cpp in that situation. |
That MR currently uses cublas for prompt and MMQ for completion. It should be something like cublas for prompt if seq_len > 32, otherwise MMQ. These are the llama.cpp heuristics if I understood it correctly |
Ah, I'm not even benchmarking prompts with batch sizes > 1, because I'm assuming we'll move forwards with #234 |
Yes, I just need to finish the testing and then I'll merge #234. I am looking forward to Candle adding support for calling hgemm, but if that takes a while I can add it. |
I think we're not measuring the same timings as llama.cpp exactly. Prompt timings include a memory transfer and the sampling. After huggingface/candle#2139 (comment) If I look at just the nvidia profile of a warmed run, llama.cpp takes ~350ms and mistral.rs takes ~400ms. That puts llama.cpp at ~1500t/s and mistral.rs at ~1300t/s |
@lucasavila00 yes, that is possible. Are they timing the memory transfer and sampling? |
No, they're just synchronizing. I wonder why mistral.rs has this 35ms of DtoH transfer. It happens only at prompt time, so it can't be logits transfer to CPU... |
Maybe it is our cloning in&out of the cache? I don't think that incurs any dtoh. Can you disable the prefix cacher to make sure it isn't doing anything? |
I'm using mistralrs-bench, which passes the config to disable it. Maybe it is still doing something? |
Ah ok. No, all functions are essentially gated by |
The only major dtoh I can think of is during sampling... |
I'm not counting the DtoH, so we're at 380ms regardless of it.
But why it lasts for 30ms just for prompt and not completion? 🤔 |
I think the rest of the time difference is due to slow attention mask application. I'm trying to gather evidence and look into improving it here or upstream in candle. |
Ah, the htod copy when making the attention mask may be to blame. Perhaps we could pre-generate a bunch (up to 512 tokens) and cache them? |
Doesn't look like it. It looks like it's a slow kernel in It comes from this part of the code (specifically,
You can see in the picture I highlighted the second attention layer, there's no htod in the bottom like in the first layer however it's still the slowest kernel of the attention mechanism. |
It could also be that llama.cpp does not use attention mask in the benchmark. I can't find the timings for attention mask in the profile. Since we use 10% of the time for attention mask, this is precisely the 35ms out of 385ms that differs from llama.cpp |
Ok. Do you think we can find a way to disable this elegantly? |
I disabled attention mask for mistralrs-bench here 97c0324 It used the builders so it did not require a public API change. I also re-ran the profiles with my GPU at the same temps, and with the latest commit I see:
I'm pretty sure from the profiles llama.cpp is not doing masking indeed. There's no big difference now, but I can see improvements we could make:
The affine division above is surprisingly one of the most expensive operations of the attention mechanism. I wonder if candle can optimize that... Also, it looks like llama.cpp copies data and convert dtypes faster than candle. |
I think we could close this now, after the merge of #238, but please feel free to reopen. Thank you for your help, I really appreciate it. I will be looking into fusing the affine division. |
Since generation speed is almost matching llama.cpp after #152 I think it's worth it trying to optimize prompt processing now.
The text was updated successfully, but these errors were encountered: