Skip to content

Commit

Permalink
feat(#1512): simplify cortex run
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Oct 21, 2024
1 parent a6b64f8 commit 0aad82c
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 94 deletions.
58 changes: 32 additions & 26 deletions engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "command_line_parser.h"
#include <memory>
#include <optional>
#include "commands/chat_cmd.h"
#include "commands/chat_completion_cmd.h"
#include "commands/cortex_upd_cmd.h"
Expand Down Expand Up @@ -82,27 +83,30 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
// Check new update
#ifdef CORTEX_CPP_VERSION
if (cml_data_.check_upd) {
// TODO(sang) find a better way to handle
// This is an extremely ungly way to deal with connection
// hang when network down
std::atomic<bool> done = false;
std::thread t([&]() {
if (auto latest_version =
commands::CheckNewUpdate(commands::kTimeoutCheckUpdate);
latest_version.has_value() && *latest_version != CORTEX_CPP_VERSION) {
CLI_LOG("\nA new release of cortex is available: "
<< CORTEX_CPP_VERSION << " -> " << *latest_version);
CLI_LOG("To upgrade, run: " << commands::GetRole()
<< commands::GetCortexBinary()
<< " update");
if (strcmp(CORTEX_CPP_VERSION, "default_version") != 0) {
// TODO(sang) find a better way to handle
// This is an extremely ugly way to deal with connection
// hang when network down
std::atomic<bool> done = false;
std::thread t([&]() {
if (auto latest_version =
commands::CheckNewUpdate(commands::kTimeoutCheckUpdate);
latest_version.has_value() &&
*latest_version != CORTEX_CPP_VERSION) {
CLI_LOG("\nA new release of cortex is available: "
<< CORTEX_CPP_VERSION << " -> " << *latest_version);
CLI_LOG("To upgrade, run: " << commands::GetRole()
<< commands::GetCortexBinary()
<< " update");
}
done = true;
});
// Do not wait for http connection timeout
t.detach();
int retry = 10;
while (!done && retry--) {
std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10);
}
done = true;
});
// Do not wait for http connection timeout
t.detach();
int retry = 10;
while (!done && retry--) {
std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10);
}
}
#endif
Expand Down Expand Up @@ -143,11 +147,6 @@ void CommandLineParser::SetupCommonCommands() {
run_cmd->callback([this, run_cmd] {
if (std::exchange(executed_, true))
return;
if (cml_data_.model_id.empty()) {
CLI_LOG("[model_id] is required\n");
CLI_LOG(run_cmd->help());
return;
}
commands::RunCmd rc(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort),
cml_data_.model_id, download_service_);
Expand Down Expand Up @@ -247,12 +246,19 @@ void CommandLineParser::SetupModelCommands() {

auto list_models_cmd =
models_cmd->add_subcommand("list", "List all models locally");
list_models_cmd->add_option("filter", cml_data_.filter, "Filter model id");
list_models_cmd->add_flag("-e,--engine", cml_data_.display_engine,
"Display engine");
list_models_cmd->add_flag("-v,--version", cml_data_.display_version,
"Display version");
list_models_cmd->group(kSubcommands);
list_models_cmd->callback([this]() {
if (std::exchange(executed_, true))
return;
commands::ModelListCmd().Exec(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort));
std::stoi(cml_data_.config.apiServerPort),
cml_data_.filter, cml_data_.display_engine,
cml_data_.display_version);
});

auto get_models_cmd =
Expand Down
6 changes: 6 additions & 0 deletions engine/cli/command_line_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class CommandLineParser {
std::string cortex_version;
bool check_upd = true;
bool run_detach = false;

// for model list
bool display_engine = false;
bool display_version = false;
std::string filter = "";

int port;
config_yaml_utils::CortexConfig config;
std::unordered_map<std::string, std::string> model_update_options;
Expand Down
66 changes: 38 additions & 28 deletions engine/cli/commands/model_list_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
#include "model_list_cmd.h"
#include <json/reader.h>
#include <json/value.h>
#include <iostream>

#include <vector>
#include "httplib.h"
#include "json/json.h"
#include "server_start_cmd.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
// clang-format off
#include <tabulate/table.hpp>
// clang-format on

namespace commands {
using namespace tabulate;
using Row_t =
std::vector<variant<std::string, const char*, string_view, Table>>;

void ModelListCmd::Exec(const std::string& host, int port) {
void ModelListCmd::Exec(const std::string& host, int port, std::string filter,
bool display_engine, bool display_version) {
// Start server if server is not started yet
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
Expand All @@ -21,27 +28,48 @@ void ModelListCmd::Exec(const std::string& host, int port) {
}
}

tabulate::Table table;
Table table;
std::vector<std::string> column_headers{"(Index)", "ID"};
if (display_engine) {
column_headers.push_back("Engine");
}
if (display_version) {
column_headers.push_back("Version");
}

table.add_row({"(Index)", "ID", "model alias", "engine", "version"});
table.format().font_color(tabulate::Color::green);
Row_t header{column_headers.begin(), column_headers.end()};
table.add_row(header);
table.format().font_color(Color::green);
int count = 0;
// Iterate through directory

httplib::Client cli(host + ":" + std::to_string(port));
auto res = cli.Get("/v1/models");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
// CLI_LOG(res->body);
Json::Value body;
Json::Reader reader;
reader.parse(res->body, body);
if (!body["data"].isNull()) {
for (auto const& v : body["data"]) {
auto model_id = v["model"].asString();
if (!filter.empty() &&
!string_utils::StringContainsIgnoreCase(model_id, filter)) {
continue;
}

count += 1;
table.add_row({std::to_string(count), v["model"].asString(),
v["model_alias"].asString(), v["engine"].asString(),
v["version"].asString()});

std::vector<std::string> row = {std::to_string(count),
v["model"].asString()};
if (display_engine) {
row.push_back(v["engine"].asString());
}
if (display_version) {
row.push_back(v["version"].asString());
}

table.add_row({row.begin(), row.end()});
}
}
} else {
Expand All @@ -54,24 +82,6 @@ void ModelListCmd::Exec(const std::string& host, int port) {
return;
}

for (int i = 0; i < 5; i++) {
table[0][i]
.format()
.font_color(tabulate::Color::white) // Set font color
.font_style({tabulate::FontStyle::bold})
.font_align(tabulate::FontAlign::center);
}
for (int i = 1; i <= count; i++) {
table[i][0] //index value
.format()
.font_color(tabulate::Color::white) // Set font color
.font_align(tabulate::FontAlign::center);
table[i][4] //version value
.format()
.font_align(tabulate::FontAlign::center);
}
std::cout << table << std::endl;
}
}

; // namespace commands
}; // namespace commands
4 changes: 3 additions & 1 deletion engine/cli/commands/model_list_cmd.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#pragma once

#include <string>

namespace commands {

class ModelListCmd {
public:
void Exec(const std::string& host, int port);
void Exec(const std::string& host, int port, std::string filter,
bool display_engine = false, bool display_version = false);
};
} // namespace commands
59 changes: 45 additions & 14 deletions engine/cli/commands/run_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,54 @@ void RunCmd::Exec(bool run_detach) {
config::YamlHandler yaml_handler;
auto address = host_ + ":" + std::to_string(port_);

// Download model if it does not exist
{
auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() || related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(), "Local Models: (press enter to select)");
if (!selection.has_value()) {
if (model_handle_.empty()) {
auto all_local_models = modellist_handler.LoadModelList();
if (all_local_models.has_error() || all_local_models.value().empty()) {
CLI_LOG("No local models available!");
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());

if (all_local_models.value().size() == 1) {
model_id = all_local_models.value().front().model;
} else {
std::vector<std::string> model_id_list{};
for (const auto& model : all_local_models.value()) {
model_id_list.push_back(model.model);
}

auto selection = cli_selection_utils::PrintSelection(
model_id_list, "Please select an option");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
} else {
auto related_models_ids =
modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() ||
related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
if (result.has_error()) {
CLI_LOG("Model " << model_handle_ << " not found!");
return;
}
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(),
"Local Models: (press enter to select)");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
}
}

Expand Down
5 changes: 1 addition & 4 deletions engine/cli/commands/server_start_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#include "server_start_cmd.h"
#include "commands/cortex_upd_cmd.h"
#include "httplib.h"
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/file_manager_utils.h"
#include "utils/logging_utils.h"

namespace commands {

Expand Down Expand Up @@ -124,4 +121,4 @@ bool ServerStartCmd::Exec(const std::string& host, int port) {
return true;
}

}; // namespace commands
}; // namespace commands
3 changes: 2 additions & 1 deletion engine/cli/commands/server_start_cmd.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once

#include <string>
#include "httplib.h"

Expand All @@ -18,4 +19,4 @@ class ServerStartCmd {
ServerStartCmd();
bool Exec(const std::string& host, int port);
};
} // namespace commands
} // namespace commands
10 changes: 2 additions & 8 deletions engine/database/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,11 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(

cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
const std::string& identifier) const {
// TODO (namh): add check for alias as well
try {
std::vector<std::string> related_models;
SQLite::Statement query(
db_,
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
"OR model_id LIKE ? OR model_id LIKE ?");
query.bind(1, identifier + ":%");
query.bind(2, "%:" + identifier);
query.bind(3, "%:" + identifier + ":%");
query.bind(4, identifier);
db_, "SELECT model_id FROM models WHERE model_id LIKE ?");
query.bind(1, "%" + identifier + "%");

while (query.executeStep()) {
related_models.push_back(query.getColumn(0).getString());
Expand Down
Loading

0 comments on commit 0aad82c

Please sign in to comment.