diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index f13bb51bc..11e2c384b 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -79,6 +79,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ) target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index d806a4f13..9041e7e07 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -11,7 +11,8 @@ namespace commands { bool ModelStartCmd::Exec(const std::string& host, int port, - const std::string& model_handle) { + const std::string& model_handle, + bool print_success_log) { std::optional model_id = SelectLocalModel(model_service_, model_handle); @@ -37,9 +38,12 @@ bool ModelStartCmd::Exec(const std::string& host, int port, data_str.size(), "application/json"); if (res) { if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG(model_id.value() << " model started successfully. Use `" - << commands::GetCortexBinary() << " run " - << *model_id << "` for interactive chat shell"); + if (print_success_log) { + CLI_LOG(model_id.value() + << " model started successfully. Use `" + << commands::GetCortexBinary() << " run " << *model_id + << "` for interactive chat shell"); + } return true; } else { auto root = json_helper::ParseJsonString(res->body); diff --git a/engine/cli/commands/model_start_cmd.h b/engine/cli/commands/model_start_cmd.h index 40c485a9f..ffd63d611 100644 --- a/engine/cli/commands/model_start_cmd.h +++ b/engine/cli/commands/model_start_cmd.h @@ -9,7 +9,8 @@ class ModelStartCmd { explicit ModelStartCmd(const ModelService& model_service) : model_service_{model_service} {}; - bool Exec(const std::string& host, int port, const std::string& model_handle); + bool Exec(const std::string& host, int port, const std::string& model_handle, + bool print_success_log = true); private: ModelService model_service_; diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 73aa5c362..c80f12de1 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -3,6 +3,7 @@ #include "config/yaml_config.h" #include "cortex_upd_cmd.h" #include "database/models.h" +#include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" #include "utils/cli_selection_utils.h" @@ -82,7 +83,7 @@ void RunCmd::Exec(bool run_detach) { if (!model_id.has_value()) { return; } - + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); @@ -139,12 +140,10 @@ void RunCmd::Exec(bool run_detach) { !commands::ModelStatusCmd(model_service_) .IsLoaded(host_, port_, *model_id)) { - auto result = model_service_.StartModel(host_, port_, *model_id); - if (result.has_error()) { - CLI_LOG("Error: " + result.error()); - return; - } - if (!result.value()) { + auto res = + commands::ModelStartCmd(model_service_) + .Exec(host_, port_, *model_id, false /*print_success_log*/); + if (!res) { CLI_LOG("Error: Failed to start model"); return; } diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 87094528b..cfc6be6e3 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -10,7 +10,8 @@ using namespace inferences; using json = nlohmann::json; namespace inferences { -server::server() { +server::server(std::shared_ptr inference_service) + : inference_svc_(inference_service) { #if defined(_WIN32) SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS); #endif @@ -25,7 +26,7 @@ void server::ChatCompletion( auto json_body = req->getJsonObject(); bool is_stream = (*json_body).get("stream", false).asBool(); auto q = std::make_shared(); - auto ir = inference_svc_.HandleChatCompletion(q, json_body); + auto ir = inference_svc_->HandleChatCompletion(q, json_body); if (ir.has_error()) { auto err = ir.error(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err)); @@ -47,7 +48,7 @@ void server::Embedding(const HttpRequestPtr& req, std::function&& callback) { LOG_TRACE << "Start embedding"; auto q = std::make_shared(); - auto ir = inference_svc_.HandleEmbedding(q, req->getJsonObject()); + auto ir = inference_svc_->HandleEmbedding(q, req->getJsonObject()); if (ir.has_error()) { auto err = ir.error(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err)); @@ -64,7 +65,7 @@ void server::Embedding(const HttpRequestPtr& req, void server::UnloadModel( const HttpRequestPtr& req, std::function&& callback) { - auto ir = inference_svc_.UnloadModel(req->getJsonObject()); + auto ir = inference_svc_->UnloadModel(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); resp->setStatusCode( static_cast(std::get<0>(ir)["status_code"].asInt())); @@ -74,7 +75,7 @@ void server::UnloadModel( void server::ModelStatus( const HttpRequestPtr& req, std::function&& callback) { - auto ir = inference_svc_.GetModelStatus(req->getJsonObject()); + auto ir = inference_svc_->GetModelStatus(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); resp->setStatusCode( static_cast(std::get<0>(ir)["status_code"].asInt())); @@ -84,7 +85,7 @@ void server::ModelStatus( void server::GetModels(const HttpRequestPtr& req, std::function&& callback) { LOG_TRACE << "Start to get models"; - auto ir = inference_svc_.GetModels(req->getJsonObject()); + auto ir = inference_svc_->GetModels(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); resp->setStatusCode( static_cast(std::get<0>(ir)["status_code"].asInt())); @@ -95,7 +96,7 @@ void server::GetModels(const HttpRequestPtr& req, void server::GetEngines( const HttpRequestPtr& req, std::function&& callback) { - auto ir = inference_svc_.GetEngines(req->getJsonObject()); + auto ir = inference_svc_->GetEngines(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(ir); callback(resp); } @@ -103,7 +104,7 @@ void server::GetEngines( void server::FineTuning( const HttpRequestPtr& req, std::function&& callback) { - auto ir = inference_svc_.FineTuning(req->getJsonObject()); + auto ir = inference_svc_->FineTuning(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); resp->setStatusCode( static_cast(std::get<0>(ir)["status_code"].asInt())); @@ -113,7 +114,7 @@ void server::FineTuning( void server::LoadModel(const HttpRequestPtr& req, std::function&& callback) { - auto ir = inference_svc_.LoadModel(req->getJsonObject()); + auto ir = inference_svc_->LoadModel(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); resp->setStatusCode( static_cast(std::get<0>(ir)["status_code"].asInt())); @@ -124,7 +125,7 @@ void server::LoadModel(const HttpRequestPtr& req, void server::UnloadEngine( const HttpRequestPtr& req, std::function&& callback) { - auto ir = inference_svc_.UnloadEngine(req->getJsonObject()); + auto ir = inference_svc_->UnloadEngine(req->getJsonObject()); auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir)); resp->setStatusCode( static_cast(std::get<0>(ir)["status_code"].asInt())); diff --git a/engine/controllers/server.h b/engine/controllers/server.h index 5959c7a8c..15844b403 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -1,10 +1,12 @@ + +#pragma once #include #include +#include #if defined(_WIN32) #define NOMINMAX #endif -#pragma once #include @@ -31,12 +33,12 @@ using namespace drogon; namespace inferences { -class server : public drogon::HttpController, +class server : public drogon::HttpController, public BaseModel, public BaseChatCompletion, public BaseEmbedding { public: - server(); + server(std::shared_ptr inference_service); ~server(); METHOD_LIST_BEGIN // list path definitions here; @@ -100,6 +102,6 @@ class server : public drogon::HttpController, services::SyncQueue& q); private: - services::InferenceService inference_svc_; + std::shared_ptr inference_svc_; }; }; // namespace inferences diff --git a/engine/main.cc b/engine/main.cc index 8c2f62a03..1e97384c8 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -5,6 +5,7 @@ #include "controllers/events.h" #include "controllers/models.h" #include "controllers/process_manager.h" +#include "controllers/server.h" #include "cortex-common/cortexpythoni.h" #include "services/model_service.h" #include "utils/archive_utils.h" @@ -88,21 +89,25 @@ void RunServer(std::optional port) { auto event_queue_ptr = std::make_shared(); cortex::event::EventProcessor event_processor(event_queue_ptr); + auto inference_svc = std::make_shared(); auto download_service = std::make_shared(event_queue_ptr); auto engine_service = std::make_shared(download_service); - auto model_service = std::make_shared(download_service); + auto model_service = + std::make_shared(download_service, inference_svc); // initialize custom controllers auto engine_ctl = std::make_shared(engine_service); auto model_ctl = std::make_shared(model_service, engine_service); auto event_ctl = std::make_shared(event_queue_ptr); auto pm_ctl = std::make_shared(); + auto server_ctl = std::make_shared(inference_svc); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); drogon::app().registerController(event_ctl); drogon::app().registerController(pm_ctl); + drogon::app().registerController(server_ctl); LOG_INFO << "Server started, listening at: " << config.apiServerHost << ":" << config.apiServerPort; diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 05147c84f..26cee5157 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -1,8 +1,11 @@ #pragma once +#include +#include #include +#include +#include #include -#include "common/base.h" #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" #include "utils/dylib.h" diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index fb6875119..ae3316c12 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -11,6 +11,7 @@ #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" #include "utils/huggingface_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" #include "utils/result.hpp" #include "utils/string_utils.h" @@ -611,28 +612,21 @@ cpp::result ModelService::StartModel( json_data["ai_prompt"] = mc.ai_template; } - auto data_str = json_data.toStyledString(); - CTL_INF(data_str); - cli.set_read_timeout(std::chrono::seconds(60)); - auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(), - data_str.data(), data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - return true; - } else if (res->status == httplib::StatusCode::Conflict_409) { - CTL_INF("Model '" + model_handle + "' is already loaded"); - return true; - } else { - auto root = json_helper::ParseJsonString(res->body); - CTL_ERR("Model failed to load with status code: " << res->status); - return cpp::fail("Model failed to start: " + root["message"].asString()); - } + CTL_INF(json_data.toStyledString()); + assert(!!inference_svc_); + auto ir = + inference_svc_->LoadModel(std::make_shared(json_data)); + auto status = std::get<0>(ir)["status_code"].asInt(); + auto data = std::get<1>(ir); + if (status == httplib::StatusCode::OK_200) { + return true; + } else if (status == httplib::StatusCode::Conflict_409) { + CTL_INF("Model '" + model_handle + "' is already loaded"); + return true; } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); - return cpp::fail("HTTP error: " + httplib::to_string(err)); + CTL_ERR("Model failed to start with status code: " << status); + return cpp::fail("Model failed to start: " + data["message"].asString()); } - } catch (const std::exception& e) { return cpp::fail("Fail to load model with ID '" + model_handle + "': " + e.what()); @@ -663,25 +657,18 @@ cpp::result ModelService::StopModel( Json::Value json_data; json_data["model"] = model_handle; json_data["engine"] = mc.engine; - auto data_str = json_data.toStyledString(); - CTL_INF(data_str); - cli.set_read_timeout(std::chrono::seconds(60)); - auto res = cli.Post("/inferences/server/unloadmodel", httplib::Headers(), - data_str.data(), data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - return true; - } else { - CTL_ERR("Model failed to unload with status code: " << res->status); - return cpp::fail("Model failed to unload with status code: " + - std::to_string(res->status)); - } + CTL_INF(json_data.toStyledString()); + assert(inference_svc_); + auto ir = + inference_svc_->UnloadModel(std::make_shared(json_data)); + auto status = std::get<0>(ir)["status_code"].asInt(); + auto data = std::get<1>(ir); + if (status == httplib::StatusCode::OK_200) { + return true; } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); - return cpp::fail("HTTP error: " + httplib::to_string(err)); + CTL_ERR("Model failed to stop with status code: " << status); + return cpp::fail("Model failed to stop: " + data["message"].asString()); } - } catch (const std::exception& e) { return cpp::fail("Fail to unload model with ID '" + model_handle + "': " + e.what()); diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 822b376ae..5adc5a01e 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -5,6 +5,8 @@ #include #include "config/model_config.h" #include "services/download_service.h" +#include "services/inference_service.h" + class ModelService { public: constexpr auto static kHuggingFaceHost = "huggingface.co"; @@ -12,6 +14,12 @@ class ModelService { explicit ModelService(std::shared_ptr download_service) : download_service_{download_service} {}; + explicit ModelService( + std::shared_ptr download_service, + std::shared_ptr inference_service) + : download_service_{download_service}, + inference_svc_(inference_service) {}; + /** * Return model id if download successfully */ @@ -67,4 +75,5 @@ class ModelService { const std::string& modelName); std::shared_ptr download_service_; + std::shared_ptr inference_svc_; };