Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: uplift pull and run cmd #1430

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ void ModelPullCmd::Exec(const std::string& input) {
auto result = model_service_.DownloadModel(input);
if (result.has_error()) {
CLI_LOG(result.error());
}
}
}
}; // namespace commands
21 changes: 14 additions & 7 deletions engine/commands/run_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "run_cmd.h"
#include "chat_completion_cmd.h"
#include "config/yaml_config.h"
#include "cortex_upd_cmd.h"
#include "database/models.h"
#include "model_start_cmd.h"
#include "model_status_cmd.h"
#include "server_start_cmd.h"
#include "utils/cli_selection_utils.h"
#include "utils/logging_utils.h"

#include "cortex_upd_cmd.h"

namespace commands {

namespace {
Expand All @@ -33,14 +33,21 @@ void RunCmd::Exec(bool chat_flag) {

// Download model if it does not exist
{
if (!modellist_handler.HasModel(model_handle_)) {
auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() || related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
if (result.has_error()) {
CTL_ERR("Error: " << result.error());
return;
}
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(), "Local Models: (press enter to select)");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
}

Expand Down
2 changes: 1 addition & 1 deletion engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void Models::PullModel(const HttpRequestPtr& req,
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_.HandleUrl(model_handle, true);
} else if (model_handle.find(":") == std::string::npos) {
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
return model_service_.DownloadModelFromCortexso(
model_and_branch[0], model_and_branch[1], true);
Expand Down
23 changes: 23 additions & 0 deletions engine/database/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,29 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(
}
}

cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
const std::string& identifier) const {
// TODO (namh): add check for alias as well
try {
std::vector<std::string> related_models;
SQLite::Statement query(
db_,
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
"OR model_id LIKE ? OR model_id LIKE ?");
query.bind(1, identifier + ":%");
query.bind(2, "%:" + identifier);
query.bind(3, "%:" + identifier + ":%");
query.bind(4, identifier);

while (query.executeStep()) {
related_models.push_back(query.getColumn(0).getString());
}
return related_models;
} catch (const std::exception& e) {
return cpp::fail(e.what());
}
}

bool Models::HasModel(const std::string& identifier) const {
try {
SQLite::Statement query(
Expand Down
18 changes: 11 additions & 7 deletions engine/database/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Models {
const std::string& model_id,
const std::string& model_alias) const;

cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;
cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;

public:
static const std::string kModelListPath;
Expand All @@ -35,15 +35,19 @@ class Models {
std::string GenerateShortenedAlias(
const std::string& model_id,
const std::vector<ModelEntry>& entries) const;
cpp::result<ModelEntry, std::string> GetModelInfo(const std::string& identifier) const;
cpp::result<ModelEntry, std::string> GetModelInfo(
const std::string& identifier) const;
void PrintModelInfo(const ModelEntry& entry) const;
cpp::result<bool, std::string> AddModelEntry(ModelEntry new_entry,
bool use_short_alias = false);
cpp::result<bool, std::string> UpdateModelEntry(const std::string& identifier,
const ModelEntry& updated_entry);
cpp::result<bool, std::string> DeleteModelEntry(const std::string& identifier);
cpp::result<bool, std::string> UpdateModelAlias(const std::string& model_id,
const std::string& model_alias);
cpp::result<bool, std::string> UpdateModelEntry(
const std::string& identifier, const ModelEntry& updated_entry);
cpp::result<bool, std::string> DeleteModelEntry(
const std::string& identifier);
cpp::result<bool, std::string> UpdateModelAlias(
const std::string& model_id, const std::string& model_alias);
cpp::result<std::vector<std::string>, std::string> FindRelatedModel(
const std::string& identifier) const;
bool HasModel(const std::string& identifier) const;
};
} // namespace cortex::db
41 changes: 35 additions & 6 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,46 @@ cpp::result<std::string, std::string> ModelService::HandleCortexsoModel(
return cpp::fail(branches.error());
}

std::vector<std::string> options{};
auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName);

cortex::db::Models modellist_handler;
auto downloaded_model_ids =
modellist_handler.FindRelatedModel(modelName).value_or(
std::vector<std::string>{});

std::vector<std::string> avai_download_opts{};
for (const auto& branch : branches.value()) {
if (branch.second.name != "main") {
options.emplace_back(branch.second.name);
if (branch.second.name == "main") { // main branch only have metadata. skip
continue;
}
auto model_id = modelName + ":" + branch.second.name;
if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(),
model_id) !=
downloaded_model_ids.end()) { // if downloaded, we skip it
continue;
}
avai_download_opts.emplace_back(model_id);
}
if (options.empty()) {

if (avai_download_opts.empty()) {
// TODO: only with pull, we return
return cpp::fail("No variant available");
}
auto selection = cli_selection_utils::PrintSelection(options);
return DownloadModelFromCortexso(modelName, selection.value());
std::optional<std::string> normalized_def_branch = std::nullopt;
if (default_model_branch.has_value()) {
normalized_def_branch = modelName + ":" + default_model_branch.value();
}
string_utils::SortStrings(downloaded_model_ids);
string_utils::SortStrings(avai_download_opts);
auto selection = cli_selection_utils::PrintModelSelection(
downloaded_model_ids, avai_download_opts, normalized_def_branch);
if (!selection.has_value()) {
return cpp::fail("Invalid selection");
}

CLI_LOG("Selected: " << selection.value());
auto branch_name = selection.value().substr(modelName.size() + 1);
return DownloadModelFromCortexso(modelName, branch_name, false);
}

std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
Expand Down
56 changes: 53 additions & 3 deletions engine/utils/cli_selection_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,67 @@
#include <optional>
#include <string>
#include <vector>
#include "utils/logging_utils.h"

namespace cli_selection_utils {
inline void PrintMenu(const std::vector<std::string>& options) {
auto index{1};
const std::string indent = std::string(4, ' ');
inline void PrintMenu(
const std::vector<std::string>& options,
const std::optional<std::string> default_option = std::nullopt,
const int start_index = 1) {
auto index{start_index};
for (const auto& option : options) {
std::cout << index << ". " << option << "\n";
bool is_default = false;
if (default_option.has_value() && option == default_option.value()) {
is_default = true;
}
std::string selection{std::to_string(index) + ". " + option +
(is_default ? " (default)" : "") + "\n"};
std::cout << indent << selection;
index++;
}
std::endl(std::cout);
}

inline std::optional<std::string> PrintModelSelection(
const std::vector<std::string>& downloaded,
const std::vector<std::string>& availables,
const std::optional<std::string> default_selection = std::nullopt) {

std::string selection{""};
if (!downloaded.empty()) {
std::cout << "Downloaded models:\n";
for (const auto& option : downloaded) {
std::cout << indent << option << "\n";
}
std::endl(std::cout);
}

if (!availables.empty()) {
std::cout << "Available to download:\n";
PrintMenu(availables, default_selection, 1);
}

std::cout << "Select a model (" << 1 << "-" << availables.size() << "): ";
std::getline(std::cin, selection);

// if selection is empty and default selection is inside availables, return default_selection
if (selection.empty()) {
if (default_selection.has_value() &&
std::find(availables.begin(), availables.end(),
default_selection.value()) != availables.end()) {
return default_selection;
}
return std::nullopt;
}

if (std::stoi(selection) > availables.size() || std::stoi(selection) < 1) {
return std::nullopt;
}

return availables[std::stoi(selection) - 1];
}

inline std::optional<std::string> PrintSelection(
const std::vector<std::string>& options,
const std::string& title = "Select an option") {
Expand Down
3 changes: 1 addition & 2 deletions engine/utils/curl_utils.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <curl/curl.h>
#include <nlohmann/json.hpp>
#include <string>
#include "utils/logging_utils.h"
#include "utils/result.hpp"
#include "yaml-cpp/yaml.h"

Expand Down Expand Up @@ -74,4 +73,4 @@ inline cpp::result<nlohmann::json, std::string> SimpleGetJson(
" parsing error: " + std::string(e.what()));
}
}
} // namespace curl_utils
} // namespace curl_utils
26 changes: 26 additions & 0 deletions engine/utils/huggingface_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ GetHuggingFaceModelRepoInfo(const std::string& author,
return model_repo_info;
}

inline std::string GetMetadataUrl(const std::string& model_id) {
auto url_obj = url_parser::Url{
.protocol = "https",
.host = kHuggingfaceHost,
.pathParams = {"cortexso", model_id, "resolve", "main", "metadata.yml"}};

return url_obj.ToFullPath();
}

inline std::string GetDownloadableUrl(const std::string& author,
const std::string& modelName,
const std::string& fileName,
Expand All @@ -151,4 +160,21 @@ inline std::string GetDownloadableUrl(const std::string& author,
};
return url_parser::FromUrl(url_obj);
}

inline std::optional<std::string> GetDefaultBranch(
const std::string& model_name) {
auto default_model_branch =
curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name));

if (default_model_branch.has_error()) {
return std::nullopt;
}

auto metadata = default_model_branch.value();
auto default_branch = metadata["default"];
if (default_branch.IsDefined()) {
return default_branch.as<std::string>();
}
return std::nullopt;
}
} // namespace huggingface_utils
4 changes: 4 additions & 0 deletions engine/utils/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ inline bool StartsWith(const std::string& str, const std::string& prefix) {
return str.rfind(prefix, 0) == 0;
}

inline void SortStrings(std::vector<std::string>& strings) {
std::sort(strings.begin(), strings.end());
}

inline bool EndsWith(const std::string& str, const std::string& suffix) {
if (str.length() >= suffix.length()) {
return (0 == str.compare(str.length() - suffix.length(), suffix.length(),
Expand Down
Loading