Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Jul 15, 2023
1 parent 09ab5c1 commit 83595ec
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
8 changes: 4 additions & 4 deletions ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct ggml_buffer ggml_backend_alloc_buffer(struct ggml_backend * backend, size
buffer.mem_size = ggml_tensor_overhead() * max_tensors;
buffer.mem_buffer = malloc(buffer.mem_size);
buffer.backend = backend;
// size += 128 * max_tensors; // alignment overhead
size += 128 * max_tensors; // alignment overhead
buffer.backend_buffer = backend->interface->alloc_buffer(backend->context, size);
return buffer;
}
Expand Down Expand Up @@ -172,7 +172,7 @@ static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_context_t ctx, struct
}

static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_context_t ctx, struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_set_tensor(dst, src->data, 0, ggml_nbytes(src));
ggml_backend_set_tensor_async(dst, src->data, 0, ggml_nbytes(src));

UNUSED(ctx);
}
Expand Down Expand Up @@ -409,7 +409,7 @@ void ggml_graph_splits_compute(struct ggml_graph_splits * splits) {
ggml_backend_cpy_tensor(split->dst_inputs[j], split->src_inputs[j]);
}
}
ggml_backend_synchronize(split->dst_inputs[0]->backend);
// ggml_backend_synchronize(split->dst_inputs[0]->backend);
copy_us += ggml_time_us() - copy_start_us;

#if 0
Expand All @@ -419,7 +419,7 @@ void ggml_graph_splits_compute(struct ggml_graph_splits * splits) {
#endif
uint64_t start = ggml_time_us();
ggml_backend_graph_compute(split->dst_inputs[0]->backend, split->graph);
ggml_backend_synchronize(split->dst_inputs[0]->backend);
//ggml_backend_synchronize(split->dst_inputs[0]->backend);
uint64_t end = ggml_time_us();
if (strcmp(ggml_backend_name(split->dst_inputs[0]->backend), "CPU") == 0) {
compute_cpu_us += end - start;
Expand Down
44 changes: 26 additions & 18 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,9 @@ struct llama_model_loader {
}
LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already

bool is_cpu = lt.ggml_tensor->backend == &model->backend_cpu; // TODO
bool is_cpu = lt.ggml_tensor->backend == &model->backend_cpu;

// select buffer to load data into
if (!use_mmap) {
if (is_cpu) {
lt.data = (uint8_t *) lt.ggml_tensor->data;
Expand All @@ -638,7 +639,7 @@ struct llama_model_loader {
if (is_cpu) {
if (use_mmap) {
lt.ggml_tensor->data = lt.data;
// TODO: this assumes that the data is contiguous, which may not always be the case
// TODO: this assumes that the data to lock is contiguous, which may not always be the case
if (lmlock) {
lock_size += lt.size;
lmlock->grow_to(lock_size);
Expand Down Expand Up @@ -1199,6 +1200,10 @@ static ggml_graph_splits llama_build_graph(
inpL = ggml_get_rows(ctx_i, model.tok_embeddings, token_in);
}

// reuse the scale tensor for all layers since it requires a memory transfer
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head));
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");

struct ggml_tensor * cur = nullptr;
for (int il = 0; il < n_layer; ++il) {
struct ggml_context * ctx_l = ctx_ls[il];
Expand Down Expand Up @@ -1239,9 +1244,6 @@ static ggml_graph_splits llama_build_graph(
struct ggml_tensor * Vcur = ggml_transpose(ctx_l, ggml_reshape_2d(ctx_l, tmpv, n_embd, N));
ggml_set_name(Vcur, "Vcur");

//ggml_graph_splits_add(&splits, &Kcur, ctx_kv, "Kcur");
//ggml_graph_splits_add(&splits, &Vcur, ctx_kv, "Vcur");
//ggml_graph_splits_add(&splits, &Qcur, ctx_kv, "Qcur");
ggml_tensor ** attn_inputs[] = {&Kcur, &Vcur, &Qcur, NULL};
ggml_graph_splits_add_n(&splits, attn_inputs, ctx_kv, "l%d_attn", il);

Expand Down Expand Up @@ -1288,9 +1290,6 @@ static ggml_graph_splits llama_build_graph(
ggml_set_name(KQ, "KQ");

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head));
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");

// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx_kv, KQ, KQ_scale);
ggml_set_name(KQ_scaled, "KQ_scaled");
Expand Down Expand Up @@ -1367,7 +1366,7 @@ static ggml_graph_splits llama_build_graph(
cur = ggml_mul_mat(ctx_l,
model.layers[il].w1,
cur);
ggml_set_name(cur, "result_w2");
ggml_set_name(cur, "result_w1");

// SILU activation
cur = ggml_silu(ctx_l, cur);
Expand Down Expand Up @@ -1503,6 +1502,12 @@ static bool llama_eval_internal(

LLAMA_ASSERT(lctx.graph_logits != nullptr);


// for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(&model.backend_cpu), n_threads);

struct ggml_graph_splits splits = llama_build_graph(lctx, N, n_past, embd_input);

// TODO: use backend functions
Expand All @@ -1514,11 +1519,7 @@ static bool llama_eval_internal(
ggml_backend_set_tensor(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in));
}

// for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;

ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(&model.backend_cpu), n_threads);

// run the computation
ggml_graph_splits_compute(&splits);
Expand All @@ -1545,21 +1546,28 @@ static bool llama_eval_internal(

if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
ggml_backend_get_tensor(lctx.graph_logits, logits_out.data(), 0, N*n_vocab*sizeof(float));
ggml_backend_get_tensor_async(lctx.graph_logits, logits_out.data(), 0, N*n_vocab*sizeof(float));
} else {
// return result for just the last token
logits_out.resize(n_vocab);
ggml_backend_get_tensor(lctx.graph_logits, logits_out.data(), 0, n_vocab*sizeof(float));
ggml_backend_get_tensor_async(lctx.graph_logits, logits_out.data(), 0, n_vocab*sizeof(float));
}
}

// extract embeddings
if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding;
embedding_out.resize(n_embd);
ggml_backend_get_tensor(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
ggml_backend_get_tensor_async(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
}

#ifdef GGML_USE_CUDA
// wait for the async copy to finish
if (lctx.model.n_gpu_layers > 0) {
ggml_backend_synchronize(const_cast<ggml_backend*>(&lctx.model.backend_cuda));
}
#endif

// measure the performance only for the single-token evals
if (N == 1) {
lctx.t_eval_us += ggml_time_us() - t_start_us;
Expand Down Expand Up @@ -2543,7 +2551,7 @@ struct llama_context * llama_new_context_with_model(
// initialize the graph input/output buffers
// input buffer
{
size_t buf_input_size = 1024;
size_t buf_input_size = 0;
buf_input_size += hparams.n_ctx * ggml_type_size(GGML_TYPE_F32); // input tokens
// TODO: input embeddings should be optional to save memory
buf_input_size += hparams.n_embd * hparams.n_ctx * ggml_type_size(GGML_TYPE_F32); // input embeddings
Expand All @@ -2562,7 +2570,7 @@ struct llama_context * llama_new_context_with_model(
}
// output buffer
{
size_t buf_output_size = 1024;
size_t buf_output_size = 0;
if (params.logits_all) {
buf_output_size += hparams.n_ctx * hparams.n_vocab * ggml_type_size(GGML_TYPE_F32);
} else {
Expand Down

0 comments on commit 83595ec

Please sign in to comment.