-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom RoPE + bettter memory management for CUDA #2295
Conversation
Does the first figure show the ppl from the optimized base frequency from the polynomial fit? Which quant did you use for that? How does it change with the new llama 2 models? |
No, not from the fit. I manually optimized for a set of context lengths, and the 1st figure shows that. I then used the optimized base frequency for the second figure and the fits. But just using the fits gets you very close to the manually optimized base frequency.
Haven't come around to try LLaMA-2 yet (but I'm curious too and will in the next days). |
The memory allocation change would work to reduce VRAM usage for LLaMA. But it does not universally reduce VRAM usage. If the largest temporary buffer comes first then VRAM usage will be 25% higher. I don't know the allocation patterns of other ggml projects so I can't comment on whether or not this PR would be beneficial or harmful for them. I think a better solution would be add a function that lets you specify the maximum expected buffer size (though this would not be mutually exclusive with this PR). Calling the function would then allocate such a buffer in the pool. The best solution I think would be to eliminate as many temporary buffers as possible via writing matrix multiplication kernels that can directly use quantized data (still WIP). Related: CUDA has functions |
I fully agree with you that this is not the ultimate solution to managing VRAM memory usage. It is just a quick band-air for now so one can run on CUDA with larger contexts. Here is what I get on master (after adding logging to
And here is what I get with this PR after enabling logging:
I don't think the behavior is model specific. |
This is sufficient it seems. We end up using about 200 MB less VRAM that way when running the 13B model with context 8192.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the memory pool change is simple enough and it's unlikely to cause a disruption across other projects. If this is the only concern, I recommend we merge the PR
I have reduced the 25% being added to the requested size to 5%. This is enough to run up to 8k context. |
This PR does two things:
ggml_cuda_pool_malloc()
to be more forward looking when allocating memory. With this change I'm able to run with 8k contexts on CUDA. Without the change, the maximum context before running out of VRAM is somewhere between 5k and 6k for 7B, and between 3.5k and 4k for 13B.I'm finding that the
--rope-freq-base
option works much better than--rope-freq-scale
. This graph, computed with the changes in this PR, shows wikitext perplexity for contexts up to 8192 for the LLaMA-1 7B, 13B, and 30B models. Shown is the ratio of the perplexity at a given context length to the perplexity at context = 2048 (max. training context for the LLaMA-1 models).The dependence of the (approximately) best RoPE frequency needed as a function of context length (best as resulting in the lowest perplexity) is shown in the next graph. This can be fit with a second order polynomial as
where
x = context size / 2048
.