diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 97325b5bd634f..1010f04b7bdd1 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1149,7 +1149,8 @@ int main(int argc, char ** argv) { // warmup run if (t.n_prompt > 0) { - test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); + //test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); + test_prompt(ctx, std::min(t.n_prompt, 32), 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { test_gen(ctx, 1, 0, t.n_threads); diff --git a/ggml-alloc.c b/ggml-alloc.c index 89b85d34870d7..beb557997cb84 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -319,6 +319,13 @@ struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) { return alloc->buffer; } +void ggml_tallocr_set_buffer(ggml_tallocr_t talloc, struct ggml_backend_buffer * buffer) { + talloc->buffer = buffer; + talloc->base = ggml_backend_buffer_get_base(buffer); + talloc->alignment = ggml_backend_buffer_get_alignment(buffer); + ggml_tallocr_reset(talloc); +} + void ggml_tallocr_free(ggml_tallocr_t alloc) { if (alloc == NULL) { return; diff --git a/ggml-alloc.h b/ggml-alloc.h index 4e59975213406..08c3d84d36d8c 100644 --- a/ggml-alloc.h +++ b/ggml-alloc.h @@ -59,6 +59,7 @@ GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_b GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend); GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc); +GGML_API void ggml_tallocr_set_buffer(ggml_tallocr_t talloc, struct ggml_backend_buffer * buffer); GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc); GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc); diff --git a/llama.cpp b/llama.cpp index 2190ea7aa92c2..ec09c2ded8da9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1663,7 +1663,9 @@ struct llama_context { std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; // allocator for the input tensors - ggml_tallocr * alloc = nullptr; + ggml_tallocr * alloc_cpu = nullptr; + + std::vector buf_cpu_ub; // temporary buffer for copying data to/from the backend std::vector> buf_copy; @@ -3208,7 +3210,8 @@ static bool llm_load_tensors( const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0); // there is very little benefit to offloading the input layer, so always keep it on the CPU - model.buft_input = llama_default_buffer_type_cpu(true); + //model.buft_input = llama_default_buffer_type_cpu(true); + model.buft_input = llama_default_buffer_type_offload(main_gpu); model.buft_layer.resize(n_layer); @@ -5955,7 +5958,7 @@ static struct ggml_cgraph * llama_build_graph( const auto & model = lctx.model; // check if we should build the worst-case graph (for memory measurement) - const bool worst_case = ggml_tallocr_is_measure(lctx.alloc); + const bool worst_case = ggml_tallocr_is_measure(lctx.alloc_cpu); // keep track of the input that has already been allocated bool alloc_inp_tokens = false; @@ -5978,9 +5981,9 @@ static struct ggml_cgraph * llama_build_graph( // if (!alloc_inp_tokens && strcmp(name, "inp_tokens") == 0) { - ggml_tallocr_alloc(lctx.alloc, cur); + ggml_tallocr_alloc(lctx.alloc_cpu, cur); - if (!ggml_tallocr_is_measure(lctx.alloc) && batch.token) { + if (!ggml_tallocr_is_measure(lctx.alloc_cpu) && batch.token) { const int64_t n_tokens = cur->ne[0]; ggml_backend_tensor_set(cur, batch.token, 0, n_tokens*ggml_element_size(cur)); @@ -5990,9 +5993,9 @@ static struct ggml_cgraph * llama_build_graph( } if (!alloc_inp_embd && strcmp(name, "inp_embd") == 0 && batch.embd) { - ggml_tallocr_alloc(lctx.alloc, cur); + ggml_tallocr_alloc(lctx.alloc_cpu, cur); - if (!ggml_tallocr_is_measure(lctx.alloc) && batch.embd) { + if (!ggml_tallocr_is_measure(lctx.alloc_cpu) && batch.embd) { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; @@ -6003,9 +6006,9 @@ static struct ggml_cgraph * llama_build_graph( } if (!alloc_inp_pos && strcmp(name, "inp_pos") == 0) { - ggml_tallocr_alloc(lctx.alloc, cur); + ggml_tallocr_alloc(lctx.alloc_cpu, cur); - if (!ggml_tallocr_is_measure(lctx.alloc) && batch.pos) { + if (!ggml_tallocr_is_measure(lctx.alloc_cpu) && batch.pos) { const int64_t n_tokens = cur->ne[0]; static_assert(std::is_same::value, "llama_pos must be int32_t"); @@ -6016,9 +6019,9 @@ static struct ggml_cgraph * llama_build_graph( } if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) { - ggml_tallocr_alloc(lctx.alloc, cur); + ggml_tallocr_alloc(lctx.alloc_cpu, cur); - if (!ggml_tallocr_is_measure(lctx.alloc)) { + if (!ggml_tallocr_is_measure(lctx.alloc_cpu)) { const int64_t n_kv = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; @@ -6056,9 +6059,9 @@ static struct ggml_cgraph * llama_build_graph( } if (!alloc_inp_K_shift && strcmp(name, "K_shift") == 0) { - ggml_tallocr_alloc(lctx.alloc, cur); + ggml_tallocr_alloc(lctx.alloc_cpu, cur); - if (!ggml_tallocr_is_measure(lctx.alloc)) { + if (!ggml_tallocr_is_measure(lctx.alloc_cpu)) { const int64_t n_ctx = cur->ne[0]; int32_t * data; @@ -6161,10 +6164,11 @@ static struct ggml_cgraph * llama_build_graph( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch) { - const uint32_t n_tokens = batch.n_tokens; + llama_batch all_batch) { + + const uint32_t n_tokens_all = all_batch.n_tokens; - if (n_tokens == 0) { + if (n_tokens_all == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); return -1; } @@ -6173,12 +6177,11 @@ static int llama_decode_internal( const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - const auto n_batch = cparams.n_batch; + //const auto n_batch = cparams.n_batch; - GGML_ASSERT(n_tokens <= n_batch); + GGML_ASSERT((!all_batch.token && all_batch.embd) || (all_batch.token && !all_batch.embd)); // NOLINT - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + GGML_ASSERT(n_tokens_all <= cparams.n_ctx); const int64_t t_start_us = ggml_time_us(); @@ -6188,205 +6191,255 @@ static int llama_decode_internal( //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif - GGML_ASSERT(n_threads > 0); - auto & kv_self = lctx.kv_self; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - std::vector pos; - - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; - if (batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = batch.all_pos_0 + i*batch.all_pos_1; - } + auto & logits_out = lctx.logits; - batch.pos = pos.data(); + if (all_batch.logits) { + logits_out.resize(n_vocab * n_tokens_all); + } else if (lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens_all); + } else { + logits_out.resize(n_vocab); } - if (batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } +#ifndef NDEBUG + auto & logits_valid = lctx.logits_valid; + logits_valid.clear(); + logits_valid.resize(n_tokens_all); - batch.n_seq_id = n_seq_id.data(); - batch.seq_id = seq_id_arr.data(); - } + logits_out.clear(); +#endif - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } + const uint32_t n_microbatch = 256; + + for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_microbatch) { + const uint32_t n_tokens = std::min(n_microbatch, n_tokens_all - cur_token); + + llama_batch batch = { + /* .n_tokens = */ (int32_t) n_tokens, + /* .token = */ all_batch.token ? all_batch.token + cur_token : nullptr, + /* .embd = */ all_batch.embd ? all_batch.embd + cur_token*n_embd : nullptr, + /* .pos = */ all_batch.pos ? all_batch.pos + cur_token : nullptr, + /* .n_seq_id = */ all_batch.n_seq_id ? all_batch.n_seq_id + cur_token : nullptr, + /* .seq_id = */ all_batch.seq_id ? all_batch.seq_id + cur_token : nullptr, + /* .logits = */ all_batch.logits ? all_batch.logits + cur_token : nullptr, + /* .all_pos_0 = */ all_batch.all_pos_0 + (llama_pos) cur_token*all_batch.all_pos_1, + /* .all_pos_1 = */ all_batch.all_pos_1, + /* .all_seq_id = */ all_batch.all_seq_id, + }; - if (!llama_kv_cache_find_slot(kv_self, batch)) { - return 1; - } + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT(n_threads > 0); - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); + // helpers for smoother batch API transition + // after deprecating the llama_eval calls, these will be removed + std::vector pos; - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; - ggml_backend_sched_reset(lctx.sched); + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } - ggml_cgraph * gf = llama_build_graph(lctx, batch); + batch.pos = pos.data(); + } - // the output is always the last tensor in the graph - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); + if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); + seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } - // the embeddings could be the second to last tensor, or the third to last tensor - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - } + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); + } - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } - // for big prompts, if BLAS is enabled, it is better to use only one thread - // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance - // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well - // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering - // with the BLAS calls. need a better solution - if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { - n_threads = std::min(4, n_threads); - } + if (!llama_kv_cache_find_slot(kv_self, batch)) { + LLAMA_LOG_ERROR("%s: failed to find a slot in the cache", __func__); + return 1; + } - const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1; - if (ggml_cpu_has_cublas() && fully_offloaded) { - n_threads = 1; - } + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); -#ifdef GGML_USE_MPI - const int64_t n_layer = hparams.n_layer; - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); -#endif + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(lctx.backend_metal)) { - ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); - } -#endif + int i_ub = cur_token / n_microbatch; + size_t n_buf = lctx.buf_cpu_ub.size(); + if (i_ub != 0 && i_ub % n_buf == 0) { + // sync all backends + printf("not enough buffers, syncing now\n"); + // TODO: ggml_backend_sched_synchronize() + for (auto * backend : lctx.backends) { + ggml_backend_synchronize(backend); + } + } - if (lctx.backend_cpu != nullptr) { - ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); - } - ggml_backend_sched_graph_compute(lctx.sched, gf); + ggml_tallocr_set_buffer(lctx.alloc_cpu, lctx.buf_cpu_ub[i_ub % n_buf]); - // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); + ggml_backend_sched_reset(lctx.sched); -#ifdef GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); -#endif + ggml_cgraph * gf = llama_build_graph(lctx, batch); - // update the kv ring buffer - { - if (kv_self.has_shift) { - kv_self.has_shift = false; - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].delta = 0; - } + // the output is always the last tensor in the graph + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + + // the embeddings could be the second to last tensor, or the third to last tensor + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); } - kv_self.head += n_tokens; + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); - // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; + // for big prompts, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well + // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering + // with the BLAS calls. need a better solution + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + n_threads = std::min(4, n_threads); } - } -#ifdef GGML_PERF - // print timing information per ggml operation (for debugging purposes) - // requires GGML_PERF to be defined - ggml_graph_print(gf); -#endif + #ifdef GGML_USE_MPI + const int64_t n_layer = hparams.n_layer; + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); + #endif - // plot the computation graph in dot format (for debugging purposes) - //if (n_past%100 == 0) { - // ggml_graph_dump_dot(gf, NULL, "llama.dot"); - //} + #ifdef GGML_USE_METAL + if (lctx.backend_metal != nullptr) { + ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); + } + #endif - // extract logits - // TODO: do not compute and extract logits if only embeddings are needed - // need to update the graphs to skip "result_output" - { - auto & logits_out = lctx.logits; + if (lctx.backend_cpu != nullptr) { + ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); + } -#ifndef NDEBUG - auto & logits_valid = lctx.logits_valid; - logits_valid.clear(); - logits_valid.resize(n_tokens); + ggml_backend_sched_graph_compute(lctx.sched, gf); - logits_out.clear(); -#endif + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); - ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); - GGML_ASSERT(res_backend != nullptr); - if (batch.logits) { - logits_out.resize(n_vocab * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; + #ifdef GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); + #endif + + // update the kv ring buffer + { + if (kv_self.has_shift) { + kv_self.has_shift = false; + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].delta = 0; } - ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); -#ifndef NDEBUG - logits_valid[i] = true; -#endif } - } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); - ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); -#ifndef NDEBUG - std::fill(logits_valid.begin(), logits_valid.end(), true); -#endif - } else { - logits_out.resize(n_vocab); - ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); -#ifndef NDEBUG - logits_valid[0] = true; -#endif + + kv_self.head += n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } + } + + #ifdef GGML_PERF + // print timing information per ggml operation (for debugging purposes) + // requires GGML_PERF to be defined + ggml_graph_print(gf); + #endif + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + // extract logits + // TODO: do not compute and extract logits if only embeddings are needed + // need to update the graphs to skip "result_output" + { + ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); + GGML_ASSERT(res_backend != nullptr); + if (batch.logits) { + //logits_out.resize(n_vocab * n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + n_vocab*(cur_token + i), n_vocab*i*sizeof(float), n_vocab*sizeof(float)); + #ifndef NDEBUG + logits_valid[i] = true; + #endif + } + } else if (lctx.logits_all) { + //logits_out.resize(n_vocab * n_tokens); + //ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); + ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + cur_token*n_vocab, 0, n_vocab*n_tokens*sizeof(float)); + #ifndef NDEBUG + std::fill(logits_valid.begin(), logits_valid.end(), true); + #endif + } else { + if (cur_token + n_tokens >= n_tokens_all) { + //logits_out.resize(n_vocab); + ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float)); + } + //ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float)); + #ifndef NDEBUG + logits_valid[0] = true; + #endif + } + //ggml_backend_synchronize(res_backend); } - ggml_backend_synchronize(res_backend); - } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + // FIXME + // extract embeddings + if (!lctx.embedding.empty()) { + GGML_ASSERT(!"not implemented"); + auto & embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); - ggml_backend_synchronize(embeddings_backend); + embedding_out.resize(n_embd); + ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); + //ggml_backend_synchronize(embeddings_backend); + } + } + + // TODO: ggml_backend_sched_synchronize() + for (auto * backend : lctx.backends) { + ggml_backend_synchronize(backend); } // measure the performance only for the single-token evals - if (n_tokens == 1) { + if (n_tokens_all == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.n_eval++; } - else if (n_tokens > 1) { + else if (n_tokens_all > 1) { lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += n_tokens; + lctx.n_p_eval += n_tokens_all; } // get a more accurate load time, upon first eval @@ -9402,7 +9455,7 @@ struct llama_context * llama_new_context_with_model( ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES); - ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); + ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); // build worst-case graph int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); @@ -9415,7 +9468,17 @@ struct llama_context * llama_new_context_with_model( // note: the number of splits during measure is higher than during inference due to the kv shift int n_splits = ggml_backend_sched_get_n_splits(ctx->sched); LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits); - ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); + ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); + + // duplicate cpu buffers for microbatching + ggml_backend_buffer_t buf_cpu = ggml_tallocr_get_buffer(ctx->alloc_cpu); + size_t buf_size = ggml_backend_buffer_get_size(buf_cpu); + ctx->buf_cpu_ub.push_back(buf_cpu); + int n_ub = 64; + for (int i = 1; i < n_ub; ++i) { + ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_size); + ctx->buf_cpu_ub.push_back(buf); + } for (ggml_backend_t backend : ctx->backends) { ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend);