Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml backends interface, ggml-cuda refactor #2230

Closed
wants to merge 4 commits into from

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Jul 15, 2023

Continued in #2239

This PR adds a common interface to the compute backends.

Breaking changes

  • ggml_context allocates memory from a ggml_buffer that contains a buffer in device memory for the tensor data, and a buffer in system memory for the tensor structs (and in the future also other data such as the graphs)
  • The data member of ggml_tensor is a backend-specific pointer that should not be accessed directly. To access the data, ggml_backend_set_tensor and ggml_backend_get_tensor must be used instead. Functions such as ggml_new_f32 and ggml_set_f32 can also be used as before.
    • Technically, if you are using the CPU backend, you can still access the data member directly, but you shouldn't do that if you want to support other backends
    • I will probably change the name to something else to prevent current code from compiling without changes
  • Added a small params buffer to ggml_tensor for the op parameters that currently are stored in a tensor. For example, for ggml_rope this buffer is used to store the values n_past, n_dims, mode, n_ctx. The goal is to make these parameters easily accessible from the CPU, and reduce the overhead of creating a new tensor for them.

Brief example:

// initialize a backend
struct ggml_backend backend = ggml_backend_cpu_init();
// for CUDA:
// struct ggml_backend backend = ggml_backend_cuda_init();

// create a buffer
size_t ctx_size = 4096; // buffer size for the tensor data
size_t num_tensors = 10; // maximum number of tensors that can be allocated
struct ggml_buffer buf = ggml_backend_alloc_buffer(&backend, ctx_size, num_tensors);

// create a context using the buffer
struct ggml_init_params ggml_params = ggml_init_params_default();
ggml_params.buffer = &buf;
ggml_context * ctx = ggml_init(ggml_params);

// use the context to create a computation graph
struct ggml_tensor * x = ggml_new_f32(ctx, 2.0f);
struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
struct ggml_cgraph gf = ggml_build_forward(f);

// set the value of the input tensors
float a_val = 3.0f;
ggml_backend_set_tensor(a, &a_val, 0, sizeof(float));
// alternatively:
float b_val = 5.0f;
ggml_set_f32(b, b_val);

// run the computation
ggml_backend_graph_compute(&backend, &gf);

// get the result
float result;
ggml_backend_get_tensor(f, &result, 0, sizeof(float));
// alternatively:
// result = ggml_get_f32_1d(f, 0);

Backend implementation

Backends should implement the functions defined in the ggml_backend_interface struct. Currently there are implementations for the CPU and CUDA backends.

Computation using multiple backends

It is still possible to offload some parts of the graph to the GPU while keeping others on the CPU. This is done using ggml_graph_splits. See the llama.cpp code for an example, will update this later with more details.

Notes/limitations

  • Only the CPU and CUDA backends are currently supported.
  • Only the bare minimum necessary to run llama is currently implemented, don't look too much into the details of the code for now
  • When partially offloading a model to the GPU, internally this is handled by splitting the graphs into multiple parts, and running each of them in sequence in a different backend. Because the KV memory is either on the CPU or on the GPU, this means that there are at least as many graph executions as there are layers. The CPU backend creates and destroys the threads with every graph launch, so this can have a non-negligible overhead. Eventually this will be improved by using a thread pool.
  • On the plus side, the CPU threads are no longer spinning while the GPU backend is running
  • In the long term, the goal is to support automatic fallback to the CPU when an operation is not implemented in a backend, but that's going to take a while. This means that backends like OpenCL that mostly only implement matrix multiplication cannot be used for now. I am expecting that this backend will be replaced by a Vulkan backend that implements full GPU offloading, and won't be necessary anymore.
  • For prompt processing, it is usually worth to upload the weights every time even when using partial offloading, but that's currently not supported. So prompt processing will use the same CPU/GPU layer split as generation. Eventually I would like to support this again.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to review this PR regularly but feel free to @ me if you want me to look at something in particular.

CMakeLists.txt Show resolved Hide resolved
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler
ifdef LLAMA_DEBUG
NVCCFLAGS += -lineinfo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on what the intended use of -lineinfo is this could be added unconditionally (I don't think it makes a difference for performance). When I use Nsight Compute I typically use it without LLAMA_DEBUG since that also affects compiler optimizations.

}

tensor->data = (char*)cpu_buffer->data + cpu_buffer->offset;
cpu_buffer->offset = aligned_offset(cpu_buffer->data, cpu_buffer->offset + ggml_nbytes(tensor), TENSOR_ALIGNMENT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the GCC equivalent but CUDA has a compiler hint that lets you specify memory alignment. Perhaps something like this could be useful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks interesting, but I imagine that you would have to use it in the kernels for the compiler to actually notice anything. It may be worth checking if adding that to all the data pointers in the kernels improves performance. The tensor allocator in the ggml-cuda backend always aligns pointers 128 bytes, so it is safe to assume that tensor data pointers are aligned to at least that.

Copy link
Collaborator Author

@slaren slaren Jul 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About using this on the CPU code, there are AVX instructions both for aligned and unaligned load and stores. We always use the unaligned instructions because in practice in current CPUs it doesn't seem to make much of a difference in performance, and depending on the types and row sizes the data may not always be aligned.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could maybe add __forceinline__ to the helper functions but I would assume the compiler does it anyways.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not convinced that it makes a difference, I trust that the compiler is smart enough to inline device functions. So unless there is a measurable difference, my preference is to leave these decisions to the compiler.

// reduce warps
T warp_reduction = warp_reduce_all<op_t>(val);

__syncthreads();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think synchronization is needed here. each warp writes to a different location so there should be no data race.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a weird data race that can happen here, I think it is related to the read of lane_result later. I am not sure if any of the kernels here are affected by this, but it was an issue in a different kernel that it is not here yet.

const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to re-add the 3b fix added in #2144 .

ggml_type_name(t0), ggml_type_name(t1), ggml_type_name(t2));
}

GGML_ASSERT(dispatch.d[t0][t1][t2] && "Unsupported type combination");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this check do compared to the previous one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous check prints the types so that I can know what is the problem and will be eventually removed. The GGML_ASSERT is to crash the program when that happens and will remain. I will probably change this part significantly to account for different number of arguments to ops.

Comment on lines -2624 to -2625
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you intend to actually allow the use of split tensors for src1 and dst?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all of that code has been removed for now for simplicity, but I will add it again. I think I will do this with a special type of ggml_buffer that allocates split tensors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the exact design for the VRAM scratch buffer? is it manually being reset like it's the case for the RAM scratch buffers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be addressed as proposed here: ggml-org/ggml#288
After that, scratch buffers will not be necessary, since the compute buffers will already be as small as they can be.

llama.cpp Outdated Show resolved Hide resolved
@ggerganov ggerganov added high priority Very important issue refactoring Refactoring labels Jul 15, 2023
@slaren slaren added the breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. label Jul 15, 2023
@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

Updated the description with more details.
@ggerganov @JohannesGaessler @0cc4m @niansa @evanmiller and anybody else who may be interested in working on a backend, please take a look at the backend interface in ggml-backend.h and let me know if that would work for your case. There are implementations for the CPU backend in ggml-backend.c, and for the CUDA backend in ggml-cuda.cu.

@JohannesGaessler
Copy link
Collaborator

The CPU backend creates and destroys the threads with every graph launch, so this can have a non-negligible overhead. Eventually this will be improved by using a thread pool.

Do you mean to say that you will implement it? I'm asking because I was thinking of doing this myself and would like to avoid duplicate work.

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

@JohannesGaessler if you want to do it yourself, that would be great! I have plenty to do already.
Needless to say, any help is welcome.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2023

Please correct me if I'm wrong, but it looks like the ggml-backend.h is only for backends that can run an entire (part of a) graph. While this may be the end goal for most backends, they are built up through many intermediate steps that don't fit into this framework. OpenCL cannot offload a full graph, and neither can my Vulkan approach (for now).

Do you have a plan on how to accomodate backends in development and "partial backends", so to say?

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

@0cc4m I covered it briefly in the description:

In the long term, the goal is to support automatic fallback to the CPU when an operation is not implemented in a backend, but that's going to take a while. This means that backends like OpenCL that mostly only implement matrix multiplication cannot be used for now. I am expecting that this backend will be replaced by a Vulkan backend that implements full GPU offloading, and won't be necessary anymore.

For now, this is not supported, but I would like to support it eventually.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2023

It may be replaced eventually by a Vulkan backend, or it may stay alongside it. But even the Vulkan backend will stay partial for quite a while. Is your plan to implement this fallback before this PR is ready to merge?

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

I will think about it. My plan is to do it later, the scope of this PR is already quite large as it is, but maybe I can support that case without too many changes.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2023

Removing other backends in the name of a cuda refactor is not acceptable. I support the goal of this PR, but it has to be less disruptive.

@niansa
Copy link
Contributor

niansa commented Jul 15, 2023

Mmmh the idea really makes a lot of sense and I like it, but I feel like different implementations may have different needs that aren't covered by this and are still going to fall back to #ifdefing in llama.cpp or ggml.c...

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

@niansa I would like to avoid that, which is why I am asking for your feedback. If there is something that isn't covered by this API, please let me know.

@JohannesGaessler
Copy link
Collaborator

I think as of right now the OpenCL implementation still has all data except for the weights in RAM. In other words, the results get copied to RAM after each operation. Would it be possible to just forward the non GPU accelerated tensors to the CPU backend for the OpenCL backend?

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

Fallback to CPU would work exactly like that, the hard part is making it in a way that doesn't litter the user code with backend-specific details, and also keeps the implementation of the backends as simple as possible.
Really, this could already be done with this implementation if you just create a new graph split for each matrix multiplication. What the splits interface does is just to copy the data between backends as required. But obviously that would make the graph building code unreadable, so it needs to be done in a more automated way. It can be done, but the scope of the PR needs to be kept reasonable, or it will never be completed.

@henk717
Copy link

henk717 commented Jul 15, 2023

OpenCL is currently relied on by AMD and Intel users for speed increases using their dedicated GPU. Removing it would block off entire vendors on Windows and should not be considered acceptable. Especially users with slower CPU's would be stuck on older versions of the project.

We have a lot of AMD users currently relying on Llamacpp to use their GPU for AI as an easy solution.

@JohannesGaessler
Copy link
Collaborator

What would prevent us from keeping the status quo for OpenCL? From what I can tell the entry points for OpenCL kernels is still in the corresponding kernels in ggml.c. So couldn't you run a CPU graph with some OpenCL tensors in that graph for the offloaded weights?

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

It could work if we keep two versions of the model loading code in llama.cpp, one for the new interface and another for OpenCL. It's something I would prefer to avoid, ultimately the goal is to simplify the code and remove as many of the special cases as possible, but as a temporary solution it may be better than removing the OpenCL backend.

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

It would be interesting to see performance on native Linux, and also native Windows. For me under WSL with a RTX 3080, this is significantly faster than master. I think this mainly is because it requires less synchronization, so kernel launches can be queued and this helps hiding the latency of launching a kernel. I expect that the effect will be minimal, if any, under native Linux, though.

7B q4_0:

PR:
-ngl 99: llama_print_timings:        eval time =  1163.00 ms /   127 runs   (    9.16 ms per token,   109.20 tokens per second)
-ngl 30: llama_print_timings:        eval time =  2083.09 ms /   127 runs   (   16.40 ms per token,    60.97 tokens per second)
-ngl 20: llama_print_timings:        eval time =  6841.63 ms /   127 runs   (   53.87 ms per token,    18.56 tokens per second)
-ngl 10: llama_print_timings:        eval time = 13264.34 ms /   127 runs   (  104.44 ms per token,     9.57 tokens per second)


master:
-ngl 99: llama_print_timings:        eval time =  3140.24 ms /   127 runs   (   24.73 ms per token,    40.44 tokens per second)
-ngl 30: llama_print_timings:        eval time =  4398.86 ms /   127 runs   (   34.64 ms per token,    28.87 tokens per second)
-ngl 20: llama_print_timings:        eval time =  8466.26 ms /   127 runs   (   66.66 ms per token,    15.00 tokens per second)
-ngl 10: llama_print_timings:        eval time = 13212.88 ms /   127 runs   (  104.04 ms per token,     9.61 tokens per second)

@YellowRoseCx
Copy link
Contributor

It could work if we keep two versions of the model loading code in llama.cpp, one for the new interface and another for OpenCL. It's something I would prefer to avoid, ultimately the goal is to simplify the code and remove as many of the special cases as possible, but as a temporary solution it may be better than removing the OpenCL backend.

Maintaining two versions of the model loading code in llama.cpp may not be ideal, but it would prevent disruptions and provide a smoother transition instead of leaving openCL users stuck with older versions of the project

@JohannesGaessler
Copy link
Collaborator

I think this mainly is because it requires less synchronization, so kernel launches can be queued and this helps hiding the latency of launching a kernel.

With all layers offloaded there should only be 2 calls to cudaDeviceSynchronize though: one at the end and another one when the results of embeddings get copied to RAM. If anything I think the difference comes from concurrent kernel execution.

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

The concurrent kernel execution is not enabled currently:

static const int GGML_CUDA_MAX_SUBSTREAMS = 1;
static const bool GGML_CUDA_SEQ_COMPUTE = true;

This is because currently it creates way too many streams and events, and it may actually harm performance. But it will help once that is fixed.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jul 15, 2023

Regardless of the reasons, simple testing seems to confirm that the refactor provides a speedup:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 tg128 121.64 132.45 1.09
RTX 3090 13b q4_0 tg128 69.59 74.07 1.06

Edit: my numbers are for native Linux 6.3.5-2.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jul 15, 2023

I forgot to mention: I can't compile the PR unless I modify it. The problem seems to be line 17 in ggml-cuda-kern.h, more precisely the instruction __halves2half. If I replace that instruction with make_half2 I can compile the code.

Compilation log
/home/johannesg/Projects/llama.cpp [git::cuda-backend *] [johannesg@johannes-pc] [0:08]
> make clean && make LLAMA_CUDA=1 LLAMA_CUDA_DMMV_X=64 LLAMA_CUDA_MMV_Y=2 LLAMA_CUDA_DMMV_F16=1 libllama.so quantize main perplexity  
I llama.cpp build info: 
I UNAME_S:  Linux
I UNAME_P:  unknown
I UNAME_M:  x86_64
I CFLAGS:   -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS
I CXXFLAGS: -I. -I./examples -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS
I LDFLAGS:  
I CC:       cc (GCC) 13.1.1 20230429
I CXX:      g++ (GCC) 13.1.1 20230429

rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server simple vdot train-text-from-scratch embd-input-test build-info.h
removed 'common.o'
removed 'ggml-cuda.o'
removed 'ggml.o'
removed 'k_quants.o'
removed 'llama.o'
removed 'libllama.so'
removed 'main'
removed 'quantize'
removed 'perplexity'
removed 'build-info.h'
I llama.cpp build info: 
I UNAME_S:  Linux
I UNAME_P:  unknown
I UNAME_M:  x86_64
I CFLAGS:   -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include
I CXXFLAGS: -I. -I./examples -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include
I LDFLAGS:  -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L/opt/cuda/targets/x86_64-linux/lib
I CC:       cc (GCC) 13.1.1 20230429
I CXX:      g++ (GCC) 13.1.1 20230429
I NVCC:     

g++ -I. -I./examples -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include -c llama.cpp -o llama.o
cc  -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include   -c ggml.c -o ggml.o
cc -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include   -c -o k_quants.o k_quants.c
nvcc --forward-unknown-to-host-compiler -arch=native -DGGML_CUDA_DMMV_X=64 -DGGML_CUDA_MMV_Y=2 -DGGML_CUDA_DMMV_F16 -DK_QUANTS_PER_ITERATION=2 -I. -I./examples -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include -Wno-pedantic -c ggml-cuda.cu -o ggml-cuda.o
cc  -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include   -c ggml-backend.c -o ggml-backend.o
g++ -I. -I./examples -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include -c examples/common.cpp -o common.o
llama.cpp: In function ‘size_t llama_get_state_size(const llama_context*)’:
llama.cpp:2985:1: warning: no return statement in function returning non-void [-Wreturn-type]
 2985 | }
      | ^
llama.cpp:2955:58: warning: unused parameter ‘ctx’ [-Wunused-parameter]
 2955 | size_t llama_get_state_size(const struct llama_context * ctx) {
      |                             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~
llama.cpp: In function ‘size_t llama_copy_state_data(llama_context*, uint8_t*)’:
llama.cpp:3086:1: warning: no return statement in function returning non-void [-Wreturn-type]
 3086 | }
      | ^
llama.cpp:2988:53: warning: unused parameter ‘ctx’ [-Wunused-parameter]
 2988 | size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
      |                              ~~~~~~~~~~~~~~~~~~~~~~~^~~
llama.cpp:2988:68: warning: unused parameter ‘dst’ [-Wunused-parameter]
 2988 | size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
      |                                                          ~~~~~~~~~~^~~
llama.cpp: In function ‘size_t llama_set_state_data(llama_context*, uint8_t*)’:
llama.cpp:3195:1: warning: no return statement in function returning non-void [-Wreturn-type]
 3195 | }
      | ^
llama.cpp:3089:52: warning: unused parameter ‘ctx’ [-Wunused-parameter]
 3089 | size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
      |                             ~~~~~~~~~~~~~~~~~~~~~~~^~~
llama.cpp:3089:67: warning: unused parameter ‘src’ [-Wunused-parameter]
 3089 | size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
      |                                                         ~~~~~~~~~~^~~
llama.cpp: In function ‘int llama_eval_export(llama_context*, const char*)’:
llama.cpp:3328:46: warning: unused parameter ‘ctx’ [-Wunused-parameter]
 3328 | int llama_eval_export(struct llama_context * ctx, const char * fname) {
      |                       ~~~~~~~~~~~~~~~~~~~~~~~^~~
llama.cpp:3328:64: warning: unused parameter ‘fname’ [-Wunused-parameter]
 3328 | int llama_eval_export(struct llama_context * ctx, const char * fname) {
      |                                                   ~~~~~~~~~~~~~^~~~~
ggml-cuda-kern.h(17): error: calling a __device__ function("__halves2half2(    ::__half,     ::__half)") from a __host__ __device__ function("make_vec2_t") is not allowed
  template<> inline __attribute__((host)) __attribute__((device)) vec2_t<half> make_vec2_t(const half & x, const half & y) { return __halves2half2(x, y); }
                                                                                                                                    ^

1 error detected in the compilation of "ggml-cuda.cu".
make: *** [Makefile:205: ggml-cuda.o] Error 2
make: *** Waiting for unfinished jobs....

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

Weird, what version of the CUDA Toolkit are you using? It works for me with 12.2:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jun_13_19:16:58_PDT_2023
Cuda compilation tools, release 12.2, V12.2.91
Build cuda_12.2.r12.2/compiler.32965470_0

@JohannesGaessler
Copy link
Collaborator

I'm on version 12.1:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

@YellowRoseCx
Copy link
Contributor

I know the ROCm port is not merged, but I also see it would completely break CUDA on AMD GPUs through ROCm because the hipBLAS API doesn't support CUBLAS_COMPUTE_16F or CUBLAS_COMPUTE_32F_FAST_TF32

@YellowRoseCx YellowRoseCx mentioned this pull request Jul 15, 2023
@slaren
Copy link
Collaborator Author

slaren commented Jul 15, 2023

Just define them to HIPBLAS_R_16F and HIPBLAS_R_32F.

@ggerganov
Copy link
Owner

ggerganov commented Jul 16, 2023

I will look into migrating the Metal and MPI backends to fit into the proposed backend interface soon and will open a PR to this one. @slaren Might be a good idea to move the branch in this repo, so that anyone who wants to propose a change can PR it here - otherwise we'll have to keep an eye for PRs into your fork

I started implementing a custom cloud CI yesterday, and I hope I am able to finish it today. After that will start looking into this work in more details.

Edit: on second thought, maybe keep the PR as it is, because we would lose the discussion if you recreate the PR. Whatever you decide is better

@slaren
Copy link
Collaborator Author

slaren commented Jul 16, 2023

I think it is more important to keep all the PRs here, so I'll open a new PR to move the branch to this repo. I'll add a link to this PR so that the current discussion isn't lost.

@slaren
Copy link
Collaborator Author

slaren commented Jul 16, 2023

Continued in #2239

@slaren slaren closed this Jul 16, 2023
@slaren slaren mentioned this pull request Jul 22, 2023
@JohannesGaessler JohannesGaessler mentioned this pull request Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. high priority Very important issue refactoring Refactoring
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants