Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : understand why GPU results are different for different batch sizes #3014

Closed
ggerganov opened this issue Sep 4, 2023 · 6 comments
Closed
Labels
question Further information is requested

Comments

@ggerganov
Copy link
Owner

I did the following experiment:

Run perplexity with the same input, but changing the batch size via the -b parameter.
Here are the results for the first few iterations on different backends:

# Q4_0 7B
# batch sizes: 16, 32, 64, 128, 256, 512

# CPU (M2, LLAMA_ACCELERATE=OFF):

[1]4.3233,[2]4.8256,[3]5.4456,[4]6.0456,[5]6.1772,[6]6.0762  # SIMD is off for n_batch = 16 (ggml_vec_dot_f16)
[1]4.3214,[2]4.8286,[3]5.4463,[4]6.0497,[5]6.1802,[6]6.0800
[1]4.3214,[2]4.8286,[3]5.4463,[4]6.0497,[5]6.1802,[6]6.0800
[1]4.3214,[2]4.8286,[3]5.4463,[4]6.0497,[5]6.1802,[6]6.0800
[1]4.3214,[2]4.8286,[3]5.4463,[4]6.0497,[5]6.1802,[6]6.0800
[1]4.3214,[2]4.8286,[3]5.4463,[4]6.0497,[5]6.1802,[6]6.0800

# Metal:

[1]4.3263,[2]4.8290,[3]5.4475,[4]6.0514,[5]6.1813,[6]6.0808,[7]6.2560,[8]6.3670,[9]6.7256,[10]6.9356
[1]4.3263,[2]4.8291,[3]5.4476,[4]6.0515,[5]6.1814,[6]6.0809,[7]6.2560,[8]6.3670,[9]6.7256,[10]6.9356
[1]4.3261,[2]4.8290,[3]5.4475,[4]6.0514,[5]6.1813,[6]6.0808,[7]6.2560,[8]6.3669,[9]6.7256,[10]6.9356
[1]4.3263,[2]4.8291,[3]5.4476,[4]6.0515,[5]6.1814,[6]6.0809,[7]6.2561,[8]6.3670,[9]6.7256,[10]6.9356
[1]4.3263,[2]4.8290,[3]5.4476,[4]6.0515,[5]6.1814,[6]6.0809,[7]6.2560,[8]6.3670,[9]6.7256,[10]6.9356
[1]4.3264,[2]4.8291,[3]5.4476,[4]6.0515,[5]6.1814,[6]6.0809,[7]6.2561,[8]6.3670,[9]6.7256,[10]6.9356

# CUDA:

[1]4.3283,[2]4.8268,[3]5.4451,[4]6.0526,[5]6.1871,[6]6.0874,[7]6.2609,[8]6.3685,[9]6.7238
[1]4.3329,[2]4.8348,[3]5.4534,[4]6.0545,[5]6.1855,[6]6.0867,[7]6.2617,[8]6.3744,[9]6.7305
[1]4.3303,[2]4.8109,[3]5.4355,[4]6.0431,[5]6.1755,[6]6.0727,[7]6.2414,[8]6.3526,[9]6.7111
[1]4.3264,[2]4.8292,[3]5.4521,[4]6.0559,[5]6.1865,[6]6.0894,[7]6.2580,[8]6.3652,[9]6.7194
[1]4.3666,[2]4.8513,[3]5.4581,[4]6.0586,[5]6.1911,[6]6.0899,[7]6.2577,[8]6.3674,[9]6.7188
[1]4.3307,[2]4.8364,[3]5.4609,[4]6.0671,[5]6.1965,[6]6.0940,[7]6.2651,[8]6.3749,[9]6.7282

The CPU results are invariant to the batch size which is OK.
However, there are some differences when running on the GPU. More pronounced with CUDA compared to Metal.

We should try to understand what is the root cause of this behavior.
Some more discussion in: #3006 (comment)

@ggerganov ggerganov added help wanted Extra attention is needed high priority Very important issue labels Sep 4, 2023
@ggerganov ggerganov moved this to Todo in ggml : roadmap Sep 4, 2023
@staviq
Copy link
Contributor

staviq commented Sep 4, 2023

That's a bit of a stab in the dark, but from what I can see Makefile invokes nvcc with use fast math

Does removing fast math flag and adding -fmad=false to nvcc help ?

@JohannesGaessler
Copy link
Collaborator

I can confirm that the results for CPU are identical. Compiling with LLAMA_CUBLAS and running perplexity with 0 GPU layers still changes the results, so the matrix multiplications must change the results. The results change with both mul_mat_q and cuBLAS for the matrix multiplication kernels. I patched ggml_cuda_can_mul_mat to not run the matrix multiplications for KQ and KQV on the GPU and the results still changed. At this point the only CUDA kernels invoked were the dequantization kernels for q4_0 and q6_K as well as the cuBLAS kernels.

@JohannesGaessler
Copy link
Collaborator

If I run KQ and KQV on the CPU and use mul_mat_q for all other matrix multiplications I get identical results. So my interpretation of these findings is that cuBLAS produces slightly different results depending on shape and on master cuBLAS is always used for the f16 KV cache which has a shape depending on batch size. After #2969 the results should be identical (unless some other tensor also depends on batch size).

@ggerganov
Copy link
Owner Author

So my interpretation of these findings is that cuBLAS produces slightly different results depending on shape

That's good to know. Overall, I don't think we can make the computation 100% identical across batch sizes due to numerical rounding effects in various places of the computation, but we at least have to know where are the sources of these variations. Looks like all of them are in the attention layer.

unless some other tensor also depends on batch size

Another source of variation I think is the KQ_soft_max. There is a sum over number of elements that depends on n_batch. On the CPU we accumulate into double and probably this helps to reduce the variability.

Will leave this issue for a while in case we get some more ideas, but I think I have a better understanding now and this does not look like a serious problem.

@ggerganov ggerganov added question Further information is requested and removed help wanted Extra attention is needed high priority Very important issue labels Sep 5, 2023
@JohannesGaessler
Copy link
Collaborator

With my current version of #2969 I am now getting identical results regardless of batch size:

perplexity: calculating perplexity over 10 chunks, batch_size=512
perplexity: 0.25 seconds per pass - ETA 0.03 minutes
[1]4.2963,[2]4.7982,[3]5.4232,[4]6.0360,[5]6.1716,[6]6.0714,[7]6.2438,[8]6.3542,[9]6.7121,[10]6.9221,
Final estimate: PPL = 6.9220988942 +/- 0.3210281027

perplexity: calculating perplexity over 10 chunks, batch_size=256
perplexity: 0.25 seconds per pass - ETA 0.03 minutes
[1]4.2963,[2]4.7982,[3]5.4232,[4]6.0360,[5]6.1716,[6]6.0714,[7]6.2438,[8]6.3542,[9]6.7121,[10]6.9221,
Final estimate: PPL = 6.9220988942 +/- 0.3210281027

perplexity: calculating perplexity over 10 chunks, batch_size=128
perplexity: 0.28 seconds per pass - ETA 0.03 minutes
[1]4.2963,[2]4.7982,[3]5.4232,[4]6.0360,[5]6.1716,[6]6.0714,[7]6.2438,[8]6.3542,[9]6.7121,[10]6.9221,
Final estimate: PPL = 6.9220988942 +/- 0.3210281027

perplexity: calculating perplexity over 10 chunks, batch_size=64
perplexity: 0.44 seconds per pass - ETA 0.07 minutes
[1]4.2963,[2]4.7982,[3]5.4232,[4]6.0360,[5]6.1716,[6]6.0714,[7]6.2438,[8]6.3542,[9]6.7121,[10]6.9221,
Final estimate: PPL = 6.9220988942 +/- 0.3210281027

I think KQ_soft_max does not matter due to the triangular masking.

@gante
Copy link

gante commented Nov 20, 2023

Probably related answer to a similar question: huggingface/transformers#25420 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
Status: Done
Development

No branches or pull requests

4 participants