Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom RoPE + bettter memory management for CUDA #2295

Merged
merged 2 commits into from
Jul 21, 2023

Conversation

ikawrakow
Copy link
Contributor

This PR does two things:

  • Implement the customizable RoPE as per Implement customizable RoPE #2054
  • Change ggml_cuda_pool_malloc() to be more forward looking when allocating memory. With this change I'm able to run with 8k contexts on CUDA. Without the change, the maximum context before running out of VRAM is somewhere between 5k and 6k for 7B, and between 3.5k and 4k for 13B.

I'm finding that the --rope-freq-base option works much better than --rope-freq-scale. This graph, computed with the changes in this PR, shows wikitext perplexity for contexts up to 8192 for the LLaMA-1 7B, 13B, and 30B models. Shown is the ratio of the perplexity at a given context length to the perplexity at context = 2048 (max. training context for the LLaMA-1 models).

ppl_vs_ctx

The dependence of the (approximately) best RoPE frequency needed as a function of context length (best as resulting in the lowest perplexity) is shown in the next graph. This can be fit with a second order polynomial as

base frequency  = 10000 * (-0.13436 + 0.80541 * x + 0.28833 * x^2) for 7B
base frequency  = 10000 * (-0.41726 + 1.1792 * x + 0.16915 * x^2) for 13B

where x = context size / 2048.

rope_frequency

@jxy
Copy link
Contributor

jxy commented Jul 20, 2023

Does the first figure show the ppl from the optimized base frequency from the polynomial fit? Which quant did you use for that? How does it change with the new llama 2 models?

@ikawrakow
Copy link
Contributor Author

Does the first figure show the ppl from the optimized base frequency from the polynomial fit?

No, not from the fit. I manually optimized for a set of context lengths, and the 1st figure shows that. I then used the optimized base frequency for the second figure and the fits. But just using the fits gets you very close to the manually optimized base frequency.

Which quant did you use for that?

Q6_K, which is known to match the ppl of the fp16 models within 0.1%. To give you specific numbers, the Q6_K perplexity for context length of 2048 is 5.2856 for 7B, 4.7094 for 13B, and 3.7552 for 30B. You can use these numbers to get actual perplexities for other context lengths together with the 1st figure.

How does it change with the new llama 2 models?

Haven't come around to try LLaMA-2 yet (but I'm curious too and will in the next days).

@JohannesGaessler
Copy link
Collaborator

The memory allocation change would work to reduce VRAM usage for LLaMA. But it does not universally reduce VRAM usage. If the largest temporary buffer comes first then VRAM usage will be 25% higher. I don't know the allocation patterns of other ggml projects so I can't comment on whether or not this PR would be beneficial or harmful for them. I think a better solution would be add a function that lets you specify the maximum expected buffer size (though this would not be mutually exclusive with this PR). Calling the function would then allocate such a buffer in the pool. The best solution I think would be to eliminate as many temporary buffers as possible via writing matrix multiplication kernels that can directly use quantized data (still WIP).

Related: CUDA has functions cudaMallocAsync and cudaFreeAsync that internally also use a buffer pool. But when I tried those the performance was worse than the implementation in ggml-cuda.cu.

@ikawrakow
Copy link
Contributor Author

@JohannesGaessler

