Skip to content

Commit

Permalink
feat: cortex pull and cortex engines install CLI uses API server (#1550)
Browse files Browse the repository at this point in the history
* fix: add ws and indicators

* fix: more

* fix: pull models info from server

* fix: model_source

* fix: rename

* fix: download cortexso

* fix: pull models

* fix: remove comments

* fix: rename

* fix: change download UI

* fix: comment out

* fix: e2e tests

* fix: run, start

* fix: start server

* fix: e2e

* fix: remove

* fix: abort model

* fix: build

* fix: clean code

* fix: clean more

* fix: normalize engine id

* fix: use auto

* fix: use vcpkg for indicators

* fix: download progress

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Oct 29, 2024
1 parent 04c5c40 commit 00af979
Show file tree
Hide file tree
Showing 32 changed files with 1,526 additions and 81 deletions.
5 changes: 5 additions & 0 deletions engine/cli/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ find_package(tabulate CONFIG REQUIRED)
find_package(CURL REQUIRED)
find_package(SQLiteCpp REQUIRED)
find_package(Trantor CONFIG REQUIRED)
find_package(indicators CONFIG REQUIRED)


add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc
Expand All @@ -80,6 +82,8 @@ add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc
)

target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib)
Expand All @@ -93,6 +97,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE JsonCpp::JsonCpp OpenSSL::SSL OpenS
${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET_NAME} PRIVATE SQLiteCpp)
target_link_libraries(${TARGET_NAME} PRIVATE Trantor::Trantor)
target_link_libraries(${TARGET_NAME} PRIVATE indicators::indicators)

# ##############################################################################

Expand Down
8 changes: 6 additions & 2 deletions engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ void CommandLineParser::SetupCommonCommands() {
return;
}
try {
commands::ModelPullCmd(download_service_).Exec(cml_data_.model_id);
commands::ModelPullCmd(download_service_)
.Exec(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id);
} catch (const std::exception& e) {
CLI_LOG(e.what());
}
Expand Down Expand Up @@ -462,7 +464,9 @@ void CommandLineParser::EngineInstall(CLI::App* parent,
if (std::exchange(executed_, true))
return;
try {
commands::EngineInstallCmd(download_service_)
commands::EngineInstallCmd(download_service_,
cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort))
.Exec(engine_name, version, src);
} catch (const std::exception& e) {
CTL_ERR(e.what());
Expand Down
66 changes: 59 additions & 7 deletions engine/cli/commands/engine_install_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,16 +1,68 @@
#include "engine_install_cmd.h"
#include "server_start_cmd.h"
#include "utils/download_progress.h"
#include "utils/engine_constants.h"
#include "utils/json_helper.h"
#include "utils/logging_utils.h"

namespace commands {

void EngineInstallCmd::Exec(const std::string& engine,
bool EngineInstallCmd::Exec(const std::string& engine,
const std::string& version,
const std::string& src) {
auto result = engine_service_.InstallEngine(engine, version, src);
if (result.has_error()) {
CLI_LOG(result.error());
} else if(result && result.value()){
CLI_LOG("Engine " << engine << " installed successfully!");
// Handle local install, if fails, fallback to remote install
if (!src.empty()) {
auto res = engine_service_.UnzipEngine(engine, version, src);
if (res.has_error()) {
CLI_LOG(res.error());
return false;
}
if (res.value()) {
CLI_LOG("Engine " << engine << " installed successfully!");
return true;
}
}

// Start server if server is not started yet
if (!commands::IsServerAlive(host_, port_)) {
CLI_LOG("Starting server ...");
commands::ServerStartCmd ssc;
if (!ssc.Exec(host_, port_)) {
return false;
}
}

httplib::Client cli(host_ + ":" + std::to_string(port_));
Json::Value json_data;
auto data_str = json_data.toStyledString();
cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/v1/engines/install/" + engine, httplib::Headers(),
data_str.data(), data_str.size(), "application/json");

if (res) {
if (res->status != httplib::StatusCode::OK_200) {
auto root = json_helper::ParseJsonString(res->body);
CLI_LOG(root["message"].asString());
return false;
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return false;
}

CLI_LOG("Start downloading ...")
DownloadProgress dp;
dp.Connect(host_, port_);
if (!dp.Handle(engine))
return false;

bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (check_cuda_download) {
if (!dp.Handle("cuda"))
return false;
}

CLI_LOG("Engine " << engine << " downloaded successfully!")
return true;
}
}; // namespace commands
8 changes: 5 additions & 3 deletions engine/cli/commands/engine_install_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ namespace commands {

class EngineInstallCmd {
public:
explicit EngineInstallCmd(std::shared_ptr<DownloadService> download_service)
: engine_service_{EngineService(download_service)} {};
explicit EngineInstallCmd(std::shared_ptr<DownloadService> download_service, const std::string& host, int port)
: engine_service_{EngineService(download_service)}, host_(host), port_(port) {};

void Exec(const std::string& engine, const std::string& version = "latest",
bool Exec(const std::string& engine, const std::string& version = "latest",
const std::string& src = "");

private:
EngineService engine_service_;
std::string host_;
int port_;
};
} // namespace commands
178 changes: 174 additions & 4 deletions engine/cli/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,181 @@
#include "model_pull_cmd.h"
#include <memory>
#include "common/event.h"
#include "database/models.h"
#include "server_start_cmd.h"
#include "utils/cli_selection_utils.h"
#include "utils/download_progress.h"
#include "utils/format_utils.h"
#include "utils/huggingface_utils.h"
#include "utils/json_helper.h"
#include "utils/logging_utils.h"
#include "utils/scope_exit.h"
#include "utils/string_utils.h"
#if defined(_WIN32)
#include <signal.h>
#endif

namespace commands {
void ModelPullCmd::Exec(const std::string& input) {
auto result = model_service_.DownloadModel(input);
if (result.has_error()) {
CLI_LOG(result.error());
std::function<void(int)> shutdown_handler;
inline void signal_handler(int signal) {
if (shutdown_handler) {
shutdown_handler(signal);
}
}
std::optional<std::string> ModelPullCmd::Exec(const std::string& host, int port,
const std::string& input) {

// model_id: use to check the download progress
// model: use as a parameter for pull API
auto model_id = input;
auto model = input;

// Start server if server is not started yet
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
commands::ServerStartCmd ssc;
if (!ssc.Exec(host, port)) {
return std::nullopt;
}
}

// Get model info from Server
httplib::Client cli(host + ":" + std::to_string(port));
cli.set_read_timeout(std::chrono::seconds(60));
Json::Value j_data;
j_data["model"] = input;
auto d_str = j_data.toStyledString();
auto res = cli.Post("/models/pull/info", httplib::Headers(), d_str.data(),
d_str.size(), "application/json");

if (res) {
if (res->status == httplib::StatusCode::OK_200) {
// CLI_LOG(res->body);
auto root = json_helper::ParseJsonString(res->body);
auto id = root["id"].asString();
bool is_cortexso = root["modelSource"].asString() == "cortexso";
auto default_branch = root["defaultBranch"].asString();
std::vector<std::string> downloaded;
for (auto const& v : root["downloadedModels"]) {
downloaded.push_back(v.asString());
}
std::vector<std::string> avails;
for (auto const& v : root["availableModels"]) {
avails.push_back(v.asString());
}
auto download_url = root["downloadUrl"].asString();

if (downloaded.empty() && avails.empty()) {
model_id = id;
model = download_url;
} else {
if (is_cortexso) {
auto selection = cli_selection_utils::PrintModelSelection(
downloaded, avails,
default_branch.empty()
? std::nullopt
: std::optional<std::string>(default_branch));

if (!selection.has_value()) {
CLI_LOG("Invalid selection");
return std::nullopt;
}
model_id = selection.value();
model = model_id;
} else {
auto selection = cli_selection_utils::PrintSelection(avails);
CLI_LOG("Selected: " << selection.value());
model_id = id + ":" + selection.value();
model = download_url + selection.value();
}
}
} else {
auto root = json_helper::ParseJsonString(res->body);
CLI_LOG(root["message"].asString());
return std::nullopt;
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return std::nullopt;
}

// Send request download model to server
Json::Value json_data;
json_data["model"] = model;
auto data_str = json_data.toStyledString();
cli.set_read_timeout(std::chrono::seconds(60));
res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(),
data_str.size(), "application/json");

if (res) {
if (res->status != httplib::StatusCode::OK_200) {
auto root = json_helper::ParseJsonString(res->body);
CLI_LOG(root["message"].asString());
return std::nullopt;
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return std::nullopt;
}

CLI_LOG("Start downloading ...")
DownloadProgress dp;
bool force_stop = false;

shutdown_handler = [this, &dp, &host, &port, &model_id, &force_stop](int) {
force_stop = true;
AbortModelPull(host, port, model_id);
dp.ForceStop();
};

utils::ScopeExit se([]() { shutdown_handler = {}; });
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset(&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined(_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(
reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
dp.Connect(host, port);
if (!dp.Handle(model_id))
return std::nullopt;
if (force_stop)
return std::nullopt;
CLI_LOG("Model " << model_id << " downloaded successfully!")
return model_id;
}

bool ModelPullCmd::AbortModelPull(const std::string& host, int port,
const std::string& task_id) {
Json::Value json_data;
json_data["taskId"] = task_id;
auto data_str = json_data.toStyledString();
httplib::Client cli(host + ":" + std::to_string(port));
cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Delete("/v1/models/pull", httplib::Headers(), data_str.data(),
data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
CTL_INF("Abort model pull successfully: " << task_id);
return true;
} else {
auto root = json_helper::ParseJsonString(res->body);
CLI_LOG(root["message"].asString());
return false;
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return false;
}
}
}; // namespace commands
9 changes: 8 additions & 1 deletion engine/cli/commands/model_pull_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@ class ModelPullCmd {
public:
explicit ModelPullCmd(std::shared_ptr<DownloadService> download_service)
: model_service_{ModelService(download_service)} {};
void Exec(const std::string& input);
explicit ModelPullCmd(const ModelService& model_service)
: model_service_{model_service} {};
std::optional<std::string> Exec(const std::string& host, int port,
const std::string& input);

private:
bool AbortModelPull(const std::string& host, int port,
const std::string& task_id);

private:
ModelService model_service_;
Expand Down
2 changes: 1 addition & 1 deletion engine/cli/commands/model_start_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ bool ModelStartCmd::Exec(const std::string& host, int port,
const std::string& model_handle,
bool print_success_log) {
std::optional<std::string> model_id =
SelectLocalModel(model_service_, model_handle);
SelectLocalModel(host, port, model_service_, model_handle);

if (!model_id.has_value()) {
return false;
Expand Down
Loading

0 comments on commit 00af979

Please sign in to comment.