Skip to content

Commit

Permalink
finalizing
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Oct 31, 2024
1 parent 2a4ed50 commit 8a31f11
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 126 deletions.
4 changes: 3 additions & 1 deletion engine/cli/commands/engine_release_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ cpp::result<void, std::string> EngineReleaseCmd::Exec(
github_release_utils::GitHubAsset selected_asset;
for (const auto& asset : selected_release["assets"]) {
if (asset["name"] == variant_selection) {
selected_asset = github_release_utils::GitHubAsset::FromJson(asset);
auto version = string_utils::RemoveSubstring(selection.value(), "v");
selected_asset =
github_release_utils::GitHubAsset::FromJson(asset, version);
break;
}
}
Expand Down
54 changes: 5 additions & 49 deletions engine/controllers/engines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,6 @@ std::string NormalizeEngine(const std::string& engine) {
};
} // namespace

void Engines::InstallEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) {

if (engine.empty()) {
Json::Value res;
res["message"] = "Engine name is required";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k409Conflict);
callback(resp);
LOG_WARN << "No engine field in path param";
return;
}

std::string version = "latest";
if (auto o = req->getJsonObject(); o) {
version = (*o).get("version", "latest").asString();
}

auto result = engine_service_->InstallEngineAsync(engine, version);
if (result.has_error()) {
Json::Value res;
res["message"] = result.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k400BadRequest);
callback(resp);
} else {
Json::Value res;
res["message"] = "Engine " + engine + " starts installing!";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k200OK);
callback(resp);
}
}

void Engines::ListEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
Expand Down Expand Up @@ -167,20 +131,12 @@ void Engines::GetEngineVariants(
void Engines::InstallEngineVariant(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine, const std::string& version,
const std::string& variant_name) {
const std::string& engine, const std::optional<std::string> version,
const std::optional<std::string> variant_name) {
auto normalized_version = version.value_or("latest");

if (version.empty() || variant_name.empty()) {
Json::Value ret;
ret["result"] = "Bad Request";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}

auto result =
engine_service_->InstallEngineAsyncV2(engine, version, variant_name);
auto result = engine_service_->InstallEngineAsyncV2(
engine, normalized_version, variant_name);
if (result.has_error()) {
Json::Value res;
res["message"] = result.error();
Expand Down
13 changes: 4 additions & 9 deletions engine/controllers/engines.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class Engines : public drogon::HttpController<Engines, false> {
public:
METHOD_LIST_BEGIN

// TODO: update this API
// METHOD_ADD(Engines::InstallEngine, "/install/{1}", Post);
METHOD_ADD(Engines::InstallEngineVariant, "/{1}?version={2}&variant={3}",
Post);
METHOD_ADD(Engines::UninstallEngine, "/{1}/{2}/{3}", Delete);
METHOD_ADD(Engines::ListEngine, "", Get);

Expand All @@ -30,18 +30,13 @@ class Engines : public drogon::HttpController<Engines, false> {
METHOD_ADD(Engines::LoadEngine, "/{1}/load", Post);
METHOD_ADD(Engines::UnloadEngine, "/{1}/load", Delete);

ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/install/{1}", Post);
ADD_METHOD_TO(Engines::UninstallEngine, "/v1/engines/{1}/{2}/{3}", Delete);

METHOD_LIST_END

explicit Engines(std::shared_ptr<EngineService> engine_service)
: engine_service_{engine_service} {}

void InstallEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine);

void ListEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;

Expand All @@ -62,8 +57,8 @@ class Engines : public drogon::HttpController<Engines, false> {
void InstallEngineVariant(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine, const std::string& version,
const std::string& variant_name);
const std::string& engine, const std::optional<std::string> version,
const std::optional<std::string> variant_name);

void GetEnginesInstalledVariants(
const HttpRequestPtr& req,
Expand Down
115 changes: 59 additions & 56 deletions engine/services/engine_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ std::string GetEnginePath(std::string_view e) {

cpp::result<void, std::string> EngineService::InstallEngineAsyncV2(
const std::string& engine, const std::string& version,
const std::string& variant_name) {
const std::optional<std::string> variant_name) {
auto ne = NormalizeEngine(engine);
CTL_INF("InstallEngineAsyncV2: " << ne << ", " << version << ", "
<< variant_name);
<< variant_name.value_or(""));

auto result = DownloadEngineV2(ne, version, variant_name, true /*async*/);
auto result = DownloadEngineV2(ne, version, variant_name);
if (result.has_error()) {
return result;
}
Expand Down Expand Up @@ -191,21 +191,11 @@ cpp::result<bool, std::string> EngineService::UninstallEngineVariant(

cpp::result<void, std::string> EngineService::DownloadEngineV2(
const std::string& engine, const std::string& version,
const std::string& variant_name, bool async) {
const std::optional<std::string> variant_name) {
auto normalized_version = version == "latest"
? "latest"
: string_utils::RemoveSubstring(version, "v");

// check if engine variant is installed
bool is_installed = false;
if (is_installed) {
// set default
// TODO: namh implement this
return {};
}

// TODO: namh add back the github_token

auto normalized_version = string_utils::RemoveSubstring(version, "v");
auto merged_variant_name =
engine + "-" + normalized_version + "-" + variant_name + ".tar.gz";
auto res = GetEngineVariants(engine, version);
if (res.has_error()) {
return cpp::fail("Failed to fetch engine releases: " + res.error());
Expand All @@ -216,33 +206,70 @@ cpp::result<void, std::string> EngineService::DownloadEngineV2(
}

std::optional<EngineVariant> selected_variant = std::nullopt;
for (const auto& asset : res.value()) {
if (asset.name == merged_variant_name) {
selected_variant = asset;
break;

if (variant_name.has_value()) {
auto merged_variant_name = engine + "-" + normalized_version + "-" +
variant_name.value() + ".tar.gz";

for (const auto& asset : res.value()) {
if (asset.name == merged_variant_name) {
selected_variant = asset;
break;
}
}
} else {
std::vector<std::string> variants;
for (const auto& asset : res.value()) {
variants.push_back(asset.name);
}

auto matched_variant_name = GetMatchedVariant(engine, variants);
for (const auto& v : res.value()) {
if (v.name == matched_variant_name) {
selected_variant = v;
break;
}
}
}

if (selected_variant == std::nullopt) {
return cpp::fail("Not found variant: " + variant_name);
return cpp::fail("Failed to find a suitable variant for " + engine);
}
auto normalize_version = "v" + selected_variant->version;

auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion(
selected_variant->name, engine, selected_variant->version);

auto engine_folder_path =
file_manager_utils::GetEnginesContainerPath() / engine;
auto variant_folder_path = engine_folder_path / variant_name / version;
auto variant_path = variant_folder_path / merged_variant_name;
auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() /
engine / variant_folder_name.value() /
normalize_version;

auto variant_path = variant_folder_path / selected_variant->name;
std::filesystem::create_directories(variant_folder_path);
CLI_LOG("variant_folder_path: " + variant_folder_path.string());

auto on_finished = [](const DownloadTask& finishedTask) {
auto on_finished = [this, engine, selected_variant,
normalize_version](const DownloadTask& finishedTask) {
// try to unzip the downloaded file
CLI_LOG("Engine zip path: " << finishedTask.items[0].localPath.string());
CLI_LOG("Version: " + normalize_version);

auto extract_path = finishedTask.items[0].localPath.parent_path();

archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(),
extract_path.string(), true);

auto variant = engine_matcher_utils::GetVariantFromNameAndVersion(
selected_variant->name, engine, normalize_version);
CLI_LOG("Extracted variant: " + variant.value());
// set as default
auto res =
SetDefaultEngineVariant(engine, normalize_version, variant.value());
if (res.has_error()) {
CTL_ERR("Failed to set default engine variant: " << res.error());
} else {
CTL_INF("Set default engine variant: " << res.value().variant);
}

// remove the downloaded file
try {
std::filesystem::remove(finishedTask.items[0].localPath);
Expand Down Expand Up @@ -270,33 +297,6 @@ cpp::result<void, std::string> EngineService::DownloadEngineV2(

cpp::result<bool, std::string> EngineService::DownloadEngine(
const std::string& engine, const std::string& version, bool async) {

// auto get_params = [&engine, &version]() -> std::vector<std::string> {
// if (version == "latest") {
// return {"repos", "janhq", engine, "releases", version};
// } else {
// return {"repos", "janhq", engine, "releases"};
// }
// };
//
// auto url_obj = url_parser::Url{
// .protocol = "https",
// .host = "api.github.com",
// .pathParams = get_params(),
// };

// std::unordered_map<std::string, std::string> headers;

// Check if GITHUB_TOKEN env exist
// const char* github_token = std::getenv("GITHUB_TOKEN");
// if (github_token) {
// std::string auth_header = "token " + std::string(github_token);
// headers.insert({"Authorization", auth_header});
// CTL_INF("Using authentication with GitHub token.");
// } else {
// CTL_INF("No GitHub token found. Sending request without authentication.");
// }

auto res = GetEngineVariants(engine, version);
if (res.has_error()) {
return cpp::fail("Failed to fetch engine releases: " + res.error());
Expand Down Expand Up @@ -335,7 +335,6 @@ cpp::result<bool, std::string> EngineService::DownloadEngine(
CTL_INF("Creating " << engine_folder_path.string());
std::filesystem::create_directories(engine_folder_path);
}

CTL_INF("Engine folder path: " << engine_folder_path.string() << "\n");
auto local_path = engine_folder_path / asset.name;
auto downloadTask{
Expand Down Expand Up @@ -549,7 +548,11 @@ cpp::result<bool, std::string> EngineService::IsEngineVariantReady(
auto normalized_version = string_utils::RemoveSubstring(version, "v");
auto installed_engines = GetInstalledEngineVariants(ne);

CLI_LOG("IsEngineVariantReady: " << ne << ", " << normalized_version << ", "
<< variant);
for (const auto& installed_engine : installed_engines) {
CLI_LOG("Installed: name: " + installed_engine.name +
", version: " + installed_engine.version);
if (installed_engine.name == variant &&
installed_engine.version == normalized_version) {
return true;
Expand Down
12 changes: 4 additions & 8 deletions engine/services/engine_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ class EngineService {
std::unordered_map<std::string, EngineInfo> engines_{};

public:
constexpr static auto kIncompatible = "Incompatible";
constexpr static auto kReady = "Ready";
constexpr static auto kNotInstalled = "Not Installed";

const std::vector<std::string_view> kSupportEngines = {
kLlamaEngine, kOnnxEngine, kTrtLlmEngine};

Expand All @@ -112,8 +108,8 @@ class EngineService {
* If no variant provided, automatically pick the best variant.
*/
cpp::result<void, std::string> InstallEngineAsyncV2(
const std::string& engine, const std::string& version = "latest",
const std::string& variant_name = "");
const std::string& engine, const std::string& version,
const std::optional<std::string> variant_name);

cpp::result<bool, std::string> UninstallEngineVariant(
const std::string& engine, const std::string& variant,
Expand Down Expand Up @@ -162,8 +158,8 @@ class EngineService {
bool async = false);

cpp::result<void, std::string> DownloadEngineV2(
const std::string& engine, const std::string& variant,
const std::string& version = "latest", bool async = false);
const std::string& engine, const std::string& version = "latest",
const std::optional<std::string> variant_name = std::nullopt);

cpp::result<bool, std::string> DownloadCuda(const std::string& engine,
bool async = false);
Expand Down
3 changes: 2 additions & 1 deletion engine/utils/engine_matcher_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ inline cpp::result<std::string, std::string> GetVariantFromNameAndVersion(
if (engine.empty()) {
return cpp::fail("Engine name is empty");
}
auto nv = string_utils::RemoveSubstring(version, "v");
using namespace string_utils;
auto removed_extension = RemoveSubstring(engine_file_name, ".tar.gz");
auto version_and_variant = RemoveSubstring(removed_extension, engine + "-");

auto variant = RemoveSubstring(version_and_variant, version + "-");
auto variant = RemoveSubstring(version_and_variant, nv + "-");
return variant;
}

Expand Down
8 changes: 6 additions & 2 deletions engine/utils/github_release_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ struct GitHubAsset {

std::string updated_at;
std::string browser_download_url;
std::string version;

static GitHubAsset FromJson(const Json::Value& json) {
static GitHubAsset FromJson(const Json::Value& json,
const std::string& version) {
return GitHubAsset{
.url = json["url"].asString(),
.id = json["id"].asInt(),
Expand All @@ -38,6 +40,7 @@ struct GitHubAsset {
.created_at = json["created_at"].asString(),
.updated_at = json["updated_at"].asString(),
.browser_download_url = json["browser_download_url"].asString(),
.version = version,
};
}

Expand All @@ -55,6 +58,7 @@ struct GitHubAsset {
root["created_at"] = created_at;
root["updated_at"] = updated_at;
root["browser_download_url"] = browser_download_url;
root["version"] = version;
return root;
}

Expand Down Expand Up @@ -93,7 +97,7 @@ struct GitHubRelease {
std::vector<GitHubAsset> assets = {};
if (json["assets"].isArray()) {
for (const auto& asset : json["assets"]) {
assets.push_back(GitHubAsset::FromJson(asset));
assets.push_back(GitHubAsset::FromJson(asset, json["name"].asString()));
}
}

Expand Down

0 comments on commit 8a31f11

Please sign in to comment.