Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor inference service #1536

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engine/cli/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc
)

target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib)
Expand Down
12 changes: 8 additions & 4 deletions engine/cli/commands/model_start_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

namespace commands {
bool ModelStartCmd::Exec(const std::string& host, int port,
const std::string& model_handle) {
const std::string& model_handle,
bool print_success_log) {
std::optional<std::string> model_id =
SelectLocalModel(model_service_, model_handle);

Expand All @@ -37,9 +38,12 @@ bool ModelStartCmd::Exec(const std::string& host, int port,
data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
CLI_LOG(model_id.value() << " model started successfully. Use `"
<< commands::GetCortexBinary() << " run "
<< *model_id << "` for interactive chat shell");
if (print_success_log) {
CLI_LOG(model_id.value()
<< " model started successfully. Use `"
<< commands::GetCortexBinary() << " run " << *model_id
<< "` for interactive chat shell");
}
return true;
} else {
auto root = json_helper::ParseJsonString(res->body);
Expand Down
3 changes: 2 additions & 1 deletion engine/cli/commands/model_start_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class ModelStartCmd {
explicit ModelStartCmd(const ModelService& model_service)
: model_service_{model_service} {};

bool Exec(const std::string& host, int port, const std::string& model_handle);
bool Exec(const std::string& host, int port, const std::string& model_handle,
bool print_success_log = true);

private:
ModelService model_service_;
Expand Down
13 changes: 6 additions & 7 deletions engine/cli/commands/run_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "config/yaml_config.h"
#include "cortex_upd_cmd.h"
#include "database/models.h"
#include "model_start_cmd.h"
#include "model_status_cmd.h"
#include "server_start_cmd.h"
#include "utils/cli_selection_utils.h"
Expand Down Expand Up @@ -82,7 +83,7 @@ void RunCmd::Exec(bool run_detach) {
if (!model_id.has_value()) {
return;
}

cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
auto address = host_ + ":" + std::to_string(port_);
Expand Down Expand Up @@ -139,12 +140,10 @@ void RunCmd::Exec(bool run_detach) {
!commands::ModelStatusCmd(model_service_)
.IsLoaded(host_, port_, *model_id)) {

auto result = model_service_.StartModel(host_, port_, *model_id);
if (result.has_error()) {
CLI_LOG("Error: " + result.error());
return;
}
if (!result.value()) {
auto res =
commands::ModelStartCmd(model_service_)
.Exec(host_, port_, *model_id, false /*print_success_log*/);
if (!res) {
CLI_LOG("Error: Failed to start model");
return;
}
Expand Down
21 changes: 11 additions & 10 deletions engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ using namespace inferences;
using json = nlohmann::json;
namespace inferences {

server::server() {
server::server(std::shared_ptr<services::InferenceService> inference_service)
: inference_svc_(inference_service) {
#if defined(_WIN32)
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
#endif
Expand All @@ -25,7 +26,7 @@ void server::ChatCompletion(
auto json_body = req->getJsonObject();
bool is_stream = (*json_body).get("stream", false).asBool();
auto q = std::make_shared<services::SyncQueue>();
auto ir = inference_svc_.HandleChatCompletion(q, json_body);
auto ir = inference_svc_->HandleChatCompletion(q, json_body);
if (ir.has_error()) {
auto err = ir.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err));
Expand All @@ -47,7 +48,7 @@ void server::Embedding(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
LOG_TRACE << "Start embedding";
auto q = std::make_shared<services::SyncQueue>();
auto ir = inference_svc_.HandleEmbedding(q, req->getJsonObject());
auto ir = inference_svc_->HandleEmbedding(q, req->getJsonObject());
if (ir.has_error()) {
auto err = ir.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err));
Expand All @@ -64,7 +65,7 @@ void server::Embedding(const HttpRequestPtr& req,
void server::UnloadModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto ir = inference_svc_.UnloadModel(req->getJsonObject());
auto ir = inference_svc_->UnloadModel(req->getJsonObject());
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
resp->setStatusCode(
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
Expand All @@ -74,7 +75,7 @@ void server::UnloadModel(
void server::ModelStatus(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto ir = inference_svc_.GetModelStatus(req->getJsonObject());
auto ir = inference_svc_->GetModelStatus(req->getJsonObject());
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
resp->setStatusCode(
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
Expand All @@ -84,7 +85,7 @@ void server::ModelStatus(
void server::GetModels(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
LOG_TRACE << "Start to get models";
auto ir = inference_svc_.GetModels(req->getJsonObject());
auto ir = inference_svc_->GetModels(req->getJsonObject());
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
resp->setStatusCode(
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
Expand All @@ -95,15 +96,15 @@ void server::GetModels(const HttpRequestPtr& req,
void server::GetEngines(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto ir = inference_svc_.GetEngines(req->getJsonObject());
auto ir = inference_svc_->GetEngines(req->getJsonObject());
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ir);
callback(resp);
}

void server::FineTuning(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto ir = inference_svc_.FineTuning(req->getJsonObject());
auto ir = inference_svc_->FineTuning(req->getJsonObject());
vansangpfiev marked this conversation as resolved.
Show resolved Hide resolved
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
resp->setStatusCode(
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
Expand All @@ -113,7 +114,7 @@ void server::FineTuning(

void server::LoadModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto ir = inference_svc_.LoadModel(req->getJsonObject());
auto ir = inference_svc_->LoadModel(req->getJsonObject());
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
resp->setStatusCode(
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
Expand All @@ -124,7 +125,7 @@ void server::LoadModel(const HttpRequestPtr& req,
void server::UnloadEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto ir = inference_svc_.UnloadEngine(req->getJsonObject());
auto ir = inference_svc_->UnloadEngine(req->getJsonObject());
auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(ir));
resp->setStatusCode(
static_cast<HttpStatusCode>(std::get<0>(ir)["status_code"].asInt()));
Expand Down
10 changes: 6 additions & 4 deletions engine/controllers/server.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@

#pragma once
#include <nlohmann/json.hpp>
#include <string>
#include <memory>

#if defined(_WIN32)
#define NOMINMAX
#endif
#pragma once

#include <drogon/HttpController.h>

Expand All @@ -31,12 +33,12 @@ using namespace drogon;

namespace inferences {

class server : public drogon::HttpController<server>,
class server : public drogon::HttpController<server, false>,
public BaseModel,
public BaseChatCompletion,
public BaseEmbedding {
public:
server();
server(std::shared_ptr<services::InferenceService> inference_service);
~server();
METHOD_LIST_BEGIN
// list path definitions here;
Expand Down Expand Up @@ -100,6 +102,6 @@ class server : public drogon::HttpController<server>,
services::SyncQueue& q);

private:
services::InferenceService inference_svc_;
std::shared_ptr<services::InferenceService> inference_svc_;
};
}; // namespace inferences
7 changes: 6 additions & 1 deletion engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "controllers/events.h"
#include "controllers/models.h"
#include "controllers/process_manager.h"
#include "controllers/server.h"
#include "cortex-common/cortexpythoni.h"
#include "services/model_service.h"
#include "utils/archive_utils.h"
Expand Down Expand Up @@ -88,21 +89,25 @@ void RunServer(std::optional<int> port) {

auto event_queue_ptr = std::make_shared<EventQueue>();
cortex::event::EventProcessor event_processor(event_queue_ptr);
auto inference_svc = std::make_shared<services::InferenceService>();

auto download_service = std::make_shared<DownloadService>(event_queue_ptr);
auto engine_service = std::make_shared<EngineService>(download_service);
auto model_service = std::make_shared<ModelService>(download_service);
auto model_service =
std::make_shared<ModelService>(download_service, inference_svc);

// initialize custom controllers
auto engine_ctl = std::make_shared<Engines>(engine_service);
auto model_ctl = std::make_shared<Models>(model_service, engine_service);
auto event_ctl = std::make_shared<Events>(event_queue_ptr);
auto pm_ctl = std::make_shared<ProcessManager>();
auto server_ctl = std::make_shared<inferences::server>(inference_svc);

drogon::app().registerController(engine_ctl);
drogon::app().registerController(model_ctl);
drogon::app().registerController(event_ctl);
drogon::app().registerController(pm_ctl);
drogon::app().registerController(server_ctl);

LOG_INFO << "Server started, listening at: " << config.apiServerHost << ":"
<< config.apiServerPort;
Expand Down
5 changes: 4 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include <condition_variable>
#include <mutex>
#include <optional>
#include <queue>
#include <unordered_map>
#include <variant>
#include "common/base.h"
#include "cortex-common/EngineI.h"
#include "cortex-common/cortexpythoni.h"
#include "utils/dylib.h"
Expand Down
61 changes: 24 additions & 37 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "utils/engine_constants.h"
#include "utils/file_manager_utils.h"
#include "utils/huggingface_utils.h"
#include "utils/json_helper.h"
#include "utils/logging_utils.h"
#include "utils/result.hpp"
#include "utils/string_utils.h"
Expand Down Expand Up @@ -611,28 +612,21 @@ cpp::result<bool, std::string> ModelService::StartModel(
json_data["ai_prompt"] = mc.ai_template;
}

auto data_str = json_data.toStyledString();
CTL_INF(data_str);
cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
return true;
} else if (res->status == httplib::StatusCode::Conflict_409) {
CTL_INF("Model '" + model_handle + "' is already loaded");
return true;
} else {
auto root = json_helper::ParseJsonString(res->body);
CTL_ERR("Model failed to load with status code: " << res->status);
return cpp::fail("Model failed to start: " + root["message"].asString());
}
CTL_INF(json_data.toStyledString());
assert(!!inference_svc_);
auto ir =
inference_svc_->LoadModel(std::make_shared<Json::Value>(json_data));
auto status = std::get<0>(ir)["status_code"].asInt();
auto data = std::get<1>(ir);
if (status == httplib::StatusCode::OK_200) {
return true;
} else if (status == httplib::StatusCode::Conflict_409) {
CTL_INF("Model '" + model_handle + "' is already loaded");
return true;
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return cpp::fail("HTTP error: " + httplib::to_string(err));
CTL_ERR("Model failed to start with status code: " << status);
return cpp::fail("Model failed to start: " + data["message"].asString());
}

} catch (const std::exception& e) {
return cpp::fail("Fail to load model with ID '" + model_handle +
"': " + e.what());
Expand Down Expand Up @@ -663,25 +657,18 @@ cpp::result<bool, std::string> ModelService::StopModel(
Json::Value json_data;
json_data["model"] = model_handle;
json_data["engine"] = mc.engine;
auto data_str = json_data.toStyledString();
CTL_INF(data_str);
cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/inferences/server/unloadmodel", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
return true;
} else {
CTL_ERR("Model failed to unload with status code: " << res->status);
return cpp::fail("Model failed to unload with status code: " +
std::to_string(res->status));
}
CTL_INF(json_data.toStyledString());
assert(inference_svc_);
auto ir =
inference_svc_->UnloadModel(std::make_shared<Json::Value>(json_data));
auto status = std::get<0>(ir)["status_code"].asInt();
auto data = std::get<1>(ir);
if (status == httplib::StatusCode::OK_200) {
return true;
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return cpp::fail("HTTP error: " + httplib::to_string(err));
CTL_ERR("Model failed to stop with status code: " << status);
return cpp::fail("Model failed to stop: " + data["message"].asString());
}

} catch (const std::exception& e) {
return cpp::fail("Fail to unload model with ID '" + model_handle +
"': " + e.what());
Expand Down
9 changes: 9 additions & 0 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
#include <string>
#include "config/model_config.h"
#include "services/download_service.h"
#include "services/inference_service.h"

class ModelService {
public:
constexpr auto static kHuggingFaceHost = "huggingface.co";

explicit ModelService(std::shared_ptr<DownloadService> download_service)
: download_service_{download_service} {};

explicit ModelService(
std::shared_ptr<DownloadService> download_service,
std::shared_ptr<services::InferenceService> inference_service)
: download_service_{download_service},
inference_svc_(inference_service) {};

/**
* Return model id if download successfully
*/
Expand Down Expand Up @@ -67,4 +75,5 @@ class ModelService {
const std::string& modelName);

std::shared_ptr<DownloadService> download_service_;
std::shared_ptr<services::InferenceService> inference_svc_;
};
Loading