Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda: 1.2x faster dequantization kernel #2809

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

li-plus
Copy link
Contributor

@li-plus li-plus commented Aug 26, 2023

Optimize cuda dequantization with memory coalescing, achieving 1.2x speed up. For now, I only implement the faster kernel for q4_0. If this PR gets accepted, I'll implement the rest.

I use nvprof to profile kernels on a V100-SXM2 GPU with 900GB/s memory bankwidth.

nvprof --print-gpu-trace ./main -m ./models/7B/ggml-model-q4_0.gguf -p "Hello" -n 8 -ngl 32 -nommq

Before this PR, dequantizing a 11008x4096 q4_0 weight matrix costs 332.10us:

...
2.34055s  332.10us         (176128 1 1)       (256 1 1)        16        0B        0B         -           -           -           -  Tesla V100-SXM2         1        13  void dequantize_block<int=32, int=2, __operator_&__(_INTERNAL_3560778b_12_ggml_cuda_cu_2317e5a6::dequantize_q4_0(void const *, int, int, float2&))>(void const *, float*, int) [7307]
2.34088s  238.40us           (1376 1 1)       (128 1 1)        40  3.0000KB        0B         -           -           -           -  Tesla V100-SXM2         1        13  void gemmSN_TN_kernel<float, int=128, int=16, int=2, int=4, int=2, int=2, bool=1, cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<float const , cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float const >, float>) [7311]
...

With this PR, the same operation only costs 277.18us (1.2x speed up).

...
5.48006s  277.18us          (44032 1 1)       (256 1 1)        16        0B        0B         -           -           -           -  Tesla V100-SXM2         1        13  dequantize_block_q4_0(void const *, float*, int) [7307]
5.48034s  239.17us           (1376 1 1)       (128 1 1)        40  3.0000KB        0B         -           -           -           -  Tesla V100-SXM2         1        13  void gemmSN_TN_kernel<float, int=128, int=16, int=2, int=4, int=2, int=2, bool=1, cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<float const , cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float const >, float>) [7311]
...

The memory bandwidth utilized reached (11008x4096x(0.5+4)B)/277.18us = 682GB/s (76% out of the peak 900GB/s), compared to the previous 569GB/s out of 900GB/s (63% peak).

@YellowRoseCx
Copy link
Contributor

Have you tested to see if this breaks CUDA compatibility for AMD cards using the recently merged ROCm pull request?

@KerfuffleV2
Copy link
Collaborator

Have you tested to see if this breaks CUDA compatibility for AMD cards using the recently merged ROCm pull request?

I tried it out:

ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6600, compute capability 10.3

Seems fine, identical results with a specific seed. Tested on a Q4_0 7B LLaMA1 model.

No difference in speed that I can see, which isn't too surprising since like 90% of time is spent in matrix multiplication.

ggml-cuda.cu Outdated

dfloat2 dv0;
dv0.x = (int)(qs.x & 0xf) - 8;
dv0.y = (int)(qs.y & 0xf) - 8;
Copy link
Contributor

@Engininja2 Engininja2 Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HIP/ROCm treats the x and y variables of a half2 as shorts, so I think this would work better, and then the same change for dv1 just below this.

    #ifdef GGML_CUDA_F16
    dv0 = __halves2half2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
    #else
    dv0.x = (int)(qs.x & 0xf) - 8;
    dv0.y = (int)(qs.y & 0xf) - 8;
    #endif

edit: replaced make_half2 with __halves2half2 which has been part of the CUDA API for longer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Introduced a make_dfloat2 macro to create the proper dfloat2 (half2 or float2)

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation behind this change? In which situation does the dequantization performance actually make a difference?

@li-plus
Copy link
Contributor Author

li-plus commented Aug 27, 2023

What is the motivation behind this change? In which situation does the dequantization performance actually make a difference?

@JohannesGaessler I'm optimizing chatglm.cpp and find that dequantization kernels cost ~50% of the total time of context computing. I only test short prompts though. Here is the nsys profile.

image

@JohannesGaessler
Copy link
Collaborator

I would suggest you try a longer prompt, the llama-bench binary, or the perplexity binary. For long prompts I would be very surprised if it made a noticeable difference.

@JohannesGaessler
Copy link
Collaborator

When I tested it using the llama-bench binary 13.9% of the kernel time was spent in dequantize_block on master. But this is only with -mmq 0 in conjunction with cuBLAS. The new mul_mat_q kernels which are both faster and use less VRAM do not invoke the corresponding dequantization kernel in the first place.

@jammm
Copy link
Contributor

jammm commented Oct 8, 2023

I would suggest you try a longer prompt, the llama-bench binary, or the perplexity binary. For long prompts I would be very surprised if it made a noticeable difference.

noob question - what is the prompt length and content of the llama-bench program? Shouldn't the performance depend on what the prompt is asking for? E.g., "write a 50 page script" whould be more tg heavy than "read this <50 page script> and summarize" right?

@slaren
Copy link
Member

slaren commented Oct 8, 2023

llama-bench only measures the time to process a prompt, not the response. The length of the prompt can be configured with the -p parameter. The contents of the prompt are just bos tokens, but the performance should be the same regardless of the contents.

@jammm
Copy link
Contributor

jammm commented Oct 9, 2023

llama-bench only measures the time to process a prompt, not the response. The length of the prompt can be configured with the -p parameter. The contents of the prompt are just bos tokens, but the performance should be the same regardless of the contents.

I see. Thanks! So the tg test simply calculates the amount of bos tokens based on the default number of tokens to be generated? Does this imply the same speed when running main with the same prompt length?

@slaren
Copy link
Member

slaren commented Oct 9, 2023

The tg test simulates text generation by generating a sequence of the length given with -n, one token at a time. The token feed to the model is always bos, as in the pp tests. Essentially, it would be the same as if you ran main without a prompt and banned every token except bos, so that the only token that could be sampled is bos. The pp test should be equivalent to evaluating a prompt of the same length with main.

@jammm
Copy link
Contributor

jammm commented Oct 9, 2023

The tg test simulates text generation by generating a sequence of the length given with -n, one token at a time. The token feed to the model is always bos, as in the pp tests. Essentially, it would be the same as if you ran main without a prompt and banned every token except bos, so that the only token that could be sampled is bos. The pp test should be equivalent to evaluating a prompt of the same length with main.

Thanks! That makes it a lot clearer now.

@ggerganov
Copy link
Member

With the recent performance improvement of -mmq 0 (#3412) this change looks more relevant.
Should we merge this and put some effort in optimizing the rest of the dequantization kernels?

@slaren
Copy link
Member

slaren commented Oct 10, 2023

I don't think it is worth to put any effort into this, we need to implement matrix multiplication kernels that can use tensor cores, ideally with integer operations.

@ggerganov
Copy link
Member

OK. For future reference, adding a data point for low-batch decoding on V100 (#3479)

make -j && ./bin/llama-bench -m /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-f16.gguf -mmq 0 -n 0 -t 1 -p 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512
model back ngl th mmq test master t/s PR t/s speedup
llama 7B Q4_0 CUDA 99 1 0 pp 1 45.40 ± 8.25 45.40 ± 8.11 1.000
llama 7B Q4_0 CUDA 99 1 0 pp 2 24.78 ± 0.12 32.29 ± 0.18 1.303
llama 7B Q4_0 CUDA 99 1 0 pp 3 36.65 ± 0.37 47.83 ± 0.25 1.305
llama 7B Q4_0 CUDA 99 1 0 pp 4 49.36 ± 0.07 62.73 ± 1.64 1.271
llama 7B Q4_0 CUDA 99 1 0 pp 5 60.40 ± 0.07 77.27 ± 0.83 1.279
llama 7B Q4_0 CUDA 99 1 0 pp 6 72.84 ± 0.08 93.77 ± 0.18 1.287
llama 7B Q4_0 CUDA 99 1 0 pp 7 82.90 ± 0.14 106.24 ± 0.11 1.282
llama 7B Q4_0 CUDA 99 1 0 pp 8 96.52 ± 0.57 122.71 ± 3.46 1.271
llama 7B Q4_0 CUDA 99 1 0 pp 9 106.96 ± 0.13 136.79 ± 0.09 1.279
llama 7B Q4_0 CUDA 99 1 0 pp 10 119.88 ± 0.04 153.53 ± 0.65 1.281
llama 7B Q4_0 CUDA 99 1 0 pp 11 129.90 ± 0.19 166.20 ± 0.25 1.279
llama 7B Q4_0 CUDA 99 1 0 pp 12 143.29 ± 0.56 183.78 ± 0.24 1.283
llama 7B Q4_0 CUDA 99 1 0 pp 13 151.01 ± 0.29 191.97 ± 0.17 1.271
llama 7B Q4_0 CUDA 99 1 0 pp 14 165.57 ± 0.04 209.25 ± 3.43 1.264
llama 7B Q4_0 CUDA 99 1 0 pp 15 175.41 ± 0.16 222.84 ± 0.66 1.270
llama 7B Q4_0 CUDA 99 1 0 pp 16 190.85 ± 0.11 244.04 ± 0.24 1.279
llama 7B Q4_0 CUDA 99 1 0 pp 32 337.25 ± 6.56 417.63 ± 10.76 1.238
llama 7B Q4_0 CUDA 99 1 0 pp 64 575.15 ± 0.36 686.56 ± 0.43 1.194
llama 7B Q4_0 CUDA 99 1 0 pp 128 1096.75 ± 38.20 1306.71 ± 31.41 1.191
llama 7B Q4_0 CUDA 99 1 0 pp 256 1803.38 ± 1.90 2064.30 ± 6.43 1.145
llama 7B Q4_0 CUDA 99 1 0 pp 512 2329.00 ± 4.80 2538.12 ± 2.10 1.090

@jammm
Copy link
Contributor

jammm commented Oct 10, 2023

I don't think it is worth to put any effort into this, we need to implement matrix multiplication kernels that can use tensor cores, ideally with integer operations.

Shouldn't that be handled by cublas/hipblas already? They should use tensor cores/WMMA.

@slaren
Copy link
Member

slaren commented Oct 10, 2023

Shouldn't that be handled by cublas/hipblas already? They should use tensor cores/WMMA.

Yes, but writing our own kernels would allow us to do this without having to dequantize the entire matrix to main memory first, and to use INT8 instead of FP16.

@jammm
Copy link
Contributor

jammm commented Oct 10, 2023

Shouldn't that be handled by cublas/hipblas already? They should use tensor cores/WMMA.

Yes, but writing our own kernels would allow us to do this without having to dequantize the entire matrix to main memory first, and to use INT8 instead of FP16.

I see. Makes sense. For those kernels, I would highly recommend using mma.h (for CUDA/NVIDIA) and rocWMMA for AMD https://github.com/ROCmSoftwarePlatform/rocWMMA. This will make sure you have a single codebase for using tensorcores for both vendors.

This will provide the additional benefit that it'll also run seamlessly on MI GPU's matrix cores (however you'd have to support wave64 mode.. shouldn't be too tricky though).

The main caveat for RDNA3 support would be that it's restricted to 16x16x16 GEMM whereas NVIDIA may support other GEMM sizes. Keeping all matmuls to 16x16 will make the code portable, then you can specialize that afterword.

@mofosyne mofosyne added performance Speed related topics Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants