-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : understand why GPU results are different for different batch sizes #3014
Comments
That's a bit of a stab in the dark, but from what I can see Makefile invokes nvcc with use fast math Does removing fast math flag and adding -fmad=false to nvcc help ? |
I can confirm that the results for CPU are identical. Compiling with |
If I run KQ and KQV on the CPU and use mul_mat_q for all other matrix multiplications I get identical results. So my interpretation of these findings is that cuBLAS produces slightly different results depending on shape and on master cuBLAS is always used for the f16 KV cache which has a shape depending on batch size. After #2969 the results should be identical (unless some other tensor also depends on batch size). |
That's good to know. Overall, I don't think we can make the computation 100% identical across batch sizes due to numerical rounding effects in various places of the computation, but we at least have to know where are the sources of these variations. Looks like all of them are in the attention layer.
Another source of variation I think is the Will leave this issue for a while in case we get some more ideas, but I think I have a better understanding now and this does not look like a serious problem. |
With my current version of #2969 I am now getting identical results regardless of batch size:
I think |
Probably related answer to a similar question: huggingface/transformers#25420 (comment) |
I did the following experiment:
Run
perplexity
with the same input, but changing the batch size via the-b
parameter.Here are the results for the first few iterations on different backends:
The CPU results are invariant to the batch size which is OK.
However, there are some differences when running on the GPU. More pronounced with CUDA compared to Metal.
We should try to understand what is the root cause of this behavior.
Some more discussion in: #3006 (comment)
The text was updated successfully, but these errors were encountered: