-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ba6816f
commit 4800004
Showing
18 changed files
with
282 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include "cmd_info.h" | ||
#include <vector> | ||
#include "trantor/utils/Logger.h" | ||
|
||
namespace commands { | ||
namespace { | ||
constexpr const char* kDelimiter = ":"; | ||
|
||
std::vector<std::string> split(std::string& s, const std::string& delimiter) { | ||
std::vector<std::string> tokens; | ||
size_t pos = 0; | ||
std::string token; | ||
while ((pos = s.find(delimiter)) != std::string::npos) { | ||
token = s.substr(0, pos); | ||
tokens.push_back(token); | ||
s.erase(0, pos + delimiter.length()); | ||
} | ||
tokens.push_back(s); | ||
|
||
return tokens; | ||
} | ||
} // namespace | ||
|
||
CmdInfo::CmdInfo(std::string model_id) { | ||
Parse(std::move(model_id)); | ||
} | ||
|
||
void CmdInfo::Parse(std::string model_id) { | ||
if (model_id.find(kDelimiter) == std::string::npos) { | ||
engine = "cortex.llamacpp"; | ||
name = std::move(model_id); | ||
branch = "main"; | ||
} else { | ||
auto res = split(model_id, kDelimiter); | ||
if (res.size() != 2) { | ||
LOG_ERROR << "model_id does not valid"; | ||
return; | ||
} else { | ||
name = std::move(res[0]); | ||
branch = std::move(res[1]); | ||
if (branch.find("onnx") != std::string::npos) { | ||
engine = "cortex.onnx"; | ||
} else if (branch.find("tensorrt") != std::string::npos) { | ||
engine = "cortex.tensorrt-llm"; | ||
} else if (branch.find("gguf") != std::string::npos) { | ||
engine = "cortex.llamacpp"; | ||
} else { | ||
LOG_ERROR << "Not a valid branch name " << branch; | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace commands |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#pragma once | ||
#include <string> | ||
namespace commands { | ||
struct CmdInfo { | ||
explicit CmdInfo(std::string model_id); | ||
|
||
std::string engine; | ||
std::string name; | ||
std::string branch; | ||
|
||
private: | ||
void Parse(std::string model_id); | ||
}; | ||
} // namespace commands |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include "run_cmd.h" | ||
#include "chat_cmd.h" | ||
#include "cmd_info.h" | ||
#include "config/yaml_config.h" | ||
#include "engine_init_cmd.h" | ||
#include "httplib.h" | ||
#include "model_pull_cmd.h" | ||
#include "model_start_cmd.h" | ||
#include "trantor/utils/Logger.h" | ||
#include "utils/cortex_utils.h" | ||
|
||
namespace commands { | ||
|
||
RunCmd::RunCmd(std::string host, int port, std::string model_id) | ||
: host_(std::move(host)), port_(port), model_id_(std::move(model_id)) {} | ||
|
||
void RunCmd::Exec() { | ||
auto address = host_ + ":" + std::to_string(port_); | ||
CmdInfo ci(model_id_); | ||
std::string model_file = | ||
ci.branch == "main" ? ci.name : ci.name + "-" + ci.branch; | ||
// TODO should we clean all resource if something fails? | ||
// Check if model existed. If not, download it | ||
{ | ||
if (!IsModelExisted(model_file)) { | ||
ModelPullCmd model_pull_cmd(ci.name, ci.branch); | ||
if (!model_pull_cmd.Exec()) { | ||
return; | ||
} | ||
} | ||
} | ||
|
||
// Check if engine existed. If not, download it | ||
{ | ||
if (!IsEngineExisted(ci.engine)) { | ||
EngineInitCmd eic(ci.engine, ""); | ||
if (!eic.Exec()) | ||
return; | ||
} | ||
} | ||
|
||
// Start model | ||
config::YamlHandler yaml_handler; | ||
yaml_handler.ModelConfigFromFile(cortex_utils::GetCurrentPath() + "/models/" + | ||
model_file + ".yaml"); | ||
{ | ||
ModelStartCmd msc(host_, port_, yaml_handler.GetModelConfig()); | ||
if (!msc.Exec()) { | ||
return; | ||
} | ||
} | ||
|
||
// Chat | ||
{ | ||
ChatCmd cc(host_, port_, yaml_handler.GetModelConfig()); | ||
cc.Exec(""); | ||
} | ||
} | ||
|
||
bool RunCmd::IsModelExisted(const std::string& model_id) { | ||
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" + | ||
cortex_utils::models_folder) && | ||
std::filesystem::is_directory(cortex_utils::GetCurrentPath() + "/" + | ||
cortex_utils::models_folder)) { | ||
// Iterate through directory | ||
for (const auto& entry : std::filesystem::directory_iterator( | ||
cortex_utils::GetCurrentPath() + "/" + | ||
cortex_utils::models_folder)) { | ||
if (entry.is_regular_file() && entry.path().extension() == ".yaml") { | ||
try { | ||
config::YamlHandler handler; | ||
handler.ModelConfigFromFile(entry.path().string()); | ||
std::cout << entry.path().stem().string() << std::endl; | ||
if (entry.path().stem().string() == model_id) { | ||
return true; | ||
} | ||
} catch (const std::exception& e) { | ||
LOG_ERROR << "Error reading yaml file '" << entry.path().string() | ||
<< "': " << e.what(); | ||
} | ||
} | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
bool RunCmd::IsEngineExisted(const std::string& e) { | ||
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" + | ||
"engines") && | ||
std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" + | ||
"engines/" + e)) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
}; // namespace commands |
Oops, something went wrong.