diff --git a/engine/commands/model_pull_cmd.cc b/engine/commands/model_pull_cmd.cc index b275f8beb..4ec5344bb 100644 --- a/engine/commands/model_pull_cmd.cc +++ b/engine/commands/model_pull_cmd.cc @@ -6,6 +6,6 @@ void ModelPullCmd::Exec(const std::string& input) { auto result = model_service_.DownloadModel(input); if (result.has_error()) { CLI_LOG(result.error()); - } + } } }; // namespace commands diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index 9ae71d85c..0df795615 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -1,14 +1,14 @@ #include "run_cmd.h" #include "chat_completion_cmd.h" #include "config/yaml_config.h" +#include "cortex_upd_cmd.h" #include "database/models.h" #include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" +#include "utils/cli_selection_utils.h" #include "utils/logging_utils.h" -#include "cortex_upd_cmd.h" - namespace commands { namespace { @@ -33,14 +33,21 @@ void RunCmd::Exec(bool chat_flag) { // Download model if it does not exist { - if (!modellist_handler.HasModel(model_handle_)) { + auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_); + if (related_models_ids.has_error() || related_models_ids.value().empty()) { auto result = model_service_.DownloadModel(model_handle_); - if (result.has_error()) { - CTL_ERR("Error: " << result.error()); - return; - } model_id = result.value(); CTL_INF("model_id: " << model_id.value()); + } else if (related_models_ids.value().size() == 1) { + model_id = related_models_ids.value().front(); + } else { // multiple models with nearly same name found + auto selection = cli_selection_utils::PrintSelection( + related_models_ids.value(), "Local Models: (press enter to select)"); + if (!selection.has_value()) { + return; + } + model_id = selection.value(); + CLI_LOG("Selected: " << selection.value()); } } diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 739f614c9..493957aca 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -31,7 +31,7 @@ void Models::PullModel(const HttpRequestPtr& req, CTL_INF("Handle model input, model handle: " + model_handle); if (string_utils::StartsWith(model_handle, "https")) { return model_service_.HandleUrl(model_handle, true); - } else if (model_handle.find(":") == std::string::npos) { + } else if (model_handle.find(":") != std::string::npos) { auto model_and_branch = string_utils::SplitBy(model_handle, ":"); return model_service_.DownloadModelFromCortexso( model_and_branch[0], model_and_branch[1], true); diff --git a/engine/database/models.cc b/engine/database/models.cc index 62ee4f7df..c08229061 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -273,6 +273,29 @@ cpp::result Models::DeleteModelEntry( } } +cpp::result, std::string> Models::FindRelatedModel( + const std::string& identifier) const { + // TODO (namh): add check for alias as well + try { + std::vector related_models; + SQLite::Statement query( + db_, + "SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? " + "OR model_id LIKE ? OR model_id LIKE ?"); + query.bind(1, identifier + ":%"); + query.bind(2, "%:" + identifier); + query.bind(3, "%:" + identifier + ":%"); + query.bind(4, identifier); + + while (query.executeStep()) { + related_models.push_back(query.getColumn(0).getString()); + } + return related_models; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + bool Models::HasModel(const std::string& identifier) const { try { SQLite::Statement query( diff --git a/engine/database/models.h b/engine/database/models.h index f3ff99faa..3248da788 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -24,7 +24,7 @@ class Models { const std::string& model_id, const std::string& model_alias) const; - cpp::result, std::string> LoadModelListNoLock() const; + cpp::result, std::string> LoadModelListNoLock() const; public: static const std::string kModelListPath; @@ -35,15 +35,19 @@ class Models { std::string GenerateShortenedAlias( const std::string& model_id, const std::vector& entries) const; - cpp::result GetModelInfo(const std::string& identifier) const; + cpp::result GetModelInfo( + const std::string& identifier) const; void PrintModelInfo(const ModelEntry& entry) const; cpp::result AddModelEntry(ModelEntry new_entry, bool use_short_alias = false); - cpp::result UpdateModelEntry(const std::string& identifier, - const ModelEntry& updated_entry); - cpp::result DeleteModelEntry(const std::string& identifier); - cpp::result UpdateModelAlias(const std::string& model_id, - const std::string& model_alias); + cpp::result UpdateModelEntry( + const std::string& identifier, const ModelEntry& updated_entry); + cpp::result DeleteModelEntry( + const std::string& identifier); + cpp::result UpdateModelAlias( + const std::string& model_id, const std::string& model_alias); + cpp::result, std::string> FindRelatedModel( + const std::string& identifier) const; bool HasModel(const std::string& identifier) const; }; } // namespace cortex::db diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 61ce7bbc3..d9c2aa48f 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -151,17 +151,46 @@ cpp::result ModelService::HandleCortexsoModel( return cpp::fail(branches.error()); } - std::vector options{}; + auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName); + + cortex::db::Models modellist_handler; + auto downloaded_model_ids = + modellist_handler.FindRelatedModel(modelName).value_or( + std::vector{}); + + std::vector avai_download_opts{}; for (const auto& branch : branches.value()) { - if (branch.second.name != "main") { - options.emplace_back(branch.second.name); + if (branch.second.name == "main") { // main branch only have metadata. skip + continue; + } + auto model_id = modelName + ":" + branch.second.name; + if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), + model_id) != + downloaded_model_ids.end()) { // if downloaded, we skip it + continue; } + avai_download_opts.emplace_back(model_id); } - if (options.empty()) { + + if (avai_download_opts.empty()) { + // TODO: only with pull, we return return cpp::fail("No variant available"); } - auto selection = cli_selection_utils::PrintSelection(options); - return DownloadModelFromCortexso(modelName, selection.value()); + std::optional normalized_def_branch = std::nullopt; + if (default_model_branch.has_value()) { + normalized_def_branch = modelName + ":" + default_model_branch.value(); + } + string_utils::SortStrings(downloaded_model_ids); + string_utils::SortStrings(avai_download_opts); + auto selection = cli_selection_utils::PrintModelSelection( + downloaded_model_ids, avai_download_opts, normalized_def_branch); + if (!selection.has_value()) { + return cpp::fail("Invalid selection"); + } + + CLI_LOG("Selected: " << selection.value()); + auto branch_name = selection.value().substr(modelName.size() + 1); + return DownloadModelFromCortexso(modelName, branch_name, false); } std::optional ModelService::GetDownloadedModel( diff --git a/engine/utils/cli_selection_utils.h b/engine/utils/cli_selection_utils.h index 0c2453478..0b18bdc9d 100644 --- a/engine/utils/cli_selection_utils.h +++ b/engine/utils/cli_selection_utils.h @@ -2,17 +2,67 @@ #include #include #include +#include "utils/logging_utils.h" namespace cli_selection_utils { -inline void PrintMenu(const std::vector& options) { - auto index{1}; +const std::string indent = std::string(4, ' '); +inline void PrintMenu( + const std::vector& options, + const std::optional default_option = std::nullopt, + const int start_index = 1) { + auto index{start_index}; for (const auto& option : options) { - std::cout << index << ". " << option << "\n"; + bool is_default = false; + if (default_option.has_value() && option == default_option.value()) { + is_default = true; + } + std::string selection{std::to_string(index) + ". " + option + + (is_default ? " (default)" : "") + "\n"}; + std::cout << indent << selection; index++; } std::endl(std::cout); } +inline std::optional PrintModelSelection( + const std::vector& downloaded, + const std::vector& availables, + const std::optional default_selection = std::nullopt) { + + std::string selection{""}; + if (!downloaded.empty()) { + std::cout << "Downloaded models:\n"; + for (const auto& option : downloaded) { + std::cout << indent << option << "\n"; + } + std::endl(std::cout); + } + + if (!availables.empty()) { + std::cout << "Available to download:\n"; + PrintMenu(availables, default_selection, 1); + } + + std::cout << "Select a model (" << 1 << "-" << availables.size() << "): "; + std::getline(std::cin, selection); + + // if selection is empty and default selection is inside availables, return default_selection + if (selection.empty()) { + if (default_selection.has_value() && + std::find(availables.begin(), availables.end(), + default_selection.value()) != availables.end()) { + return default_selection; + } + return std::nullopt; + } + + if (std::stoi(selection) > availables.size() || std::stoi(selection) < 1) { + return std::nullopt; + } + + return availables[std::stoi(selection) - 1]; +} + inline std::optional PrintSelection( const std::vector& options, const std::string& title = "Select an option") { diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 90dc2fd2d..2640bdc9b 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -1,7 +1,6 @@ #include #include #include -#include "utils/logging_utils.h" #include "utils/result.hpp" #include "yaml-cpp/yaml.h" @@ -74,4 +73,4 @@ inline cpp::result SimpleGetJson( " parsing error: " + std::string(e.what())); } } -} // namespace curl_utils \ No newline at end of file +} // namespace curl_utils diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index c208c1e7d..2c06afc2c 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -140,6 +140,15 @@ GetHuggingFaceModelRepoInfo(const std::string& author, return model_repo_info; } +inline std::string GetMetadataUrl(const std::string& model_id) { + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {"cortexso", model_id, "resolve", "main", "metadata.yml"}}; + + return url_obj.ToFullPath(); +} + inline std::string GetDownloadableUrl(const std::string& author, const std::string& modelName, const std::string& fileName, @@ -151,4 +160,21 @@ inline std::string GetDownloadableUrl(const std::string& author, }; return url_parser::FromUrl(url_obj); } + +inline std::optional GetDefaultBranch( + const std::string& model_name) { + auto default_model_branch = + curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name)); + + if (default_model_branch.has_error()) { + return std::nullopt; + } + + auto metadata = default_model_branch.value(); + auto default_branch = metadata["default"]; + if (default_branch.IsDefined()) { + return default_branch.as(); + } + return std::nullopt; +} } // namespace huggingface_utils diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 2bb005fc9..cc4430d64 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -13,6 +13,10 @@ inline bool StartsWith(const std::string& str, const std::string& prefix) { return str.rfind(prefix, 0) == 0; } +inline void SortStrings(std::vector& strings) { + std::sort(strings.begin(), strings.end()); +} + inline bool EndsWith(const std::string& str, const std::string& suffix) { if (str.length() >= suffix.length()) { return (0 == str.compare(str.length() - suffix.length(), suffix.length(),