Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA memory pool with async memory allocation/deallocation #3903

Merged
merged 3 commits into from
Nov 2, 2023

Conversation

young-developer
Copy link
Contributor

@young-developer young-developer commented Nov 2, 2023

Custom implementation of memory pool was changed to CUDA device memory pool.
If device doesn`t support memory pool for some reasons it fallback to old implementation.

P.S. I changed in CUDA_CHECK and CUBLAS_CHECK id to dev_id because of warning "Declaration shadows a local variable"

@young-developer young-developer changed the title CUDA memory pool with async memory allocation deallocation CUDA memory pool with async memory allocation/deallocation Nov 2, 2023
ggml-cuda.cu Outdated Show resolved Hide resolved
@slaren
Copy link
Collaborator

slaren commented Nov 2, 2023

This is a bit slower than the current pool, for me with 7B it reduces TG performance by ~2%, but it may still be worth doing for the memory savings.

The best solution may be using virtual memory to expand the allocation size as needed as explained here. Then, since we can guarantee that deallocations happen in the reverse order of allocations, allocating memory would only require increasing a head pointer, and freeeing decreasing it.

@young-developer
Copy link
Contributor Author

young-developer commented Nov 2, 2023

This is a bit slower than the current pool, for me with 7B it reduces TG performance by ~2%, but it may still be worth doing for the memory savings.

The best solution may be using virtual memory to expand the allocation size as needed as explained here. Then, since we can guarantee that deallocations happen in the reverse order of allocations, allocating memory would only require increasing a head pointer, and freeeing decreasing it.

I am curios what is performance if there is more than one GPU. Did you test with multiple GPUs?

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2023

I do not have a good way to test this with multiple GPUs.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it may still be worth doing for the memory savings.

Why does it use less memory? Is it because of the "lookahead" over-allocation in the custom implementation?

@young-developer
Copy link
Contributor Author

Custom mempool:

main: n_kv_max = 2048, is_pp_shared = 0, n_gpu_layers = 999, mmq = 0

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.090 1419.62 2.114 60.54 2.204 116.13
128 128 2 512 0.101 2538.85 2.710 94.45 2.811 182.13
128 128 4 1024 0.177 2893.67 2.780 184.16 2.957 346.28
128 128 8 2048 0.336 3043.48 4.028 254.24 4.364 469.27
128 256 1 384 0.078 1646.22 4.229 60.53 4.307 89.16
128 256 2 768 0.098 2603.29 5.443 94.06 5.542 138.59
128 256 4 1536 0.166 3078.05 5.530 185.17 5.696 269.64
256 128 1 384 0.099 2589.47 2.134 59.98 2.233 171.97
256 128 2 768 0.164 3124.62 2.740 93.44 2.904 264.50
256 128 4 1536 0.342 2993.52 2.774 184.59 3.116 492.96
256 256 1 512 0.099 2587.45 4.209 60.82 4.308 118.84
256 256 2 1024 0.163 3140.01 5.429 94.30 5.593 183.10
256 256 4 2048 0.344 2973.30 5.589 183.21 5.934 345.15
512 128 1 640 0.164 3121.09 2.104 60.85 2.268 282.24
512 128 2 1280 0.345 2967.24 2.760 92.75 3.105 412.22
512 256 1 768 0.163 3137.25 4.269 59.97 4.432 173.29
512 256 2 1536 0.343 2983.73 5.503 93.04 5.846 262.73

llama_print_timings: load time = 3058.06 ms
llama_print_timings: sample time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
llama_print_timings: prompt eval time = 48646.15 ms / 15888 tokens ( 3.06 ms per token, 326.60 tokens per second)
llama_print_timings: eval time = 19057.85 ms / 1152 runs ( 16.54 ms per token, 60.45 tokens per second)
llama_print_timings: total time = 70679.85 ms

CUDA Mempool

main: n_kv_max = 2048, is_pp_shared = 0, n_gpu_layers = 999, mmq = 0

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 0.094 1363.41 2.097 61.04 2.191 116.84
128 128 2 512 0.102 2498.24 2.685 95.35 2.787 183.69
128 128 4 1024 0.169 3031.54 2.722 188.08 2.891 354.18
128 128 8 2048 0.332 3084.63 3.936 260.18 4.268 479.89
128 256 1 384 0.077 1664.54 4.141 61.82 4.218 91.04
128 256 2 768 0.097 2640.48 5.417 94.51 5.514 139.27
128 256 4 1536 0.163 3133.11 5.521 185.46 5.685 270.19
256 128 1 384 0.098 2616.49 2.165 59.13 2.263 169.71
256 128 2 768 0.160 3197.82 2.755 92.92 2.915 263.45
256 128 4 1536 0.347 2955.18 2.766 185.10 3.113 493.48
256 256 1 512 0.100 2558.06 4.213 60.76 4.314 118.70
256 256 2 1024 0.164 3127.84 5.456 93.84 5.620 182.21
256 256 4 2048 0.341 3002.69 5.574 183.71 5.915 346.23
512 128 1 640 0.165 3111.52 2.182 58.67 2.346 272.79
512 128 2 1280 0.339 3017.32 2.760 92.75 3.099 412.99
512 256 1 768 0.164 3124.96 4.254 60.18 4.418 173.84
512 256 2 1536 0.340 3008.84 5.458 93.81 5.798 264.91

llama_print_timings: load time = 3078.71 ms
llama_print_timings: sample time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
llama_print_timings: prompt eval time = 48390.87 ms / 15888 tokens ( 3.05 ms per token, 328.33 tokens per second)
llama_print_timings: eval time = 19050.17 ms / 1152 runs ( 16.54 ms per token, 60.47 tokens per second)
llama_print_timings: total time = 70436.36 ms

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2023

The current implementation just makes an allocation with the same size as requested, and then tries to reuse these allocations when possible. So for example with cuBLAS it will allocate enough memory to convert the first weight that is used to F16, but when we need to convert a larger weight to F16 then this allocation will be too small and it will require a new allocation, nearly doubling the memory usage. The worst case is probably when evaluating with increasingly larger batch sizes, since the previous allocations for src1/dst will be too small to reuse, it will allocate new buffers with each evaluation. It's just really bad and has far more use than ever intended.

@young-developer
Copy link
Contributor Author

young-developer commented Nov 2, 2023

We can always add toogle to switch between both implementations.
Another advantage is: if NVIDIA add new changes to memory pool performance etc etc, we get it without any changes.

@ggerganov ggerganov merged commit d606905 into ggerganov:master Nov 2, 2023
32 checks passed
@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Nov 2, 2023

This PR broke llama.cpp on my GTX 970 (Maxwell, compute 5.2):

CUDA error 801 at ggml-cuda.cu:6792: operation not supported
current device: 0

@young-developer
Copy link
Contributor Author

young-developer commented Nov 3, 2023

This PR broke llama.cpp on my GTX 970 (Maxwell, compute 5.2):

CUDA error 801 at ggml-cuda.cu:6792: operation not supported
current device: 0

I will try to add additional checks for CUDA pool. There is check but for some reason your gpu can get pool but can can't allocate so I assume it is not supported by your gpu so should fallback to old implementation when needed. I will check it.

@young-developer young-developer deleted the cuda-memory-pool branch November 3, 2023 11:57
@young-developer
Copy link
Contributor Author

@cebtenzzre I added additional check in device properties: #3931

slaren added a commit that referenced this pull request Nov 4, 2023
…location when available (#3903)"

This reverts commit d606905.

ggml-ci
ggerganov pushed a commit that referenced this pull request Nov 5, 2023
* Revert "cuda : add ROCM aliases for CUDA pool stuff (#3918)"

This reverts commit 629f917.

* Revert "cuda : use CUDA memory pool with async memory allocation/deallocation when available (#3903)"

This reverts commit d606905.

ggml-ci
olexiyb pushed a commit to Sanctum-AI/llama.cpp that referenced this pull request Nov 23, 2023
… when available (ggerganov#3903)

* Using cuda memory pools for async alloc/dealloc.

* If cuda device doesnt support memory pool than use old implementation.

* Removed redundant cublasSetStream

---------

Co-authored-by: Oleksii Maryshchenko <[email protected]>
olexiyb pushed a commit to Sanctum-AI/llama.cpp that referenced this pull request Nov 23, 2023
* Revert "cuda : add ROCM aliases for CUDA pool stuff (ggerganov#3918)"

This reverts commit 629f917.

* Revert "cuda : use CUDA memory pool with async memory allocation/deallocation when available (ggerganov#3903)"

This reverts commit d606905.

ggml-ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants