Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml backends interface, ggml-cuda refactor #2239

Closed
wants to merge 21 commits into from
Closed

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Jul 16, 2023

Previous discussion: #2230

This PR adds a common interface to the compute backends.

Breaking changes

  • ggml_context allocates memory from a ggml_buffer that contains a buffer in device memory for the tensor data, and a buffer in system memory for the tensor structs (and in the future also other data such as the graphs)
  • The data member of ggml_tensor is a backend-specific pointer that should not be accessed directly. To access the data, ggml_backend_set_tensor and ggml_backend_get_tensor must be used instead. Functions such as ggml_new_f32 and ggml_set_f32 can also be used as before.
    • Technically, if you are using the CPU backend, you can still access the data member directly, but you shouldn't do that if you want to support other backends
    • I will probably change the name to something else to prevent current code from compiling without changes
  • Added a small params buffer to ggml_tensor for the op parameters that currently are stored in a tensor. For example, for ggml_rope this buffer is used to store the values n_past, n_dims, mode, n_ctx. The goal is to make these parameters easily accessible from the CPU, and reduce the overhead of creating a new tensor for them.

Brief example:

// initialize a backend
struct ggml_backend backend = ggml_backend_cpu_init();
// for CUDA:
// struct ggml_backend backend = ggml_backend_cuda_init();

// create a buffer
size_t ctx_size = 4096; // buffer size for the tensor data
size_t num_tensors = 10; // maximum number of tensors that can be allocated
struct ggml_buffer buf = ggml_backend_alloc_buffer(&backend, ctx_size, num_tensors);

// create a context using the buffer
struct ggml_init_params ggml_params = ggml_init_params_default();
ggml_params.buffer = &buf;
ggml_context * ctx = ggml_init(ggml_params);

// use the context to create a computation graph
struct ggml_tensor * x = ggml_new_f32(ctx, 2.0f);
struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
struct ggml_cgraph gf = ggml_build_forward(f);

// set the value of the input tensors
float a_val = 3.0f;
ggml_backend_set_tensor(a, &a_val, 0, sizeof(float));
// alternatively:
float b_val = 5.0f;
ggml_set_f32(b, b_val);

// run the computation
ggml_backend_graph_compute(&backend, &gf);

// get the result
float result;
ggml_backend_get_tensor(f, &result, 0, sizeof(float));
// alternatively:
// result = ggml_get_f32_1d(f, 0);

Backend implementation

Backends should implement the functions defined in the ggml_backend_interface struct. Currently there are implementations for the CPU and CUDA backends.

Computation using multiple backends

It is still possible to offload some parts of the graph to the GPU while keeping others on the CPU. This is done using ggml_graph_splits. See the llama.cpp code for an example, will update this later with more details.

Notes/limitations

  • Only the CPU and CUDA backends are currently supported.
  • Only the bare minimum necessary to run llama is currently implemented, don't look too much into the details of the code for now
  • When partially offloading a model to the GPU, internally this is handled by splitting the graphs into multiple parts, and running each of them in sequence in a different backend. Because the KV memory is either on the CPU or on the GPU, this means that there are at least as many graph executions as there are layers. The CPU backend creates and destroys the threads with every graph launch, so this can have a non-negligible overhead. Eventually this will be improved by using a thread pool.
  • On the plus side, the CPU threads are no longer spinning while the GPU backend is running
  • In the long term, the goal is to support automatic fallback to the CPU when an operation is not implemented in a backend, but that's going to take a while. This means that backends like OpenCL that mostly only implement matrix multiplication cannot be used for now. I am expecting that this backend will be replaced by a Vulkan backend that implements full GPU offloading, and won't be necessary anymore.
  • For prompt processing, it is usually worth to upload the weights every time even when using partial offloading, but that's currently not supported. So prompt processing will use the same CPU/GPU layer split as generation. Eventually I would like to support this again.

@slaren slaren added high priority Very important issue breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. refactoring Refactoring labels Jul 16, 2023
@Dampfinchen
Copy link

This sentence concerns me a bit: "For prompt processing, it is usually worth to upload the weights every time even when using partial offloading, but that's currently not supported. So prompt processing will use the same CPU/GPU layer split as generation. Eventually I would like to support this again."

Does that mean when the PR is merged, people using partial offloading (like me running only 10 layers on the GPU with a 13B model) will be experiencing significant longer prompt evaluation times?

If that's the case, perhaps it would be better to wait until the performance is atleast on the same level before merging. IMO LLama.cpp's most important feature that sets it apart from GPTQ and BNB is the possibility to run models too large for the hardware at great performance, it would be sad to lose that.

Then of course, there is the concern from AMD and Intel GPU owners. OpenCl has been an important backend for them to get decent speed on their hardware. So they would not benefit anymore from CPU performance improvements and newer quantization format in newer releases if this PR gets merged. Which would be a shame. Please keep in mind the Vulkan backend might take a long time to develop as it is a very complex task. If you are eager to supersede OpenCL in such a short time, perhaps you could lend a hand!

Aside from these concerns, I am very impressed by the awesome work you and others are doing and I fully support this PR! I'm just pointing out the effects it may have on people enjoying the software.

@slaren
Copy link
Collaborator Author

slaren commented Jul 16, 2023

I understand your concerns. I will add OpenCL support again in the same way that it works currently. Once we get closer to merging this, I will run some tests to see exactly what is the performance impact on the prompt processing speed and re-evaluate from there.

Don't get me wrong, I agree that these features are important, and I want to continue to support them. The issue is that PRs like this that affect a lot of code can have a significant maintenance overhead, because other changes merged in the meanwhile will often cause conflicts that must be resolved manually. Ultimately, it is a matter of whether we want to prioritize development speed or maintaining a stable master branch.

@casper-hansen
Copy link

Weighing in on the issue of breaking things versus shipping this feature faster.

I weigh heavily towards shipping faster and fixing things after. Developers should in general lock in on a commit (GPT4All does this) before this PR is merged so that they have a stable version until performance is stabilized afterwards and bugs are squashed.

The problem with trying to cater to solving these problems like slower prompt processing or CPU offloading is that it hinders long-term progress for a short-term benefit.

On the plus side, this PR will enable a future where more models can more easily be supported due to the backend setup. The faster it’s shipped, the faster we get performance optimization and generally better tech to support our LLMs.

ggml-backend.h Outdated Show resolved Hide resolved
@slaren
Copy link
Collaborator Author

slaren commented Jul 21, 2023

@JohannesGaessler how confident are you in the VRAM scratch buffer size estimation for context sizes > 2048? I have been working on ggml-org/ggml#288 and I get these sizes for the compute buffer (n_batch = 512):

n_ctx compute buffer (calculated) CUDA VRAM scratch buffer (master)
512 88.50 MB 288 MB
1024 152.00 MB 320 MB
2048 280.00 MB 384 MB
4096 536.00 MB 512 MB
8192 1048.00 MB 768 MB
16384 2072.00 MB 1280 MB

For smaller context sizes, the calculated memory is smaller (which is expected as this should be more efficient than scratch buffers), but for larger contexts it is bigger. It is likely that there are bugs in my implementation, though, but perplexity seems fine.

@ggerganov
Copy link
Owner

@slaren

I believe that @ikawrakow has also observed an issue with the currently estimated VRAM sizes for larger contexts (#2295 I haven't looked at details yet, so I could be wrong). But I guess your sizes might be actually the correct ones.

@JohannesGaessler
Copy link
Collaborator

See #2056 . I tested the VRAM scratch sizes with a precision of 1 MB, meaning I determined the lowest scratch size that still produces correct results. I tested up to 8192 context if possible. I observed that the amount of VRAM required increases linearly with context. I chose the VRAM scratch sizes on master to be the minimum required scratch size +25% at 2048 context. It's possible that the implementation on master is suboptimal and thus needs more VRAM for that reason.

@slaren
Copy link
Collaborator Author

slaren commented Jul 21, 2023

Thank you. It is not clear how you estimated the memory, it may be possible that we are using different parameters. What I do is construct a graph with these parameters and calculate the size of the compute buffer required to execute it:

int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
int n_past = hparams.n_ctx - n_tokens;

This should be equivalent to computing a prompt of n_batch tokens at the latest possible position in the context window, which I think is the worst case.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jul 21, 2023

What I did is hard-code a size for the VRAM scratch buffer, compile the code, and then test whether the results are correct. The values in the tables of in the linked PR are the minimum hard-coded values that still produced correct results.

@slaren
Copy link
Collaborator Author

slaren commented Jul 21, 2023

I understand that, what I am wondering about is how you ran the tests exactly. If what you did is different than this, then different results are expected:

This should be equivalent to computing a prompt of n_batch tokens at the latest possible position in the context window, which I think is the worst case.

@JohannesGaessler
Copy link
Collaborator

If I remember correctly I was calculating perplexity on the first 100 lines of Wikitext as a test.

@slaren
Copy link
Collaborator Author

slaren commented Jul 21, 2023

Ok, that should be equivalent when n_ctx is a multiple of n_batch, so I looked what tensors were allocated at the peak memory usage and I found that it was mostly KQ and KQ_scaled. Previously KQ_scaled was inplace, but I had removed inplace operations. So I added automatic inplace operations to the allocator, and I get these results now:

n_ctx compute buffer (calculated) VRAM scratch buffer
512 70.50 MB 288 MB
1024 88.00 MB 320 MB
2048 152.00 MB 384 MB
4096 280.00 MB 512 MB
8192 536.00 MB 768 MB
16384 1048.00 MB 1280 MB

So it should be a reduction in memory usage over master in all cases.

ggml-backend.c Outdated
}

void ggml_graph_allocate_tensors_n(struct ggml_cgraph ** graphs, int n_graphs) {
}
static bool ggml_is_view(struct ggml_tensor * t) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this name intentional? It seems confusing considering the collision with GGML_OP_VIEW.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use view generally to mean any operation that shares the memory of its parent tensor. In think the name makes sense, but I am open to suggestions.

break;
}
// TODO: make a list of operations that can be safely made inplace
if (parent->data != NULL && parent->n_children == 1 && parent->n_views == 0 && ggml_are_same_layout(node, parent) && node->op != GGML_OP_MUL_MAT) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used in LLaMA but GGML_CONV_2D should also not be done inplace.

@slaren
Copy link
Collaborator Author

slaren commented Jul 22, 2023

@ggerganov or anyone familiar with the training code, is there any way to tell what tensors must be kept for the backwards pass? maybe checking if ggml_tensor::grad is not NULL? I am not sure how to rewrite this after removing the scratch buffers:

// tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode, 0)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
use_buf(-1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode, 0)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
use_buf(-1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
use_buf(-1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
use_buf(-1); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
use_buf(-1); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
use_buf(-1); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
use_buf(-1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
use_buf( 0); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
use_buf(-1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
use_buf(-1); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
use_buf(-1); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
use_buf(-1); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
use_buf( 0); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);

@ggerganov
Copy link
Owner

ggerganov commented Jul 22, 2023

If you keep all tensors with grad != NULL you will be able to compute the backwards pass.
However, this is not optimal as it will keep pretty much everything from the forward pass, while depending on the operations, the values of some tensors will not be needed.

For example, let's take a look at the backward pass for GGML_OP_ADD:

llama.cpp/ggml.c

Lines 15187 to 15195 in 24baa54

case GGML_OP_ADD:
{
if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
}
if (src1->grad) {
src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
}
} break;

# x: src0
# y: src1
# z: tensor

# x': src0->grad
# y': src1->grad
# z': tensor->grad

# forward
z = x + y

# backward
x' += z'
y' += z'

I.e. to compute x' and y' in the backward pass, we don't need any of x, y or z from the forward pass.

In contrast, GGML_OP_MUL:

llama.cpp/ggml.c

Lines 15247 to 15263 in 24baa54

case GGML_OP_MUL:
{
if (src0->grad) {
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_mul(ctx, src1, tensor->grad),
inplace);
}
if (src1->grad) {
src1->grad =
ggml_add_impl(ctx,
src1->grad,
ggml_mul(ctx, src0, tensor->grad),
inplace);
}
} break;

# forward
z = x * y

# backward
x' += y*z'
y' += x*z'

So here we need x and y from the forward pass.

I think @xaedes has manually determined which tensors need to be kept. Not sure if there is better solution except for building the backwards graph and checking which tensors from the forward pass are used.

@slaren
Copy link
Collaborator Author

slaren commented Jul 22, 2023

If it is determined manually, I could add a parameter to the allocator to specify the list of tensors that must not be freed. That should achieve the same goal of minimizing memory usage, while still keeping the data of the important tensors. It could possibly be a list in ggml_cgraph, let's call it outputs to represent the list of outputs of a graph.

except for building the backwards graph and checking which tensors from the forward pass are used.

This would be interesting for a future refactor of the training code, but I don't understand it well enough to do this at this point.

@ggerganov
Copy link
Owner

Ok, we can do that for now - it should be better than the current approach, thought it might be possible to improve it even further in the future to do everything automatically. Remind me, at what point does the allocator run - during ggml_build_forward()?

@slaren
Copy link
Collaborator Author

slaren commented Jul 22, 2023

Currently, it is called manually, but the idea is it should be done automatically after building the graph. Essentially, when you create a context, you can specify the tensor allocation mode. For the graph allocator (GGML_ALLOC_COMPUTE_SEQ), this should happen after building the graph in ggml_build_forward. The default is GGML_ALLOC_IMMEDIATE (same behavior as now).

llama.cpp/llama.cpp

Lines 1192 to 1197 in d273bfd

struct ggml_init_params params = ggml_init_params_default();
params.buffer = buf_compute;
params.alloc_mode = GGML_ALLOC_COMPUTE_SEQ;
//params.alloc_mode = GGML_ALLOC_IMMEDIATE;
params.compute_type = compute_type;
ggml_context * ctx_buf = ggml_init(params);

@slaren
Copy link
Collaborator Author

slaren commented Aug 25, 2023

Closing this as it is too outdated to continue from this point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. refactoring Refactoring
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants