Skip to content

Commit

Permalink
feat: uplift pull and run cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Oct 4, 2024
1 parent d3e886b commit 40edd99
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 19 deletions.
2 changes: 1 addition & 1 deletion engine/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ void ModelPullCmd::Exec(const std::string& input) {
auto result = model_service_.DownloadModel(input);
if (result.has_error()) {
CLI_LOG(result.error());
}
}
}
}; // namespace commands
23 changes: 23 additions & 0 deletions engine/database/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,29 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(
}
}

cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
const std::string& identifier) const {
// TODO (namh): add check for alias as well
try {
std::vector<std::string> related_models;
SQLite::Statement query(
db_,
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
"OR model_id LIKE ? OR model_id LIKE ?");
query.bind(1, identifier + ":%");
query.bind(2, "%:" + identifier);
query.bind(3, "%:" + identifier + ":%");
query.bind(4, identifier);

while (query.executeStep()) {
related_models.push_back(query.getColumn(0).getString());
}
return related_models;
} catch (const std::exception& e) {
return cpp::fail(e.what());
}
}

bool Models::HasModel(const std::string& identifier) const {
try {
SQLite::Statement query(
Expand Down
18 changes: 11 additions & 7 deletions engine/database/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Models {
const std::string& model_id,
const std::string& model_alias) const;

cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;
cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;

public:
static const std::string kModelListPath;
Expand All @@ -35,15 +35,19 @@ class Models {
std::string GenerateShortenedAlias(
const std::string& model_id,
const std::vector<ModelEntry>& entries) const;
cpp::result<ModelEntry, std::string> GetModelInfo(const std::string& identifier) const;
cpp::result<ModelEntry, std::string> GetModelInfo(
const std::string& identifier) const;
void PrintModelInfo(const ModelEntry& entry) const;
cpp::result<bool, std::string> AddModelEntry(ModelEntry new_entry,
bool use_short_alias = false);
cpp::result<bool, std::string> UpdateModelEntry(const std::string& identifier,
const ModelEntry& updated_entry);
cpp::result<bool, std::string> DeleteModelEntry(const std::string& identifier);
cpp::result<bool, std::string> UpdateModelAlias(const std::string& model_id,
const std::string& model_alias);
cpp::result<bool, std::string> UpdateModelEntry(
const std::string& identifier, const ModelEntry& updated_entry);
cpp::result<bool, std::string> DeleteModelEntry(
const std::string& identifier);
cpp::result<bool, std::string> UpdateModelAlias(
const std::string& model_id, const std::string& model_alias);
cpp::result<std::vector<std::string>, std::string> FindRelatedModel(
const std::string& identifier) const;
bool HasModel(const std::string& identifier) const;
};
} // namespace cortex::db
41 changes: 35 additions & 6 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,46 @@ cpp::result<std::string, std::string> ModelService::HandleCortexsoModel(
return cpp::fail(branches.error());
}

std::vector<std::string> options{};
auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName);

cortex::db::Models modellist_handler;
auto downloaded_model_ids =
modellist_handler.FindRelatedModel(modelName).value_or(
std::vector<std::string>{});

std::vector<std::string> avai_download_opts{};
for (const auto& branch : branches.value()) {
if (branch.second.name != "main") {
options.emplace_back(branch.second.name);
if (branch.second.name == "main") { // main branch only have metadata. skip
continue;
}
auto model_id = modelName + ":" + branch.second.name;
if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(),
model_id) !=
downloaded_model_ids.end()) { // if downloaded, we skip it
continue;
}
avai_download_opts.emplace_back(model_id);
}
if (options.empty()) {

if (avai_download_opts.empty()) {
// TODO: only with pull, we return
return cpp::fail("No variant available");
}
auto selection = cli_selection_utils::PrintSelection(options);
return DownloadModelFromCortexso(modelName, selection.value());
std::optional<std::string> normalized_def_branch = std::nullopt;
if (default_model_branch.has_value()) {
normalized_def_branch = modelName + ":" + default_model_branch.value();
}
string_utils::SortStrings(downloaded_model_ids);
string_utils::SortStrings(avai_download_opts);
auto selection = cli_selection_utils::PrintModelSelection(
downloaded_model_ids, avai_download_opts, normalized_def_branch);
if (!selection.has_value()) {
return cpp::fail("Invalid selection");
}

CLI_LOG("Selected: " << selection.value());
auto branch_name = selection.value().substr(modelName.size() + 1);
return DownloadModelFromCortexso(modelName, branch_name, false);
}

std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
Expand Down
56 changes: 53 additions & 3 deletions engine/utils/cli_selection_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,67 @@
#include <optional>
#include <string>
#include <vector>
#include "utils/logging_utils.h"

namespace cli_selection_utils {
inline void PrintMenu(const std::vector<std::string>& options) {
auto index{1};
const std::string indent = std::string(4, ' ');
inline void PrintMenu(
const std::vector<std::string>& options,
const std::optional<std::string> default_option = std::nullopt,
const int start_index = 1) {
auto index{start_index};
for (const auto& option : options) {
std::cout << index << ". " << option << "\n";
bool is_default = false;
if (default_option.has_value() && option == default_option.value()) {
is_default = true;
}
std::string selection{std::to_string(index) + ". " + option +
(is_default ? " (default)" : "") + "\n"};
std::cout << indent << selection;
index++;
}
std::endl(std::cout);
}

inline std::optional<std::string> PrintModelSelection(
const std::vector<std::string>& downloaded,
const std::vector<std::string>& availables,
const std::optional<std::string> default_selection = std::nullopt) {

std::string selection{""};
if (!downloaded.empty()) {
std::cout << "Downloaded models:\n";
for (const auto& option : downloaded) {
std::cout << indent << option << "\n";
}
std::endl(std::cout);
}

if (!availables.empty()) {
std::cout << "Available to download:\n";
PrintMenu(availables, default_selection, 1);
}

std::cout << "Select a model (" << 1 << "-" << availables.size() << "): ";
std::getline(std::cin, selection);

// if selection is empty and default selection is inside availables, return default_selection
if (selection.empty()) {
if (default_selection.has_value() &&
std::find(availables.begin(), availables.end(),
default_selection.value()) != availables.end()) {
return default_selection;
}
return std::nullopt;
}

if (std::stoi(selection) > availables.size() || std::stoi(selection) < 1) {
return std::nullopt;
}

return availables[std::stoi(selection) - 1];
}

inline std::optional<std::string> PrintSelection(
const std::vector<std::string>& options,
const std::string& title = "Select an option") {
Expand Down
3 changes: 1 addition & 2 deletions engine/utils/curl_utils.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <curl/curl.h>
#include <nlohmann/json.hpp>
#include <string>
#include "utils/logging_utils.h"
#include "utils/result.hpp"
#include "yaml-cpp/yaml.h"

Expand Down Expand Up @@ -74,4 +73,4 @@ inline cpp::result<nlohmann::json, std::string> SimpleGetJson(
" parsing error: " + std::string(e.what()));
}
}
} // namespace curl_utils
} // namespace curl_utils
26 changes: 26 additions & 0 deletions engine/utils/huggingface_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ GetHuggingFaceModelRepoInfo(const std::string& author,
return model_repo_info;
}

inline std::string GetMetadataUrl(const std::string& model_id) {
auto url_obj = url_parser::Url{
.protocol = "https",
.host = kHuggingfaceHost,
.pathParams = {"cortexso", model_id, "resolve", "main", "metadata.yml"}};

return url_obj.ToFullPath();
}

inline std::string GetDownloadableUrl(const std::string& author,
const std::string& modelName,
const std::string& fileName,
Expand All @@ -151,4 +160,21 @@ inline std::string GetDownloadableUrl(const std::string& author,
};
return url_parser::FromUrl(url_obj);
}

inline std::optional<std::string> GetDefaultBranch(
const std::string& model_name) {
auto default_model_branch =
curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name));

if (default_model_branch.has_error()) {
return std::nullopt;
}

auto metadata = default_model_branch.value();
auto default_branch = metadata["default"];
if (default_branch.IsDefined()) {
return default_branch.as<std::string>();
}
return std::nullopt;
}
} // namespace huggingface_utils
4 changes: 4 additions & 0 deletions engine/utils/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ inline bool StartsWith(const std::string& str, const std::string& prefix) {
return str.rfind(prefix, 0) == 0;
}

inline void SortStrings(std::vector<std::string>& strings) {
std::sort(strings.begin(), strings.end());
}

inline bool EndsWith(const std::string& str, const std::string& suffix) {
if (str.length() >= suffix.length()) {
return (0 == str.compare(str.length() - suffix.length(), suffix.length(),
Expand Down

0 comments on commit 40edd99

Please sign in to comment.