Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: run command #1045

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/commands/chat_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void ChatCmd::Exec(std::string msg) {
}
}
// Some instruction for user here
std::cout << "Inorder to exit, type exit()" << std::endl;
std::cout << "Inorder to exit, type `exit()`" << std::endl;
// Model is loaded, start to chat
{
while (true) {
Expand Down
54 changes: 54 additions & 0 deletions engine/commands/cmd_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "cmd_info.h"
#include <vector>
#include "trantor/utils/Logger.h"

namespace commands {
namespace {
constexpr const char* kDelimiter = ":";

std::vector<std::string> split(std::string& s, const std::string& delimiter) {
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while ((pos = s.find(delimiter)) != std::string::npos) {
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);

return tokens;
}
} // namespace

CmdInfo::CmdInfo(std::string model_id) {
Parse(std::move(model_id));
}

void CmdInfo::Parse(std::string model_id) {
if (model_id.find(kDelimiter) == std::string::npos) {
engine_name = "cortex.llamacpp";
model_name = std::move(model_id);
branch = "main";
} else {
auto res = split(model_id, kDelimiter);
if (res.size() != 2) {
LOG_ERROR << "model_id does not valid";
return;
} else {
model_name = std::move(res[0]);
branch = std::move(res[1]);
if (branch.find("onnx") != std::string::npos) {
engine_name = "cortex.onnx";
} else if (branch.find("tensorrt") != std::string::npos) {
engine_name = "cortex.tensorrt-llm";
} else if (branch.find("gguf") != std::string::npos) {
engine_name = "cortex.llamacpp";
} else {
LOG_ERROR << "Not a valid branch model_name " << branch;
}
}
}
}

} // namespace commands
14 changes: 14 additions & 0 deletions engine/commands/cmd_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <string>
namespace commands {
struct CmdInfo {
explicit CmdInfo(std::string model_id);

std::string engine_name;
std::string model_name;
std::string branch;

private:
void Parse(std::string model_id);
};
} // namespace commands
57 changes: 33 additions & 24 deletions engine/commands/engine_init_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ namespace commands {
EngineInitCmd::EngineInitCmd(std::string engineName, std::string version)
: engineName_(std::move(engineName)), version_(std::move(version)) {}

void EngineInitCmd::Exec() const {
bool EngineInitCmd::Exec() const {
if (engineName_.empty()) {
LOG_ERROR << "Engine name is required";
return;
return false;
}

// Check if the architecture and OS are supported
Expand All @@ -26,15 +26,15 @@ void EngineInitCmd::Exec() const {
system_info.os == system_info_utils::kUnsupported) {
LOG_ERROR << "Unsupported OS or architecture: " << system_info.os << ", "
<< system_info.arch;
return;
return false;
}
LOG_INFO << "OS: " << system_info.os << ", Arch: " << system_info.arch;

// check if engine is supported
if (std::find(supportedEngines_.begin(), supportedEngines_.end(),
engineName_) == supportedEngines_.end()) {
LOG_ERROR << "Engine not supported";
return;
return false;
}

constexpr auto gitHubHost = "https://api.github.com";
Expand Down Expand Up @@ -78,7 +78,7 @@ void EngineInitCmd::Exec() const {
LOG_INFO << "Matched variant: " << matched_variant;
if (matched_variant.empty()) {
LOG_ERROR << "No variant found for " << os_arch;
return;
return false;
}

for (auto& asset : assets) {
Expand All @@ -103,36 +103,45 @@ void EngineInitCmd::Exec() const {
.path = path,
}}};

DownloadService().AddDownloadTask(
downloadTask, [](const std::string& absolute_path) {
// try to unzip the downloaded file
std::filesystem::path downloadedEnginePath{absolute_path};
LOG_INFO << "Downloaded engine path: "
<< downloadedEnginePath.string();

archive_utils::ExtractArchive(
downloadedEnginePath.string(),
downloadedEnginePath.parent_path()
.parent_path()
.string());

// remove the downloaded file
std::filesystem::remove(absolute_path);
LOG_INFO << "Finished!";
});

return;
DownloadService().AddDownloadTask(downloadTask, [](const std::string&
absolute_path,
bool unused) {
// try to unzip the downloaded file
std::filesystem::path downloadedEnginePath{absolute_path};
LOG_INFO << "Downloaded engine path: "
<< downloadedEnginePath.string();

archive_utils::ExtractArchive(
downloadedEnginePath.string(),
downloadedEnginePath.parent_path().parent_path().string());

// remove the downloaded file
// TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
vansangpfiev marked this conversation as resolved.
Show resolved Hide resolved
// Not sure about other platforms
try {
std::filesystem::remove(absolute_path);
} catch (const std::exception& e) {
LOG_ERROR << "Could not delete file: " << e.what();
}
LOG_INFO << "Finished!";
});

return true;
}
}
} catch (const json::parse_error& e) {
std::cerr << "JSON parse error: " << e.what() << std::endl;
return false;
}
} else {
LOG_ERROR << "HTTP error: " << res->status;
return false;
}
} else {
auto err = res.error();
LOG_ERROR << "HTTP error: " << httplib::to_string(err);
return false;
}
return true;
}
}; // namespace commands
2 changes: 1 addition & 1 deletion engine/commands/engine_init_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class EngineInitCmd {
public:
EngineInitCmd(std::string engineName, std::string version);

void Exec() const;
bool Exec() const;

private:
std::string engineName_;
Expand Down
10 changes: 6 additions & 4 deletions engine/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
#include "utils/model_callback_utils.h"

namespace commands {
ModelPullCmd::ModelPullCmd(std::string modelHandle)
: modelHandle_(std::move(modelHandle)) {}
ModelPullCmd::ModelPullCmd(std::string model_handle, std::string branch)
: model_handle_(std::move(model_handle)), branch_(std::move(branch)) {}

void ModelPullCmd::Exec() {
auto downloadTask = cortexso_parser::getDownloadTask(modelHandle_);
bool ModelPullCmd::Exec() {
auto downloadTask = cortexso_parser::getDownloadTask(model_handle_, branch_);
if (downloadTask.has_value()) {
DownloadService downloadService;
downloadService.AddDownloadTask(downloadTask.value(),
model_callback_utils::DownloadModelCb);
std::cout << "Download finished" << std::endl;
return true;
} else {
std::cout << "Model not found" << std::endl;
return false;
}
}

Expand Down
7 changes: 4 additions & 3 deletions engine/commands/model_pull_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ namespace commands {

class ModelPullCmd {
public:
ModelPullCmd(std::string modelHandle);
void Exec();
explicit ModelPullCmd(std::string model_handle, std::string branch);
bool Exec();

private:
std::string modelHandle_;
std::string model_handle_;
std::string branch_;
};
} // namespace commands
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
#include "start_model_cmd.h"
#include "model_start_cmd.h"
#include "httplib.h"
#include "nlohmann/json.hpp"
#include "trantor/utils/Logger.h"

namespace commands {
StartModelCmd::StartModelCmd(std::string host, int port,
ModelStartCmd::ModelStartCmd(std::string host, int port,
const config::ModelConfig& mc)
: host_(std::move(host)), port_(port), mc_(mc) {}

void StartModelCmd::Exec() {
bool ModelStartCmd::Exec() {
httplib::Client cli(host_ + ":" + std::to_string(port_));
nlohmann::json json_data;
if (mc_.files.size() > 0) {
// TODO(sang) support multiple files
json_data["model_path"] = mc_.files[0];
} else {
LOG_WARN << "model_path is empty";
return;
return false;
}
json_data["model"] = mc_.name;
json_data["system_prompt"] = mc_.system_template;
Expand All @@ -27,7 +27,7 @@ void StartModelCmd::Exec() {
json_data["engine"] = mc_.engine;

auto data_str = json_data.dump();

cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
Expand All @@ -37,7 +37,9 @@ void StartModelCmd::Exec() {
} else {
auto err = res.error();
LOG_WARN << "HTTP error: " << httplib::to_string(err);
return false;
}
return true;
}

}; // namespace commands
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

namespace commands {

class StartModelCmd{
class ModelStartCmd{
public:
StartModelCmd(std::string host, int port, const config::ModelConfig& mc);
void Exec();
explicit ModelStartCmd(std::string host, int port, const config::ModelConfig& mc);
bool Exec();

private:
std::string host_;
Expand Down
97 changes: 97 additions & 0 deletions engine/commands/run_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include "run_cmd.h"
#include "chat_cmd.h"
#include "cmd_info.h"
#include "config/yaml_config.h"
#include "engine_init_cmd.h"
#include "httplib.h"
#include "model_pull_cmd.h"
#include "model_start_cmd.h"
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"

namespace commands {

RunCmd::RunCmd(std::string host, int port, std::string model_id)
: host_(std::move(host)), port_(port), model_id_(std::move(model_id)) {}

void RunCmd::Exec() {
auto address = host_ + ":" + std::to_string(port_);
CmdInfo ci(model_id_);
std::string model_file =
ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch;
// TODO should we clean all resource if something fails?
// Check if model existed. If not, download it
{
if (!IsModelExisted(model_file)) {
ModelPullCmd model_pull_cmd(ci.model_name, ci.branch);
if (!model_pull_cmd.Exec()) {
return;
}
}
}

// Check if engine existed. If not, download it
{
if (!IsEngineExisted(ci.engine_name)) {
EngineInitCmd eic(ci.engine_name, "");
if (!eic.Exec())
return;
}
}

// Start model
config::YamlHandler yaml_handler;
yaml_handler.ModelConfigFromFile(cortex_utils::GetCurrentPath() + "/models/" +
model_file + ".yaml");
{
ModelStartCmd msc(host_, port_, yaml_handler.GetModelConfig());
if (!msc.Exec()) {
return;
}
}

// Chat
{
ChatCmd cc(host_, port_, yaml_handler.GetModelConfig());
cc.Exec("");
}
}

bool RunCmd::IsModelExisted(const std::string& model_id) {
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
cortex_utils::models_folder) &&
std::filesystem::is_directory(cortex_utils::GetCurrentPath() + "/" +
cortex_utils::models_folder)) {
// Iterate through directory
for (const auto& entry : std::filesystem::directory_iterator(
cortex_utils::GetCurrentPath() + "/" +
cortex_utils::models_folder)) {
if (entry.is_regular_file() && entry.path().extension() == ".yaml") {
try {
config::YamlHandler handler;
handler.ModelConfigFromFile(entry.path().string());
std::cout << entry.path().stem().string() << std::endl;
if (entry.path().stem().string() == model_id) {
return true;
}
} catch (const std::exception& e) {
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
<< "': " << e.what();
}
}
}
}
return false;
}

bool RunCmd::IsEngineExisted(const std::string& e) {
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
"engines") &&
std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
"engines/" + e)) {
return true;
}
return false;
}

}; // namespace commands
Loading
Loading