From 0aad82cbac1e82df46cbe1cae85085fc4cfca6bd Mon Sep 17 00:00:00 2001 From: James Date: Sun, 20 Oct 2024 14:57:40 +0700 Subject: [PATCH] feat(#1512): simplify cortex run --- engine/cli/command_line_parser.cc | 58 ++++++++++-------- engine/cli/command_line_parser.h | 6 ++ engine/cli/commands/model_list_cmd.cc | 66 ++++++++++++--------- engine/cli/commands/model_list_cmd.h | 4 +- engine/cli/commands/run_cmd.cc | 59 +++++++++++++----- engine/cli/commands/server_start_cmd.cc | 5 +- engine/cli/commands/server_start_cmd.h | 3 +- engine/database/models.cc | 10 +--- engine/test/components/test_string_utils.cc | 53 +++++++++++++++++ engine/utils/curl_utils.h | 3 +- engine/utils/huggingface_utils.h | 26 ++++---- engine/utils/string_utils.h | 21 +++++++ 12 files changed, 220 insertions(+), 94 deletions(-) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 51bd121b2..8d0b35d33 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -1,5 +1,6 @@ #include "command_line_parser.h" #include +#include #include "commands/chat_cmd.h" #include "commands/chat_completion_cmd.h" #include "commands/cortex_upd_cmd.h" @@ -82,27 +83,30 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { // Check new update #ifdef CORTEX_CPP_VERSION if (cml_data_.check_upd) { - // TODO(sang) find a better way to handle - // This is an extremely ungly way to deal with connection - // hang when network down - std::atomic done = false; - std::thread t([&]() { - if (auto latest_version = - commands::CheckNewUpdate(commands::kTimeoutCheckUpdate); - latest_version.has_value() && *latest_version != CORTEX_CPP_VERSION) { - CLI_LOG("\nA new release of cortex is available: " - << CORTEX_CPP_VERSION << " -> " << *latest_version); - CLI_LOG("To upgrade, run: " << commands::GetRole() - << commands::GetCortexBinary() - << " update"); + if (strcmp(CORTEX_CPP_VERSION, "default_version") != 0) { + // TODO(sang) find a better way to handle + // This is an extremely ugly way to deal with connection + // hang when network down + std::atomic done = false; + std::thread t([&]() { + if (auto latest_version = + commands::CheckNewUpdate(commands::kTimeoutCheckUpdate); + latest_version.has_value() && + *latest_version != CORTEX_CPP_VERSION) { + CLI_LOG("\nA new release of cortex is available: " + << CORTEX_CPP_VERSION << " -> " << *latest_version); + CLI_LOG("To upgrade, run: " << commands::GetRole() + << commands::GetCortexBinary() + << " update"); + } + done = true; + }); + // Do not wait for http connection timeout + t.detach(); + int retry = 10; + while (!done && retry--) { + std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10); } - done = true; - }); - // Do not wait for http connection timeout - t.detach(); - int retry = 10; - while (!done && retry--) { - std::this_thread::sleep_for(commands::kTimeoutCheckUpdate / 10); } } #endif @@ -143,11 +147,6 @@ void CommandLineParser::SetupCommonCommands() { run_cmd->callback([this, run_cmd] { if (std::exchange(executed_, true)) return; - if (cml_data_.model_id.empty()) { - CLI_LOG("[model_id] is required\n"); - CLI_LOG(run_cmd->help()); - return; - } commands::RunCmd rc(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, download_service_); @@ -247,12 +246,19 @@ void CommandLineParser::SetupModelCommands() { auto list_models_cmd = models_cmd->add_subcommand("list", "List all models locally"); + list_models_cmd->add_option("filter", cml_data_.filter, "Filter model id"); + list_models_cmd->add_flag("-e,--engine", cml_data_.display_engine, + "Display engine"); + list_models_cmd->add_flag("-v,--version", cml_data_.display_version, + "Display version"); list_models_cmd->group(kSubcommands); list_models_cmd->callback([this]() { if (std::exchange(executed_, true)) return; commands::ModelListCmd().Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort)); + std::stoi(cml_data_.config.apiServerPort), + cml_data_.filter, cml_data_.display_engine, + cml_data_.display_version); }); auto get_models_cmd = diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index 7a9581f1f..f2f00ae95 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -45,6 +45,12 @@ class CommandLineParser { std::string cortex_version; bool check_upd = true; bool run_detach = false; + + // for model list + bool display_engine = false; + bool display_version = false; + std::string filter = ""; + int port; config_yaml_utils::CortexConfig config; std::unordered_map model_update_options; diff --git a/engine/cli/commands/model_list_cmd.cc b/engine/cli/commands/model_list_cmd.cc index c92836456..a6be44d9d 100644 --- a/engine/cli/commands/model_list_cmd.cc +++ b/engine/cli/commands/model_list_cmd.cc @@ -1,17 +1,24 @@ #include "model_list_cmd.h" +#include +#include #include #include #include "httplib.h" -#include "json/json.h" #include "server_start_cmd.h" #include "utils/logging_utils.h" +#include "utils/string_utils.h" // clang-format off #include // clang-format on + namespace commands { +using namespace tabulate; +using Row_t = + std::vector>; -void ModelListCmd::Exec(const std::string& host, int port) { +void ModelListCmd::Exec(const std::string& host, int port, std::string filter, + bool display_engine, bool display_version) { // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -21,10 +28,18 @@ void ModelListCmd::Exec(const std::string& host, int port) { } } - tabulate::Table table; + Table table; + std::vector column_headers{"(Index)", "ID"}; + if (display_engine) { + column_headers.push_back("Engine"); + } + if (display_version) { + column_headers.push_back("Version"); + } - table.add_row({"(Index)", "ID", "model alias", "engine", "version"}); - table.format().font_color(tabulate::Color::green); + Row_t header{column_headers.begin(), column_headers.end()}; + table.add_row(header); + table.format().font_color(Color::green); int count = 0; // Iterate through directory @@ -32,16 +47,29 @@ void ModelListCmd::Exec(const std::string& host, int port) { auto res = cli.Get("/v1/models"); if (res) { if (res->status == httplib::StatusCode::OK_200) { - // CLI_LOG(res->body); Json::Value body; Json::Reader reader; reader.parse(res->body, body); if (!body["data"].isNull()) { for (auto const& v : body["data"]) { + auto model_id = v["model"].asString(); + if (!filter.empty() && + !string_utils::StringContainsIgnoreCase(model_id, filter)) { + continue; + } + count += 1; - table.add_row({std::to_string(count), v["model"].asString(), - v["model_alias"].asString(), v["engine"].asString(), - v["version"].asString()}); + + std::vector row = {std::to_string(count), + v["model"].asString()}; + if (display_engine) { + row.push_back(v["engine"].asString()); + } + if (display_version) { + row.push_back(v["version"].asString()); + } + + table.add_row({row.begin(), row.end()}); } } } else { @@ -54,24 +82,6 @@ void ModelListCmd::Exec(const std::string& host, int port) { return; } - for (int i = 0; i < 5; i++) { - table[0][i] - .format() - .font_color(tabulate::Color::white) // Set font color - .font_style({tabulate::FontStyle::bold}) - .font_align(tabulate::FontAlign::center); - } - for (int i = 1; i <= count; i++) { - table[i][0] //index value - .format() - .font_color(tabulate::Color::white) // Set font color - .font_align(tabulate::FontAlign::center); - table[i][4] //version value - .format() - .font_align(tabulate::FontAlign::center); - } std::cout << table << std::endl; } -} - -; // namespace commands +}; // namespace commands diff --git a/engine/cli/commands/model_list_cmd.h b/engine/cli/commands/model_list_cmd.h index 2f25cc1cf..4f61c67cc 100644 --- a/engine/cli/commands/model_list_cmd.h +++ b/engine/cli/commands/model_list_cmd.h @@ -1,10 +1,12 @@ #pragma once + #include namespace commands { class ModelListCmd { public: - void Exec(const std::string& host, int port); + void Exec(const std::string& host, int port, std::string filter, + bool display_engine = false, bool display_version = false); }; } // namespace commands diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 074c12709..3f501fdbb 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -30,23 +30,54 @@ void RunCmd::Exec(bool run_detach) { config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); - // Download model if it does not exist { - auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_); - if (related_models_ids.has_error() || related_models_ids.value().empty()) { - auto result = model_service_.DownloadModel(model_handle_); - model_id = result.value(); - CTL_INF("model_id: " << model_id.value()); - } else if (related_models_ids.value().size() == 1) { - model_id = related_models_ids.value().front(); - } else { // multiple models with nearly same name found - auto selection = cli_selection_utils::PrintSelection( - related_models_ids.value(), "Local Models: (press enter to select)"); - if (!selection.has_value()) { + if (model_handle_.empty()) { + auto all_local_models = modellist_handler.LoadModelList(); + if (all_local_models.has_error() || all_local_models.value().empty()) { + CLI_LOG("No local models available!"); return; } - model_id = selection.value(); - CLI_LOG("Selected: " << selection.value()); + + if (all_local_models.value().size() == 1) { + model_id = all_local_models.value().front().model; + } else { + std::vector model_id_list{}; + for (const auto& model : all_local_models.value()) { + model_id_list.push_back(model.model); + } + + auto selection = cli_selection_utils::PrintSelection( + model_id_list, "Please select an option"); + if (!selection.has_value()) { + return; + } + model_id = selection.value(); + CLI_LOG("Selected: " << selection.value()); + } + } else { + auto related_models_ids = + modellist_handler.FindRelatedModel(model_handle_); + if (related_models_ids.has_error() || + related_models_ids.value().empty()) { + auto result = model_service_.DownloadModel(model_handle_); + if (result.has_error()) { + CLI_LOG("Model " << model_handle_ << " not found!"); + return; + } + model_id = result.value(); + CTL_INF("model_id: " << model_id.value()); + } else if (related_models_ids.value().size() == 1) { + model_id = related_models_ids.value().front(); + } else { // multiple models with nearly same name found + auto selection = cli_selection_utils::PrintSelection( + related_models_ids.value(), + "Local Models: (press enter to select)"); + if (!selection.has_value()) { + return; + } + model_id = selection.value(); + CLI_LOG("Selected: " << selection.value()); + } } } diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index 4c47a4da3..b455f93c3 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -1,10 +1,7 @@ #include "server_start_cmd.h" #include "commands/cortex_upd_cmd.h" -#include "httplib.h" -#include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" -#include "utils/logging_utils.h" namespace commands { @@ -124,4 +121,4 @@ bool ServerStartCmd::Exec(const std::string& host, int port) { return true; } -}; // namespace commands \ No newline at end of file +}; // namespace commands diff --git a/engine/cli/commands/server_start_cmd.h b/engine/cli/commands/server_start_cmd.h index cb74c5ebc..35bd07717 100644 --- a/engine/cli/commands/server_start_cmd.h +++ b/engine/cli/commands/server_start_cmd.h @@ -1,4 +1,5 @@ #pragma once + #include #include "httplib.h" @@ -18,4 +19,4 @@ class ServerStartCmd { ServerStartCmd(); bool Exec(const std::string& host, int port); }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/database/models.cc b/engine/database/models.cc index c08229061..753162328 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -275,17 +275,11 @@ cpp::result Models::DeleteModelEntry( cpp::result, std::string> Models::FindRelatedModel( const std::string& identifier) const { - // TODO (namh): add check for alias as well try { std::vector related_models; SQLite::Statement query( - db_, - "SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? " - "OR model_id LIKE ? OR model_id LIKE ?"); - query.bind(1, identifier + ":%"); - query.bind(2, "%:" + identifier); - query.bind(3, "%:" + identifier + ":%"); - query.bind(4, identifier); + db_, "SELECT model_id FROM models WHERE model_id LIKE ?"); + query.bind(1, "%" + identifier + "%"); while (query.executeStep()) { related_models.push_back(query.getColumn(0).getString()); diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index 1b16858c4..0269f0d4a 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -173,3 +173,56 @@ TEST_F(StringUtilsTestSuite, SpecialCharacters) { EXPECT_TRUE(string_utils::EqualsIgnoreCase("123 ABC", "123 abc")); EXPECT_FALSE(string_utils::EqualsIgnoreCase("Hello!", "Hello")); } + +TEST_F(StringUtilsTestSuite, BasicMatching) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Hello, World!", "world")); + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Hello, World!", "Hello")); + EXPECT_TRUE( + string_utils::StringContainsIgnoreCase("Hello, World!", "lo, wo")); +} + +TEST_F(StringUtilsTestSuite, CaseSensitivity) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("HELLO", "hello")); + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("hello", "HELLO")); + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("HeLLo", "ELL")); +} + +TEST_F(StringUtilsTestSuite, EdgeCases) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("", "")); + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Hello", "")); + EXPECT_FALSE(string_utils::StringContainsIgnoreCase("", "Hello")); +} + +TEST_F(StringUtilsTestSuite, NoMatch) { + EXPECT_FALSE( + string_utils::StringContainsIgnoreCase("Hello, World!", "Goodbye")); + EXPECT_FALSE(string_utils::StringContainsIgnoreCase("Hello", "HelloWorld")); +} + +TEST_F(StringUtilsTestSuite, StringContainsWithSpecialCharacters) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Hello, World!", "o, W")); + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Hello! @#$%", "@#$")); +} + +TEST_F(StringUtilsTestSuite, StringContainsWithModelId) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase( + "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K." + "gguf", + "thebloke")); +} + +TEST_F(StringUtilsTestSuite, RepeatingPatterns) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Mississippi", "ssi")); + EXPECT_TRUE(string_utils::StringContainsIgnoreCase("Mississippi", "ssippi")); +} + +TEST_F(StringUtilsTestSuite, LongStrings) { + EXPECT_TRUE(string_utils::StringContainsIgnoreCase( + "This is a very long string to test our " + "function's performance with larger inputs", + "PERFORMANCE")); + EXPECT_FALSE(string_utils::StringContainsIgnoreCase( + "This is a very long string to test our " + "function's performance with larger inputs", + "not here")); +} diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index b52030726..2c847e17f 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -1,8 +1,9 @@ #include +#include +#include #include #include #include "utils/result.hpp" -#include "yaml-cpp/yaml.h" namespace curl_utils { namespace { diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index ab85948e7..9f78f59d3 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -193,18 +193,22 @@ inline std::string GetDownloadableUrl(const std::string& author, inline std::optional GetDefaultBranch( const std::string& model_name) { - auto default_model_branch = curl_utils::ReadRemoteYaml( - GetMetadataUrl(model_name), CreateCurlHfHeaders()); - - if (default_model_branch.has_error()) { + try { + auto default_model_branch = curl_utils::ReadRemoteYaml( + GetMetadataUrl(model_name), CreateCurlHfHeaders()); + + if (default_model_branch.has_error()) { + return std::nullopt; + } + + auto metadata = default_model_branch.value(); + auto default_branch = metadata["default"]; + if (default_branch.IsDefined()) { + return default_branch.as(); + } + return std::nullopt; + } catch (const std::exception& e) { return std::nullopt; } - - auto metadata = default_model_branch.value(); - auto default_branch = metadata["default"]; - if (default_branch.IsDefined()) { - return default_branch.as(); - } - return std::nullopt; } } // namespace huggingface_utils diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 3af6dda82..99373a3ce 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include #include @@ -23,6 +25,25 @@ inline void Trim(std::string& s) { s.end()); } +inline bool StringContainsIgnoreCase(const std::string& haystack, + const std::string& needle) { + if (needle.empty()) { + return true; + } + + if (haystack.length() < needle.length()) { + return false; + } + + auto it = + std::search(haystack.begin(), haystack.end(), needle.begin(), + needle.end(), [](char ch1, char ch2) { + return std::tolower(static_cast(ch1)) == + std::tolower(static_cast(ch2)); + }); + return it != haystack.end(); +} + inline bool EqualsIgnoreCase(const std::string& a, const std::string& b) { return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](char a, char b) { return tolower(a) == tolower(b); });