diff --git a/src/llama_engine.cc b/src/llama_engine.cc index 29bd84f..631d3ed 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -1613,66 +1613,24 @@ bool LlamaEngine::HandleLlamaCppChatCompletion( if (IsLlamaServerModel(model)) { llama_server_map_.at(model).q->runTaskInQueue( [this, cb = std::move(callback), json_body, model] { - auto include_usage = [&json_body]() -> bool { - auto stream = (*json_body).get("stream", false).asBool(); - if (stream) { - if (json_body->isMember("stream_options") && - !(*json_body)["stream_options"].isNull()) { - return (*json_body)["stream_options"] - .get("include_usage", false) - .asBool(); - } + auto oaicompat = [&json_body]() -> bool { + if (json_body->isMember("logprobs") && + (*json_body)["logprobs"].asBool()) { return false; } - }(); - - // - auto& s = llama_server_map_.at(model); - httplib::Client cli(s.host + ":" + std::to_string(s.port)); - auto data = ConvertJsonCppToNlohmann(*json_body); - auto data_str = data.dump(); - LOG_DEBUG << "data_str: " << data_str; - cli.set_read_timeout(std::chrono::seconds(60)); - // std::cout << "> "; - httplib::Request req; - req.headers = httplib::Headers(); - req.set_header("Content-Type", "application/json"); - req.method = "POST"; - req.path = "/v1/chat/completions"; - req.body = data_str; - req.content_receiver = [cb, include_usage]( - const char* data, size_t data_length, - uint64_t offset, uint64_t total_length) { - std::string s(data, data_length); - Json::Value resp_data; - resp_data["data"] = s; - Json::Value status; - - if (s.find("[DONE]") != std::string::npos) { - LOG_DEBUG << "[DONE]"; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); + if (json_body->isMember("prompt") && + !(*json_body)["prompt"].asString().empty()) { return false; } - - // For openai api compatibility - if (!include_usage && - s.find("completion_tokens") != std::string::npos) { - return true; - } - - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); - LOG_DEBUG << s; return true; - }; - cli.send(req); + }(); + if (oaicompat) { + HandleOpenAiChatCompletion(json_body, + const_cast(cb), model); + } else { + HandleNonOpenAiChatCompletion( + json_body, const_cast(cb), model); + } }); LOG_DEBUG << "Done HandleChatCompletion"; return true; @@ -1680,6 +1638,322 @@ bool LlamaEngine::HandleLlamaCppChatCompletion( return false; } +void LlamaEngine::HandleOpenAiChatCompletion( + std::shared_ptr json_body, + std::function&& cb, + const std::string& model) { + auto is_stream = (*json_body).get("stream", false).asBool(); + auto include_usage = [&json_body, is_stream]() -> bool { + if (is_stream) { + if (json_body->isMember("stream_options") && + !(*json_body)["stream_options"].isNull()) { + return (*json_body)["stream_options"] + .get("include_usage", false) + .asBool(); + } + return false; + } + return false; + }(); + + auto n = [&json_body, is_stream]() -> int { + if (is_stream) + return 1; + return (*json_body).get("n", 1).asInt(); + }(); + + auto& s = llama_server_map_.at(model); + + // Format logit_bias + if (json_body->isMember("logit_bias")) { + auto logit_bias = + llama::inferences::ChatCompletionRequest::ConvertLogitBiasToArray( + (*json_body)["logit_bias"]); + (*json_body)["logit_bias"] = logit_bias; + } + + httplib::Client cli(s.host + ":" + std::to_string(s.port)); + auto data = ConvertJsonCppToNlohmann(*json_body); + + // llama.cpp server only supports n = 1 + data["n"] = 1; + auto data_str = data.dump(); + LOG_INFO << "data_str: " << data_str; + cli.set_read_timeout(std::chrono::seconds(60)); + if (is_stream) { + // std::cout << "> "; + httplib::Request req; + req.headers = httplib::Headers(); + req.set_header("Content-Type", "application/json"); + req.method = "POST"; + req.path = "/v1/chat/completions"; + req.body = data_str; + req.content_receiver = [cb, include_usage, n, is_stream]( + const char* data, size_t data_length, + uint64_t offset, uint64_t total_length) { + std::string s(data, data_length); + if (s.find("[DONE]") != std::string::npos) { + LOG_DEBUG << "[DONE]"; + cb(ResStatus(IsDone{true}, HasError{false}, IsStream{true}, + StatusCode{k200OK}) + .ToJson(), + ResStreamData(s).ToJson()); + return false; + } + + // For openai api compatibility + if (!include_usage && s.find("completion_tokens") != std::string::npos) { + return true; + } + + cb(ResStatus(IsDone{false}, HasError{false}, IsStream{true}, + StatusCode{k200OK}) + .ToJson(), + ResStreamData(s).ToJson()); + LOG_DEBUG << s; + return true; + }; + cli.send(req); + LOG_DEBUG << "Sent"; + } else { + Json::Value result; + // multiple choices + for (int i = 0; i < n; i++) { + auto res = cli.Post("/v1/chat/completions", httplib::Headers(), + data_str.data(), data_str.size(), "application/json"); + if (res) { + LOG_INFO << res->body; + auto r = ParseJsonString(res->body); + if (i == 0) { + result = r; + } else { + r["choices"][0]["index"] = i; + result["choices"].append(r["choices"][0]); + result["usage"]["completion_tokens"] = + result["usage"]["completion_tokens"].asInt() + + r["usage"]["completion_tokens"].asInt(); + result["usage"]["prompt_tokens"] = + result["usage"]["prompt_tokens"].asInt() + + r["usage"]["prompt_tokens"].asInt(); + result["usage"]["total_tokens"] = + result["usage"]["total_tokens"].asInt() + + r["usage"]["total_tokens"].asInt(); + } + + if (i == n - 1) { + cb(ResStatus(IsDone{true}, HasError{false}, IsStream{false}, + StatusCode{k200OK}) + .ToJson(), + std::move(result)); + } + + } else { + std::cout << "Error" << std::endl; + cb(ResStatus(IsDone{true}, HasError{true}, IsStream{false}, + StatusCode{k500InternalServerError}) + .ToJson(), + Json::Value()); + break; + } + } + } +} + +void LlamaEngine::HandleNonOpenAiChatCompletion( + std::shared_ptr json_body, + std::function&& cb, + const std::string& model) { + LOG_INFO << "Handle non OpenAI"; + auto is_stream = (*json_body).get("stream", false).asBool(); + auto include_usage = [&json_body, is_stream]() -> bool { + if (is_stream) { + if (json_body->isMember("stream_options") && + !(*json_body)["stream_options"].isNull()) { + return (*json_body)["stream_options"] + .get("include_usage", false) + .asBool(); + } + return false; + } + return false; + }(); + + auto n = [&json_body, is_stream]() -> int { + if (is_stream) + return 1; + return (*json_body).get("n", 1).asInt(); + }(); + + auto& s = llama_server_map_.at(model); + + // Format logit_bias + if (json_body->isMember("logit_bias")) { + auto logit_bias = + llama::inferences::ChatCompletionRequest::ConvertLogitBiasToArray( + (*json_body)["logit_bias"]); + (*json_body)["logit_bias"] = logit_bias; + } + + httplib::Client cli(s.host + ":" + std::to_string(s.port)); + auto get_message = [](const Json::Value& msg_content) -> std::string { + if (msg_content.isArray()) { + for (const auto& mc : msg_content) { + if (mc["type"].asString() == "text") { + return mc["text"].asString(); + } + } + } else { + return msg_content.asString(); + } + return ""; + }; + + // If prompt is provided, use it as the prompt + if (!json_body->isMember("prompt") || + (*json_body)["prompt"].asString().empty()) { + std::string formatted_output; + for (const auto& message : (*json_body)["messages"]) { + std::string input_role = message["role"].asString(); + std::string role; + if (input_role == "user") { + role = s.user_prompt; + } else if (input_role == "assistant") { + role = s.ai_prompt; + } else if (input_role == "system") { + role = s.system_prompt; + } else { + role = input_role; + } + + if (auto content = get_message(message["content"]); !content.empty()) { + formatted_output += role + content; + } + } + formatted_output += s.ai_prompt; + (*json_body)["prompt"] = formatted_output; + } + + auto data = ConvertJsonCppToNlohmann(*json_body); + + // llama.cpp server only supports n = 1 + data["n"] = 1; + auto data_str = data.dump(); + LOG_INFO << "data_str: " << data_str; + cli.set_read_timeout(std::chrono::seconds(60)); + int n_probs = json_body->get("n_probs", 0).asInt(); + if (is_stream) { + // std::cout << "> "; + httplib::Request req; + req.headers = httplib::Headers(); + req.set_header("Content-Type", "application/json"); + req.method = "POST"; + req.path = "/v1/completions"; + req.body = data_str; + req.content_receiver = [cb, include_usage, n, is_stream, n_probs, model]( + const char* data, size_t data_length, + uint64_t offset, uint64_t total_length) { + std::string s(data, data_length); + LOG_INFO << s; + if (s.size() > 6) { + s = s.substr(6); + } + auto json_data = ParseJsonString(s); + + // DONE + if (json_data.isMember("timings")) { + std::optional u; + if (include_usage) { + u = Usage{json_data["tokens_evaluated"].asInt(), + json_data["tokens_predicted"].asInt()}; + } + const std::string str = + "data: " + + CreateReturnJson(llama_utils::generate_random_string(20), model, "", + "stop", include_usage, u) + + "\n\n" + "data: [DONE]" + "\n\n"; + + cb(ResStatus(IsDone{true}, HasError{false}, IsStream{is_stream}, + StatusCode{k200OK}) + .ToJson(), + ResStreamData(str).ToJson()); + return false; + } + + json logprobs; + if (n_probs > 0) { + logprobs = + ConvertJsonCppToNlohmann(json_data["completion_probabilities"]); + } + std::string to_send = json_data.get("content", "").asString(); + const std::string str = + "data: " + + CreateReturnJson(llama_utils::generate_random_string(20), model, + to_send, "", include_usage, std::nullopt, logprobs) + + "\n\n"; + cb(ResStatus(IsDone{false}, HasError{false}, IsStream{true}, + StatusCode{k200OK}) + .ToJson(), + ResStreamData(str).ToJson()); + + return true; + }; + cli.send(req); + LOG_DEBUG << "Sent"; + } else { + Json::Value result; + int prompt_tokens = 0; + int predicted_tokens = 0; + // multiple choices + for (int i = 0; i < n; i++) { + auto res = cli.Post("/v1/completions", httplib::Headers(), + data_str.data(), data_str.size(), "application/json"); + if (res) { + LOG_INFO << res->body; + auto r = ParseJsonString(res->body); + json logprobs; + prompt_tokens += r["tokens_evaluated"].asInt(); + predicted_tokens += r["tokens_predicted"].asInt(); + std::string to_send = r["content"].asString(); + llama_utils::ltrim(to_send); + if (n_probs > 0) { + logprobs = ConvertJsonCppToNlohmann(r["completion_probabilities"]); + } + + if (i == 0) { + result = CreateFullReturnJson( + llama_utils::generate_random_string(20), model, to_send, "_", + prompt_tokens, predicted_tokens, Json::Value("stop"), logprobs); + } else { + auto choice = CreateFullReturnJson( + llama_utils::generate_random_string(20), model, to_send, "_", + prompt_tokens, predicted_tokens, Json::Value("stop"), + logprobs)["choices"][0]; + choice["index"] = i; + result["choices"].append(choice); + result["usage"]["completion_tokens"] = predicted_tokens; + result["usage"]["prompt_tokens"] = prompt_tokens; + result["usage"]["total_tokens"] = predicted_tokens + prompt_tokens; + } + + if (i == n - 1) { + cb(ResStatus(IsDone{true}, HasError{false}, IsStream{false}, + StatusCode{k200OK}) + .ToJson(), + std::move(result)); + } + + } else { + LOG_ERROR << "Error"; + cb(ResStatus(IsDone{true}, HasError{true}, IsStream{false}, + StatusCode{k500InternalServerError}) + .ToJson(), + Json::Value()); + break; + } + } + } +} + bool LlamaEngine::HandleLlamaCppEmbedding( std::shared_ptr json_body, std::function&& callback, @@ -1689,7 +1963,6 @@ bool LlamaEngine::HandleLlamaCppEmbedding( [this, cb = std::move(callback), json_body, model] { auto& s = llama_server_map_.at(model); httplib::Client cli(s.host + ":" + std::to_string(s.port)); - httplib::Params params; auto data = ConvertJsonCppToNlohmann(*json_body); auto data_str = data.dump(); @@ -1699,21 +1972,16 @@ bool LlamaEngine::HandleLlamaCppEmbedding( data_str.size(), "application/json"); if (res) { // std::cout << res->body << std::endl; - Json::Value root = ParseJsonString(res->body); - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = k200OK; - cb(std::move(status), std::move(root)); + cb(ResStatus(IsDone{true}, HasError{false}, IsStream{false}, + StatusCode{k200OK}) + .ToJson(), + ParseJsonString(res->body)); } else { std::cout << "Error" << std::endl; - Json::Value status; - status["is_done"] = true; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k500InternalServerError; - cb(std::move(status), Json::Value()); + cb(ResStatus(IsDone{true}, HasError{true}, IsStream{false}, + StatusCode{k500InternalServerError}) + .ToJson(), + Json::Value()); } }); LOG_INFO << "Done HandleEmbedding"; diff --git a/src/llama_engine.h b/src/llama_engine.h index fcc3378..36b54a0 100644 --- a/src/llama_engine.h +++ b/src/llama_engine.h @@ -10,6 +10,8 @@ #include "trantor/utils/ConcurrentTaskQueue.h" #include "trantor/utils/Logger.h" +using http_callback = std::function; + class LlamaEngine : public EngineI { public: constexpr static auto kEngineName = "cortex.llamacpp"; @@ -73,6 +75,13 @@ class LlamaEngine : public EngineI { std::function&& callback, const std::string& model); + void HandleOpenAiChatCompletion(std::shared_ptr json_body, + http_callback&& callback, + const std::string& model); + void HandleNonOpenAiChatCompletion(std::shared_ptr json_body, + http_callback&& callback, + const std::string& model); + bool HandleLlamaCppEmbedding( std::shared_ptr json_body, std::function&& callback, @@ -81,6 +90,60 @@ class LlamaEngine : public EngineI { bool IsLlamaServerModel(const std::string& model) const; private: + struct IsDone { + bool is_done; + int operator()() { return is_done; } + }; + struct HasError { + bool has_error; + int operator()() { return has_error; } + }; + struct IsStream { + bool is_stream; + int operator()() { return is_stream; } + }; + struct StatusCode { + int status_code; + int operator()() { return status_code; } + }; + struct ResStatus { + private: + IsDone is_done; + HasError has_error; + IsStream is_stream; + StatusCode status_code; + + public: + ResStatus(IsDone is_done, HasError has_error, IsStream is_stream, + StatusCode status_code) + : is_done(is_done), + has_error(has_error), + is_stream(is_stream), + status_code(status_code) {} + + Json::Value ToJson() { + Json::Value status; + status["is_done"] = is_done(); + status["has_error"] = has_error(); + status["is_stream"] = is_stream(); + status["status_code"] = status_code(); + return status; + }; + }; + + struct ResStreamData { + private: + std::string s; + + public: + ResStreamData(std::string s) : s(std::move(s)) {} + Json::Value ToJson() { + Json::Value d; + d["data"] = s; + return d; + } + }; + struct ServerInfo { LlamaServerContext ctx; std::unique_ptr q; @@ -99,6 +162,10 @@ class LlamaEngine : public EngineI { struct ServerConfig { std::unique_ptr q; + std::string user_prompt; + std::string ai_prompt; + std::string system_prompt; + std::string pre_prompt; std::string host; int port; #if defined(_WIN32) || defined(_WIN64)