Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(#1512): simplify cortex run #1521

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "command_line_parser.h"
#include <memory>
#include <optional>
#include "commands/chat_cmd.h"
#include "commands/chat_completion_cmd.h"
#include "commands/cortex_upd_cmd.h"
Expand Down Expand Up @@ -82,27 +83,30 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
// Check new update
#ifdef CORTEX_CPP_VERSION
if (cml_data_.check_upd) {
// TODO(sang) find a better way to handle
// This is an extremely ungly way to deal with connection
// hang when network down
std::atomic<bool> done = false;
std::thread t([&]() {
if (auto latest_version =
commands::CheckNewUpdate(commands::kTimeoutCheckUpdate);
latest_version.has_value() && *latest_version != CORTEX_CPP_VERSION) {
CLI_LOG("\nA new release of cortex is available: "
<< CORTEX_CPP_VERSION << " -> " << *latest_version);
CLI_LOG("To upgrade, run: " << commands::GetRole()
<< commands::GetCortexBinary()
<< " update");
if (strcmp(CORTEX_CPP_VERSION, "default_version") != 0) {
// TODO(sang) find a better way to handle
// This is an extremely ugly way to deal with connection
// hang when network down
std::atomic<bool> done = false;
std::thread t([&]() {
if (auto latest_version =
commands::CheckNewUpdate(commands::kTimeoutCheckUpdate);
latest_version.has_value() &&
*latest_version != CORTEX_CPP_VERSION) {
CLI_LOG("\nA new release of cortex is available: "
<< CORTEX_CPP_VERSION << " -> " << *latest_version);
CLI_LOG("To upgrade, run: " << commands::GetRole()
<< commands::GetCortexBinary()
<< " update");
}
done = true;
});
// Do not wait for http connection timeout
t.detach();
int retry = 10;
while (!done && retry--) {
std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10);
}
done = true;
});
// Do not wait for http connection timeout
t.detach();
int retry = 10;
while (!done && retry--) {
std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10);
}
}
#endif
Expand Down Expand Up @@ -143,11 +147,6 @@ void CommandLineParser::SetupCommonCommands() {
run_cmd->callback([this, run_cmd] {
if (std::exchange(executed_, true))
return;
if (cml_data_.model_id.empty()) {
CLI_LOG("[model_id] is required\n");
CLI_LOG(run_cmd->help());
return;
}
commands::RunCmd rc(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort),
cml_data_.model_id, download_service_);
Expand Down Expand Up @@ -247,12 +246,19 @@ void CommandLineParser::SetupModelCommands() {

auto list_models_cmd =
models_cmd->add_subcommand("list", "List all models locally");
list_models_cmd->add_option("filter", cml_data_.filter, "Filter model id");
list_models_cmd->add_flag("-e,--engine", cml_data_.display_engine,
"Display engine");
list_models_cmd->add_flag("-v,--version", cml_data_.display_version,
"Display version");
list_models_cmd->group(kSubcommands);
list_models_cmd->callback([this]() {
if (std::exchange(executed_, true))
return;
commands::ModelListCmd().Exec(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort));
std::stoi(cml_data_.config.apiServerPort),
cml_data_.filter, cml_data_.display_engine,
cml_data_.display_version);
});

auto get_models_cmd =
Expand Down
6 changes: 6 additions & 0 deletions engine/cli/command_line_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class CommandLineParser {
std::string cortex_version;
bool check_upd = true;
bool run_detach = false;

// for model list
bool display_engine = false;
bool display_version = false;
std::string filter = "";

int port;
config_yaml_utils::CortexConfig config;
std::unordered_map<std::string, std::string> model_update_options;
Expand Down
66 changes: 38 additions & 28 deletions engine/cli/commands/model_list_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
#include "model_list_cmd.h"
#include <json/reader.h>
#include <json/value.h>
#include <iostream>

#include <vector>
#include "httplib.h"
#include "json/json.h"
#include "server_start_cmd.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
// clang-format off
#include <tabulate/table.hpp>
// clang-format on

namespace commands {
using namespace tabulate;
using Row_t =
std::vector<variant<std::string, const char*, string_view, Table>>;

void ModelListCmd::Exec(const std::string& host, int port) {
void ModelListCmd::Exec(const std::string& host, int port, std::string filter,
bool display_engine, bool display_version) {
// Start server if server is not started yet
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
Expand All @@ -21,27 +28,48 @@ void ModelListCmd::Exec(const std::string& host, int port) {
}
}

tabulate::Table table;
Table table;
std::vector<std::string> column_headers{"(Index)", "ID"};
if (display_engine) {
column_headers.push_back("Engine");
}
if (display_version) {
column_headers.push_back("Version");
}

table.add_row({"(Index)", "ID", "model alias", "engine", "version"});
table.format().font_color(tabulate::Color::green);
Row_t header{column_headers.begin(), column_headers.end()};
table.add_row(header);
table.format().font_color(Color::green);
int count = 0;
// Iterate through directory

httplib::Client cli(host + ":" + std::to_string(port));
auto res = cli.Get("/v1/models");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
// CLI_LOG(res->body);
Json::Value body;
Json::Reader reader;
reader.parse(res->body, body);
if (!body["data"].isNull()) {
for (auto const& v : body["data"]) {
auto model_id = v["model"].asString();
if (!filter.empty() &&
!string_utils::StringContainsIgnoreCase(model_id, filter)) {
continue;
}

count += 1;
table.add_row({std::to_string(count), v["model"].asString(),
v["model_alias"].asString(), v["engine"].asString(),
v["version"].asString()});

std::vector<std::string> row = {std::to_string(count),
v["model"].asString()};
if (display_engine) {
row.push_back(v["engine"].asString());
}
if (display_version) {
row.push_back(v["version"].asString());
}

table.add_row({row.begin(), row.end()});
}
}
} else {
Expand All @@ -54,24 +82,6 @@ void ModelListCmd::Exec(const std::string& host, int port) {
return;
}

for (int i = 0; i < 5; i++) {
table[0][i]
.format()
.font_color(tabulate::Color::white) // Set font color
.font_style({tabulate::FontStyle::bold})
.font_align(tabulate::FontAlign::center);
}
for (int i = 1; i <= count; i++) {
table[i][0] //index value
.format()
.font_color(tabulate::Color::white) // Set font color
.font_align(tabulate::FontAlign::center);
table[i][4] //version value
.format()
.font_align(tabulate::FontAlign::center);
}
std::cout << table << std::endl;
}
}

; // namespace commands
}; // namespace commands
4 changes: 3 additions & 1 deletion engine/cli/commands/model_list_cmd.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#pragma once

#include <string>

namespace commands {

class ModelListCmd {
public:
void Exec(const std::string& host, int port);
void Exec(const std::string& host, int port, std::string filter,
bool display_engine = false, bool display_version = false);
};
} // namespace commands
59 changes: 45 additions & 14 deletions engine/cli/commands/run_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,54 @@ void RunCmd::Exec(bool run_detach) {
config::YamlHandler yaml_handler;
auto address = host_ + ":" + std::to_string(port_);

// Download model if it does not exist
{
auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() || related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(), "Local Models: (press enter to select)");
if (!selection.has_value()) {
if (model_handle_.empty()) {
auto all_local_models = modellist_handler.LoadModelList();
if (all_local_models.has_error() || all_local_models.value().empty()) {
CLI_LOG("No local models available!");
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());

if (all_local_models.value().size() == 1) {
model_id = all_local_models.value().front().model;
} else {
std::vector<std::string> model_id_list{};
for (const auto& model : all_local_models.value()) {
model_id_list.push_back(model.model);
}

auto selection = cli_selection_utils::PrintSelection(
model_id_list, "Please select an option");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
} else {
auto related_models_ids =
modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() ||
related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
if (result.has_error()) {
CLI_LOG("Model " << model_handle_ << " not found!");
return;
}
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(),
"Local Models: (press enter to select)");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
}
}

Expand Down
5 changes: 1 addition & 4 deletions engine/cli/commands/server_start_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#include "server_start_cmd.h"
#include "commands/cortex_upd_cmd.h"
#include "httplib.h"
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/file_manager_utils.h"
#include "utils/logging_utils.h"

namespace commands {

Expand Down Expand Up @@ -124,4 +121,4 @@ bool ServerStartCmd::Exec(const std::string& host, int port) {
return true;
}

}; // namespace commands
}; // namespace commands
3 changes: 2 additions & 1 deletion engine/cli/commands/server_start_cmd.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once

#include <string>
#include "httplib.h"

Expand All @@ -18,4 +19,4 @@ class ServerStartCmd {
ServerStartCmd();
bool Exec(const std::string& host, int port);
};
} // namespace commands
} // namespace commands
10 changes: 2 additions & 8 deletions engine/database/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,11 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(

cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
const std::string& identifier) const {
// TODO (namh): add check for alias as well
try {
std::vector<std::string> related_models;
SQLite::Statement query(
db_,
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
"OR model_id LIKE ? OR model_id LIKE ?");
query.bind(1, identifier + ":%");
query.bind(2, "%:" + identifier);
query.bind(3, "%:" + identifier + ":%");
query.bind(4, identifier);
db_, "SELECT model_id FROM models WHERE model_id LIKE ?");
query.bind(1, "%" + identifier + "%");

while (query.executeStep()) {
related_models.push_back(query.getColumn(0).getString());
Expand Down
Loading
Loading