I fully agree with you that this is not the ultimate solution to managing VRAM memory usage. It is just a quick band-air for now so one can run on CUDA with larger contexts. Here is what I get on master (after adding logging to ggml_cuda_pool_malloc() to master:

./perplexity -m q6k_13.bin -f tests/wikitext-2-raw/wiki.test.raw -s 1234 -t 16 -ngl 40 -c 8192 --rope-freq-base 70000
main: warning: model might not support context sizes greater than 2048 tokens (8192 specified);expect poor results
main: build = 856 (1cdbbbb)
main: seed  = 1234
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4080, compute capability 8.9
llama.cpp: loading model from q6k_13.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 8192
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: freq_base  = 70000.0
llama_model_load_internal: freq_scale = 1
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 2762.47 MB (+ 1608.00 MB per state)
llama_model_load_internal: allocating batch_size x (640 kB + n_ctx x 160 B) = 960 MB VRAM for the scratch buffer
llama_model_load_internal: offloading 40 repeating layers to GPU
llama_model_load_internal: offloaded 40/43 layers to GPU
llama_model_load_internal: total VRAM used: 10888 MB
llama_new_context_with_model: kv self size  = 6400.00 MB

system_info: n_threads = 16 / 32 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
perplexity: calculating perplexity over 40 chunks, batch_size=512
ggml_cuda_pool_malloc: 0 buffers using 0 MiB. Allocating additional 10 MiB
ggml_cuda_pool_malloc: 1 buffers using 10 MiB. Allocating additional 100 MiB
ggml_cuda_pool_malloc: 0 buffers using 0 MiB. Allocating additional 10 MiB
ggml_cuda_pool_malloc: 0 buffers using 0 MiB. Allocating additional 40 MiB
ggml_cuda_pool_malloc: 4 buffers using 160 MiB. Allocating additional 270 MiB
ggml_cuda_pool_malloc: 4 buffers using 160 MiB. Allocating additional 625 MiB
ggml_cuda_pool_malloc: 3 buffers using 60 MiB. Allocating additional 62 MiB
ggml_cuda_pool_malloc: 4 buffers using 182 MiB. Allocating additional 120 MiB
ggml_cuda_pool_malloc: 5 buffers using 222 MiB. Allocating additional 120 MiB
ggml_cuda_pool_malloc: 6 buffers using 342 MiB. Allocating additional 160 MiB
ggml_cuda_pool_malloc: 7 buffers using 462 MiB. Allocating additional 160 MiB
ggml_cuda_pool_malloc: 8 buffers using 622 MiB. Allocating additional 200 MiB
ggml_cuda_pool_malloc: 9 buffers using 782 MiB. Allocating additional 200 MiB
ggml_cuda_pool_malloc: 10 buffers using 982 MiB. Allocating additional 240 MiB
ggml_cuda_pool_malloc: 11 buffers using 1182 MiB. Allocating additional 240 MiB
ggml_cuda_pool_malloc: 12 buffers using 1422 MiB. Allocating additional 280 MiB
ggml_cuda_pool_malloc: 13 buffers using 1662 MiB. Allocating additional 280 MiB
ggml_cuda_pool_malloc: 14 buffers using 1942 MiB. Allocating additional 320 MiB
ggml_cuda_pool_malloc: 15 buffers using 2222 MiB. Allocating additional 320 MiB
ggml_cuda_pool_malloc: 16 buffers using 2542 MiB. Allocating additional 360 MiB
ggml_cuda_pool_malloc: 17 buffers using 2862 MiB. Allocating additional 360 MiB
CUDA error 2 at /home/iwan/other/llama.cpp/ggml-cuda.cu:2444: out of memory

And here is what I get with this PR after enabling logging:

iwan@tdcu-7950X:~/other/llama.cpp/cuda$ ./bin/perplexity -m q6k_13.bin -f ../tests/wikitext-2-raw/wiki.test.raw -s 1234 -t 16 -ngl 41 -c 8192 --rope-freq-base 70000
main: warning: model might not support context sizes greater than 2048 tokens (8192 specified);expect poor results
main: build = 856 (1cdbbbb)
main: seed  = 1234
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4080, compute capability 8.9
llama.cpp: loading model from q6k_13.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 8192
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: freq_base  = 70000.0
llama_model_load_internal: freq_scale = 1
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 2634.27 MB (+ 1608.00 MB per state)
llama_model_load_internal: allocating batch_size x (640 kB + n_ctx x 160 B) = 960 MB VRAM for the scratch buffer
llama_model_load_internal: offloading 40 repeating layers to GPU
llama_model_load_internal: offloading non-repeating layers to GPU
llama_model_load_internal: offloaded 41/43 layers to GPU
llama_model_load_internal: total VRAM used: 11016 MB
llama_new_context_with_model: kv self size  = 6400.00 MB

system_info: n_threads = 16 / 32 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
perplexity: calculating perplexity over 40 chunks, batch_size=512
ggml_cuda_pool_malloc: 0 buffers, max_size = 0 MB, tot_size = 0 MB, requested 10 MB
ggml_cuda_pool_malloc: 1 buffers, max_size = 12 MB, tot_size = 12 MB, requested 100 MB
ggml_cuda_pool_malloc: 0 buffers, max_size = 0 MB, tot_size = 0 MB, requested 10 MB
ggml_cuda_pool_malloc: 0 buffers, max_size = 0 MB, tot_size = 0 MB, requested 40 MB
ggml_cuda_pool_malloc: 4 buffers, max_size = 125 MB, tot_size = 200 MB, requested 270 MB
ggml_cuda_pool_malloc: 5 buffers, max_size = 337 MB, tot_size = 537 MB, requested 625 MB
perplexity: 49.89 seconds per pass - ETA 33 minutes
[1]6.4151,[2]5.0308,[3]5.1318,[4]5.3116,^C

I don't think the behavior is model specific.

This is sufficient it seems.
We end up using about 200 MB less VRAM that way when running
the 13B model with context 8192.
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the memory pool change is simple enough and it's unlikely to cause a disruption across other projects. If this is the only concern, I recommend we merge the PR

@ikawrakow
Copy link
Contributor Author

I have reduced the 25% being added to the requested size to 5%. This is enough to run up to 8k context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants