From 09562678d9d803f8d1dc1531db3933d0c269db94 Mon Sep 17 00:00:00 2001 From: ziadb Date: Sun, 26 Nov 2023 22:28:59 -0500 Subject: [PATCH 01/10] * add multiprompt support --- examples/server/server.cpp | 144 ++++++++++++++++++++++++++++++++++--- 1 file changed, 133 insertions(+), 11 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50f124b13e849..ec9f1ad577e29 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -155,15 +156,23 @@ struct task_server { json data; bool infill_mode = false; bool embedding_mode = false; + int multitask_id = -1; }; struct task_result { int id; + int multitask_id = -1; bool stop; bool error; json result_json; }; +struct task_multi { + int id; + std::unordered_set subtasks_remaining{}; + std::vector results{}; +}; + // TODO: can become bool if we can't find use of more states enum slot_state { @@ -406,6 +415,9 @@ struct llama_client_slot double t_prompt_processing; // ms double t_token_generation; // ms + // multitasks + int multitask_id = -1; + void reset() { num_prompt_tokens = 0; generated_text = ""; @@ -512,7 +524,7 @@ struct llama_server_context bool all_slots_are_idle = false; bool add_bos_token = true; - int32_t id_gen; + std::atomic id_gen; int32_t n_ctx; // total context for all clients / slots // system prompt @@ -529,8 +541,10 @@ struct llama_server_context std::vector queue_tasks; std::vector queue_results; + std::vector queue_multitasks; std::mutex mutex_tasks; std::mutex mutex_results; + std::mutex mutex_multitasks; ~llama_server_context() { @@ -1112,17 +1126,40 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(int id, std::string error) + void send_error(task_server& task, std::string error) { std::lock_guard lock(mutex_results); task_result res; - res.id = id; + res.id = task.id; + res.multitask_id = task.multitask_id; res.stop = false; res.error = true; res.result_json = { { "content", error } }; queue_results.push_back(res); } + void add_multi_task(int id, std::vector& sub_ids) + { + std::lock_guard lock(mutex_multitasks); + task_multi multi; + multi.id = id; + std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); + queue_multitasks.push_back(multi); + } + + void update_multi_task(int multitask_id, int subtask_id, task_result& result) + { + std::lock_guard lock(mutex_multitasks); + for (auto& multitask : queue_multitasks) + { + if (multitask.id == multitask_id) + { + multitask.subtasks_remaining.erase(subtask_id); + multitask.results.push_back(result); + } + } + } + json get_model_props() { return get_formated_generation(slots[0]); @@ -1167,6 +1204,7 @@ struct llama_server_context std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; + res.multitask_id = slot.multitask_id; res.error = false; res.stop = false; @@ -1206,6 +1244,7 @@ struct llama_server_context std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; + res.multitask_id = slot.multitask_id; res.error = false; res.stop = true; @@ -1251,6 +1290,16 @@ struct llama_server_context res.result_json["model"] = slot.oaicompat_model; } + // if this task has a multitask associated with it, then we update the multitask + if (slot.multitask_id != -1) + { + update_multi_task(slot.multitask_id, slot.task_id, res); + } + else // otherwise update the results queue + { + + } + queue_results.push_back(res); } @@ -1259,6 +1308,7 @@ struct llama_server_context std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; + res.multitask_id = slot.multitask_id; res.error = false; res.stop = true; @@ -1285,9 +1335,8 @@ struct llama_server_context queue_results.push_back(res); } - int request_completion(json data, bool infill, bool embedding) + int request_completion(json data, bool infill, bool embedding, int multitask_id) { - std::lock_guard lock(mutex_tasks); task_server task; task.id = id_gen++; task.target_id = 0; @@ -1295,6 +1344,17 @@ struct llama_server_context task.infill_mode = infill; task.embedding_mode = embedding; task.type = COMPLETION_TASK; + task.multitask_id = multitask_id; + + // when a completion task's prompt array is not a singleton, we split it into multiple requests + if (task.data.at("prompt").size() > 1) + { + auto id = split_multiprompt_task_into_subtasks(task); + return id; + } + + // otherwise, it's a single-prompt task, we actually queue it + std::lock_guard lock(mutex_tasks); queue_tasks.push_back(task); return task.id; } @@ -1313,8 +1373,17 @@ struct llama_server_context for (int i = 0; i < (int) queue_results.size(); i++) { + // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result + if (queue_results[i].multitask_id == task_id) + { + update_multi_task(task_id, queue_results[i].id, queue_results[i]); + queue_results.erase(queue_results.begin() + i); + continue; + } + if (queue_results[i].id == task_id) { + assert(queue_results[i].multitask_id == -1); task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; @@ -1404,6 +1473,25 @@ struct llama_server_context queue_tasks.push_back(task); } + int split_multiprompt_task_into_subtasks(task_server& task) + { + auto prompt_count = task.data.at("prompt").size(); + assert(prompt_count > 1); + + int multitask_id = id_gen++; + std::vector subtask_ids(prompt_count); + for (int i = 0; i < prompt_count; i++) + { + json subtask_data = task.data; + subtask_data["prompt"] = subtask_data["prompt"][i]; + + subtask_ids[i] = request_completion(subtask_data, task.infill_mode, task.embedding_mode, multitask_id); + } + + add_multi_task(multitask_id, subtask_ids); + return multitask_id; + } + void process_tasks() { std::lock_guard lock(mutex_tasks); @@ -1419,7 +1507,7 @@ struct llama_server_context { LOG_TEE("slot unavailable\n"); // send error result - send_error(task.id, "slot unavailable"); + send_error(task, "slot unavailable"); return; } @@ -1433,11 +1521,12 @@ struct llama_server_context slot->infill = task.infill_mode; slot->embedding = task.embedding_mode; slot->task_id = task.id; + slot->multitask_id = task.multitask_id; if (!launch_slot_with_data(slot, task.data)) { // send error result - send_error(task.id, "internal_error"); + send_error(task, "internal_error"); break; } } break; @@ -1453,6 +1542,39 @@ struct llama_server_context } break; } } + + // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue + std::lock_guard lock_multitasks(mutex_multitasks); + auto queue_iterator = queue_multitasks.begin(); + while (queue_iterator != queue_multitasks.end()) + { + if (queue_iterator->subtasks_remaining.empty()) + { + // all subtasks done == multitask is done + task_result aggregate_result{}; + aggregate_result.id = queue_iterator->id; + aggregate_result.stop = true; + aggregate_result.error = false; + + // collect json results into one json result + std::vector result_jsons{}; + for (auto& subres : queue_iterator->results) + { + result_jsons.push_back(subres.result_json); + aggregate_result.error = aggregate_result.error && subres.error; + } + aggregate_result.result_json = json{ "results", result_jsons }; + + std::lock_guard lock(mutex_results); + queue_results.push_back(aggregate_result); + + queue_iterator = queue_multitasks.erase(queue_iterator); + } + else + { + ++queue_iterator; + } + } } bool update_slots() { @@ -2596,7 +2718,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - const int task_id = llama.request_completion(data, false, false); + const int task_id = llama.request_completion(data, false, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); @@ -2685,7 +2807,7 @@ int main(int argc, char **argv) { json data = oaicompat_completion_params_parse(json::parse(req.body)); - const int task_id = llama.request_completion(data, false, false); + const int task_id = llama.request_completion(data, false, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; @@ -2754,7 +2876,7 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - const int task_id = llama.request_completion(data, true, false); + const int task_id = llama.request_completion(data, true, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); @@ -2858,7 +2980,7 @@ int main(int argc, char **argv) { prompt = ""; } - const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true); + const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1); task_result result = llama.next_result(task_id); return res.set_content(result.result_json.dump(), "application/json"); }); From ff67c764c48a8f23d4bd47f43231f9de50ed89fe Mon Sep 17 00:00:00 2001 From: ziadb Date: Sun, 26 Nov 2023 22:36:19 -0500 Subject: [PATCH 02/10] * cleanup --- examples/server/server.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ec9f1ad577e29..a97c7be90e420 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1290,15 +1290,11 @@ struct llama_server_context res.result_json["model"] = slot.oaicompat_model; } - // if this task has a multitask associated with it, then we update the multitask + // parent multitask, if any, needs to be updated if (slot.multitask_id != -1) { update_multi_task(slot.multitask_id, slot.task_id, res); } - else // otherwise update the results queue - { - - } queue_results.push_back(res); } @@ -1349,7 +1345,7 @@ struct llama_server_context // when a completion task's prompt array is not a singleton, we split it into multiple requests if (task.data.at("prompt").size() > 1) { - auto id = split_multiprompt_task_into_subtasks(task); + auto id = request_multiprompt_task(task); return id; } @@ -1473,21 +1469,23 @@ struct llama_server_context queue_tasks.push_back(task); } - int split_multiprompt_task_into_subtasks(task_server& task) + int split_multiprompt_task(task_server& multiprompt_task) { - auto prompt_count = task.data.at("prompt").size(); + auto prompt_count = multiprompt_task.data.at("prompt").size(); assert(prompt_count > 1); int multitask_id = id_gen++; std::vector subtask_ids(prompt_count); for (int i = 0; i < prompt_count; i++) { - json subtask_data = task.data; + json subtask_data = multiprompt_task.data; subtask_data["prompt"] = subtask_data["prompt"][i]; - subtask_ids[i] = request_completion(subtask_data, task.infill_mode, task.embedding_mode, multitask_id); + // subtasks inherit everything else (infill mode, embedding mode, etc.) + subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); } + // queue up the multitask so we can track its subtask progression add_multi_task(multitask_id, subtask_ids); return multitask_id; } From 5906fb442b4b10822ca63f833635cfc40144ddca Mon Sep 17 00:00:00 2001 From: ziadb Date: Sun, 26 Nov 2023 22:38:29 -0500 Subject: [PATCH 03/10] * more cleanup --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a97c7be90e420..2b1ba4eca234a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1345,7 +1345,7 @@ struct llama_server_context // when a completion task's prompt array is not a singleton, we split it into multiple requests if (task.data.at("prompt").size() > 1) { - auto id = request_multiprompt_task(task); + auto id = split_multiprompt_task(task); return id; } From e2ee37761ef89e84a30d2950844d8fdfbe1e6b15 Mon Sep 17 00:00:00 2001 From: ziadb Date: Mon, 27 Nov 2023 18:12:58 -0500 Subject: [PATCH 04/10] * remove atomicity of id_gen, and change lock_guard to unique_lock on completion requests --- examples/server/server.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2b1ba4eca234a..443b3c00a1920 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -524,7 +524,7 @@ struct llama_server_context bool all_slots_are_idle = false; bool add_bos_token = true; - std::atomic id_gen; + int32_t id_gen; int32_t n_ctx; // total context for all clients / slots // system prompt @@ -542,7 +542,7 @@ struct llama_server_context std::vector queue_tasks; std::vector queue_results; std::vector queue_multitasks; - std::mutex mutex_tasks; + std::mutex mutex_tasks; // also guards id_gen std::mutex mutex_results; std::mutex mutex_multitasks; @@ -1333,6 +1333,7 @@ struct llama_server_context int request_completion(json data, bool infill, bool embedding, int multitask_id) { + std::unique_lock lock(mutex_tasks); task_server task; task.id = id_gen++; task.target_id = 0; @@ -1345,12 +1346,12 @@ struct llama_server_context // when a completion task's prompt array is not a singleton, we split it into multiple requests if (task.data.at("prompt").size() > 1) { + lock.unlock(); // entering new func scope auto id = split_multiprompt_task(task); return id; } // otherwise, it's a single-prompt task, we actually queue it - std::lock_guard lock(mutex_tasks); queue_tasks.push_back(task); return task.id; } From 38ce5d02e0fe7da2ab3344f34952699406778b49 Mon Sep 17 00:00:00 2001 From: ziadb Date: Wed, 29 Nov 2023 16:56:43 -0500 Subject: [PATCH 05/10] * remove all references to mutex_multitasks --- examples/server/server.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 443b3c00a1920..c79b54940d970 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -542,9 +542,8 @@ struct llama_server_context std::vector queue_tasks; std::vector queue_results; std::vector queue_multitasks; - std::mutex mutex_tasks; // also guards id_gen + std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks std::mutex mutex_results; - std::mutex mutex_multitasks; ~llama_server_context() { @@ -1140,7 +1139,7 @@ struct llama_server_context void add_multi_task(int id, std::vector& sub_ids) { - std::lock_guard lock(mutex_multitasks); + std::lock_guard lock(mutex_tasks); task_multi multi; multi.id = id; std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); @@ -1149,7 +1148,7 @@ struct llama_server_context void update_multi_task(int multitask_id, int subtask_id, task_result& result) { - std::lock_guard lock(mutex_multitasks); + std::lock_guard lock(mutex_tasks); for (auto& multitask : queue_multitasks) { if (multitask.id == multitask_id) @@ -1543,7 +1542,6 @@ struct llama_server_context } // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue - std::lock_guard lock_multitasks(mutex_multitasks); auto queue_iterator = queue_multitasks.begin(); while (queue_iterator != queue_multitasks.end()) { From 09da4b14f90add38e37a6fc758482e4aefec6230 Mon Sep 17 00:00:00 2001 From: Ziad Ben Hadj-Alouane Date: Wed, 29 Nov 2023 22:03:50 -0500 Subject: [PATCH 06/10] Update examples/server/server.cpp Co-authored-by: Jared Van Bortel --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c79b54940d970..e8726027dbc0b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1548,7 +1548,7 @@ struct llama_server_context if (queue_iterator->subtasks_remaining.empty()) { // all subtasks done == multitask is done - task_result aggregate_result{}; + task_result aggregate_result; aggregate_result.id = queue_iterator->id; aggregate_result.stop = true; aggregate_result.error = false; From 0e1a5aa5faabaece1f6abfbf5a0f2f8a1fd254d7 Mon Sep 17 00:00:00 2001 From: Ziad Ben Hadj-Alouane Date: Wed, 29 Nov 2023 22:03:55 -0500 Subject: [PATCH 07/10] Update examples/server/server.cpp Co-authored-by: Jared Van Bortel --- examples/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e8726027dbc0b..ebfc60f5f1d8e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -169,8 +169,8 @@ struct task_result { struct task_multi { int id; - std::unordered_set subtasks_remaining{}; - std::vector results{}; + std::unordered_set subtasks_remaining; + std::vector results; }; // TODO: can become bool if we can't find use of more states From 14785e11485c128109ff190128ee917e906279db Mon Sep 17 00:00:00 2001 From: Ziad Ben Hadj-Alouane Date: Wed, 29 Nov 2023 22:04:00 -0500 Subject: [PATCH 08/10] Update examples/server/server.cpp Co-authored-by: Jared Van Bortel --- examples/server/server.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ebfc60f5f1d8e..d024cfbcbdfd0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1346,8 +1346,7 @@ struct llama_server_context if (task.data.at("prompt").size() > 1) { lock.unlock(); // entering new func scope - auto id = split_multiprompt_task(task); - return id; + return split_multiprompt_task(task); } // otherwise, it's a single-prompt task, we actually queue it From 3b371e10c4f75f0e2b7e3b1955f4ec7d369d8999 Mon Sep 17 00:00:00 2001 From: Ziad Ben Hadj-Alouane Date: Wed, 29 Nov 2023 22:04:08 -0500 Subject: [PATCH 09/10] Update examples/server/server.cpp Co-authored-by: Jared Van Bortel --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d024cfbcbdfd0..e124204916a4e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1553,7 +1553,7 @@ struct llama_server_context aggregate_result.error = false; // collect json results into one json result - std::vector result_jsons{}; + std::vector result_jsons; for (auto& subres : queue_iterator->results) { result_jsons.push_back(subres.result_json); From 0f175a6084b931c3755d7e81b3162be58f189e7f Mon Sep 17 00:00:00 2001 From: ziadb Date: Thu, 30 Nov 2023 17:09:59 -0500 Subject: [PATCH 10/10] * change to set --- examples/server/server.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c79b54940d970..52cb3f7595d0d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -169,7 +168,7 @@ struct task_result { struct task_multi { int id; - std::unordered_set subtasks_remaining{}; + std::set subtasks_remaining{}; std::vector results{}; };