Skip to content

Commit

Permalink
feat: run command
Browse files Browse the repository at this point in the history
  • Loading branch information
vansangpfiev committed Aug 29, 2024
1 parent ba6816f commit 4800004
Show file tree
Hide file tree
Showing 18 changed files with 282 additions and 66 deletions.
2 changes: 1 addition & 1 deletion engine/commands/chat_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void ChatCmd::Exec(std::string msg) {
}
}
// Some instruction for user here
std::cout << "Inorder to exit, type exit()" << std::endl;
std::cout << "Inorder to exit, type `exit()`" << std::endl;
// Model is loaded, start to chat
{
while (true) {
Expand Down
54 changes: 54 additions & 0 deletions engine/commands/cmd_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "cmd_info.h"
#include <vector>
#include "trantor/utils/Logger.h"

namespace commands {
namespace {
constexpr const char* kDelimiter = ":";

std::vector<std::string> split(std::string& s, const std::string& delimiter) {
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while ((pos = s.find(delimiter)) != std::string::npos) {
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);

return tokens;
}
} // namespace

CmdInfo::CmdInfo(std::string model_id) {
Parse(std::move(model_id));
}

void CmdInfo::Parse(std::string model_id) {
if (model_id.find(kDelimiter) == std::string::npos) {
engine = "cortex.llamacpp";
name = std::move(model_id);
branch = "main";
} else {
auto res = split(model_id, kDelimiter);
if (res.size() != 2) {
LOG_ERROR << "model_id does not valid";
return;
} else {
name = std::move(res[0]);
branch = std::move(res[1]);
if (branch.find("onnx") != std::string::npos) {
engine = "cortex.onnx";
} else if (branch.find("tensorrt") != std::string::npos) {
engine = "cortex.tensorrt-llm";
} else if (branch.find("gguf") != std::string::npos) {
engine = "cortex.llamacpp";
} else {
LOG_ERROR << "Not a valid branch name " << branch;
}
}
}
}

} // namespace commands
14 changes: 14 additions & 0 deletions engine/commands/cmd_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <string>
namespace commands {
struct CmdInfo {
explicit CmdInfo(std::string model_id);

std::string engine;
std::string name;
std::string branch;

private:
void Parse(std::string model_id);
};
} // namespace commands
57 changes: 33 additions & 24 deletions engine/commands/engine_init_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ namespace commands {
EngineInitCmd::EngineInitCmd(std::string engineName, std::string version)
: engineName_(std::move(engineName)), version_(std::move(version)) {}

void EngineInitCmd::Exec() const {
bool EngineInitCmd::Exec() const {
if (engineName_.empty()) {
LOG_ERROR << "Engine name is required";
return;
return false;
}

// Check if the architecture and OS are supported
Expand All @@ -26,15 +26,15 @@ void EngineInitCmd::Exec() const {
system_info.os == system_info_utils::kUnsupported) {
LOG_ERROR << "Unsupported OS or architecture: " << system_info.os << ", "
<< system_info.arch;
return;
return false;
}
LOG_INFO << "OS: " << system_info.os << ", Arch: " << system_info.arch;

// check if engine is supported
if (std::find(supportedEngines_.begin(), supportedEngines_.end(),
engineName_) == supportedEngines_.end()) {
LOG_ERROR << "Engine not supported";
return;
return false;
}

constexpr auto gitHubHost = "https://api.github.com";
Expand Down Expand Up @@ -78,7 +78,7 @@ void EngineInitCmd::Exec() const {
LOG_INFO << "Matched variant: " << matched_variant;
if (matched_variant.empty()) {
LOG_ERROR << "No variant found for " << os_arch;
return;
return false;
}

for (auto& asset : assets) {
Expand All @@ -103,36 +103,45 @@ void EngineInitCmd::Exec() const {
.path = path,
}}};

DownloadService().AddDownloadTask(
downloadTask, [](const std::string& absolute_path) {
// try to unzip the downloaded file
std::filesystem::path downloadedEnginePath{absolute_path};
LOG_INFO << "Downloaded engine path: "
<< downloadedEnginePath.string();

archive_utils::ExtractArchive(
downloadedEnginePath.string(),
downloadedEnginePath.parent_path()
.parent_path()
.string());

// remove the downloaded file
std::filesystem::remove(absolute_path);
LOG_INFO << "Finished!";
});

return;
DownloadService().AddDownloadTask(downloadTask, [](const std::string&
absolute_path,
bool unused) {
// try to unzip the downloaded file
std::filesystem::path downloadedEnginePath{absolute_path};
LOG_INFO << "Downloaded engine path: "
<< downloadedEnginePath.string();

archive_utils::ExtractArchive(
downloadedEnginePath.string(),
downloadedEnginePath.parent_path().parent_path().string());

// remove the downloaded file
// TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
// Not sure about other platforms
try {
std::filesystem::remove(absolute_path);
} catch (const std::exception& e) {
LOG_ERROR << "Could not delete file: " << e.what();
}
LOG_INFO << "Finished!";
});

return false;
}
}
} catch (const json::parse_error& e) {
std::cerr << "JSON parse error: " << e.what() << std::endl;
return false;
}
} else {
LOG_ERROR << "HTTP error: " << res->status;
return false;
}
} else {
auto err = res.error();
LOG_ERROR << "HTTP error: " << httplib::to_string(err);
return false;
}
return true;
}
}; // namespace commands
2 changes: 1 addition & 1 deletion engine/commands/engine_init_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class EngineInitCmd {
public:
EngineInitCmd(std::string engineName, std::string version);

void Exec() const;
bool Exec() const;

private:
std::string engineName_;
Expand Down
10 changes: 6 additions & 4 deletions engine/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
#include "utils/model_callback_utils.h"

namespace commands {
ModelPullCmd::ModelPullCmd(std::string modelHandle)
: modelHandle_(std::move(modelHandle)) {}
ModelPullCmd::ModelPullCmd(std::string model_handle, std::string branch)
: model_handle_(std::move(model_handle)), branch_(std::move(branch)) {}

void ModelPullCmd::Exec() {
auto downloadTask = cortexso_parser::getDownloadTask(modelHandle_);
bool ModelPullCmd::Exec() {
auto downloadTask = cortexso_parser::getDownloadTask(model_handle_, branch_);
if (downloadTask.has_value()) {
DownloadService downloadService;
downloadService.AddDownloadTask(downloadTask.value(),
model_callback_utils::DownloadModelCb);
std::cout << "Download finished" << std::endl;
return true;
} else {
std::cout << "Model not found" << std::endl;
return false;
}
}

Expand Down
7 changes: 4 additions & 3 deletions engine/commands/model_pull_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ namespace commands {

class ModelPullCmd {
public:
ModelPullCmd(std::string modelHandle);
void Exec();
explicit ModelPullCmd(std::string model_handle, std::string branch);
bool Exec();

private:
std::string modelHandle_;
std::string model_handle_;
std::string branch_;
};
} // namespace commands
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
#include "start_model_cmd.h"
#include "model_start_cmd.h"
#include "httplib.h"
#include "nlohmann/json.hpp"
#include "trantor/utils/Logger.h"

namespace commands {
StartModelCmd::StartModelCmd(std::string host, int port,
ModelStartCmd::ModelStartCmd(std::string host, int port,
const config::ModelConfig& mc)
: host_(std::move(host)), port_(port), mc_(mc) {}

void StartModelCmd::Exec() {
bool ModelStartCmd::Exec() {
httplib::Client cli(host_ + ":" + std::to_string(port_));
nlohmann::json json_data;
if (mc_.files.size() > 0) {
// TODO(sang) support multiple files
json_data["model_path"] = mc_.files[0];
} else {
LOG_WARN << "model_path is empty";
return;
return false;
}
json_data["model"] = mc_.name;
json_data["system_prompt"] = mc_.system_template;
Expand All @@ -27,7 +27,7 @@ void StartModelCmd::Exec() {
json_data["engine"] = mc_.engine;

auto data_str = json_data.dump();

cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
Expand All @@ -37,7 +37,9 @@ void StartModelCmd::Exec() {
} else {
auto err = res.error();
LOG_WARN << "HTTP error: " << httplib::to_string(err);
return false;
}
return true;
}

}; // namespace commands
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

namespace commands {

class StartModelCmd{
class ModelStartCmd{
public:
StartModelCmd(std::string host, int port, const config::ModelConfig& mc);
void Exec();
explicit ModelStartCmd(std::string host, int port, const config::ModelConfig& mc);
bool Exec();

private:
std::string host_;
Expand Down
97 changes: 97 additions & 0 deletions engine/commands/run_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include "run_cmd.h"
#include "chat_cmd.h"
#include "cmd_info.h"
#include "config/yaml_config.h"
#include "engine_init_cmd.h"
#include "httplib.h"
#include "model_pull_cmd.h"
#include "model_start_cmd.h"
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"

namespace commands {

RunCmd::RunCmd(std::string host, int port, std::string model_id)
: host_(std::move(host)), port_(port), model_id_(std::move(model_id)) {}

void RunCmd::Exec() {
auto address = host_ + ":" + std::to_string(port_);
CmdInfo ci(model_id_);
std::string model_file =
ci.branch == "main" ? ci.name : ci.name + "-" + ci.branch;
// TODO should we clean all resource if something fails?
// Check if model existed. If not, download it
{
if (!IsModelExisted(model_file)) {
ModelPullCmd model_pull_cmd(ci.name, ci.branch);
if (!model_pull_cmd.Exec()) {
return;
}
}
}

// Check if engine existed. If not, download it
{
if (!IsEngineExisted(ci.engine)) {
EngineInitCmd eic(ci.engine, "");
if (!eic.Exec())
return;
}
}

// Start model
config::YamlHandler yaml_handler;
yaml_handler.ModelConfigFromFile(cortex_utils::GetCurrentPath() + "/models/" +
model_file + ".yaml");
{
ModelStartCmd msc(host_, port_, yaml_handler.GetModelConfig());
if (!msc.Exec()) {
return;
}
}

// Chat
{
ChatCmd cc(host_, port_, yaml_handler.GetModelConfig());
cc.Exec("");
}
}

bool RunCmd::IsModelExisted(const std::string& model_id) {
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
cortex_utils::models_folder) &&
std::filesystem::is_directory(cortex_utils::GetCurrentPath() + "/" +
cortex_utils::models_folder)) {
// Iterate through directory
for (const auto& entry : std::filesystem::directory_iterator(
cortex_utils::GetCurrentPath() + "/" +
cortex_utils::models_folder)) {
if (entry.is_regular_file() && entry.path().extension() == ".yaml") {
try {
config::YamlHandler handler;
handler.ModelConfigFromFile(entry.path().string());
std::cout << entry.path().stem().string() << std::endl;
if (entry.path().stem().string() == model_id) {
return true;
}
} catch (const std::exception& e) {
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
<< "': " << e.what();
}
}
}
}
return false;
}

bool RunCmd::IsEngineExisted(const std::string& e) {
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
"engines") &&
std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
"engines/" + e)) {
return true;
}
return false;
}

}; // namespace commands
Loading

0 comments on commit 4800004

Please sign in to comment.