Skip to content

Commit

Permalink
Merge pull request #1053 from janhq/fix/dynamically-get-cuda-dependen…
Browse files Browse the repository at this point in the history
…cy-version

fix: dynamically get cuda toolkit version
  • Loading branch information
namchuai authored Aug 30, 2024
2 parents a5a30c2 + 65de91a commit fa72355
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 45 deletions.
97 changes: 77 additions & 20 deletions engine/commands/engine_init_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// clang-format on
#include "utils/cuda_toolkit_utils.h"
#include "utils/engine_matcher_utils.h"
#if defined(_WIN32) || defined(__linux__)
#include "utils/file_manager_utils.h"
#endif

namespace commands {

Expand Down Expand Up @@ -60,21 +63,22 @@ bool EngineInitCmd::Exec() const {
variants.push_back(asset_name);
}

auto cuda_version = system_info_utils::GetCudaVersion();
LOG_INFO << "engineName_: " << engineName_;
LOG_INFO << "CUDA version: " << cuda_version;
std::string matched_variant = "";
auto cuda_driver_version = system_info_utils::GetCudaVersion();
LOG_INFO << "Engine: " << engineName_
<< ", CUDA driver version: " << cuda_driver_version;

std::string matched_variant{""};
if (engineName_ == "cortex.tensorrt-llm") {
matched_variant = engine_matcher_utils::ValidateTensorrtLlm(
variants, system_info.os, cuda_version);
variants, system_info.os, cuda_driver_version);
} else if (engineName_ == "cortex.onnx") {
matched_variant = engine_matcher_utils::ValidateOnnx(
variants, system_info.os, system_info.arch);
} else if (engineName_ == "cortex.llamacpp") {
auto suitable_avx = engine_matcher_utils::GetSuitableAvxVariant();
matched_variant = engine_matcher_utils::Validate(
variants, system_info.os, system_info.arch, suitable_avx,
cuda_version);
cuda_driver_version);
}
LOG_INFO << "Matched variant: " << matched_variant;
if (matched_variant.empty()) {
Expand Down Expand Up @@ -105,17 +109,46 @@ bool EngineInitCmd::Exec() const {
}}};

DownloadService download_service;
download_service.AddDownloadTask(downloadTask, [](const std::string&
absolute_path,
bool unused) {
download_service.AddDownloadTask(downloadTask, [this](
const std::string&
absolute_path,
bool unused) {
// try to unzip the downloaded file
std::filesystem::path downloadedEnginePath{absolute_path};
LOG_INFO << "Downloaded engine path: "
<< downloadedEnginePath.string();

archive_utils::ExtractArchive(
downloadedEnginePath.string(),
downloadedEnginePath.parent_path().parent_path().string());
std::filesystem::path extract_path =
downloadedEnginePath.parent_path().parent_path();

archive_utils::ExtractArchive(downloadedEnginePath.string(),
extract_path.string());
#if defined(_WIN32) || defined(__linux__)
// FIXME: hacky try to copy the file. Remove this when we are able to set the library path
auto engine_path = extract_path / engineName_;
LOG_INFO << "Source path: " << engine_path.string();
auto executable_path =
file_manager_utils::GetExecutableFolderContainerPath();
for (const auto& entry :
std::filesystem::recursive_directory_iterator(engine_path)) {
if (entry.is_regular_file() &&
entry.path().extension() != ".gz") {
std::filesystem::path relative_path =
std::filesystem::relative(entry.path(), engine_path);
std::filesystem::path destFile =
executable_path / relative_path;

std::filesystem::create_directories(destFile.parent_path());
std::filesystem::copy_file(
entry.path(), destFile,
std::filesystem::copy_options::overwrite_existing);

std::cout << "Copied: " << entry.path().filename().string()
<< " to " << destFile.string() << std::endl;
}
}
std::cout << "DLL copying completed successfully." << std::endl;
#endif

// remove the downloaded file
// TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
Expand All @@ -128,23 +161,47 @@ bool EngineInitCmd::Exec() const {
LOG_INFO << "Finished!";
});
if (system_info.os == "mac" || engineName_ == "cortex.onnx") {
return false;
// mac and onnx engine does not require cuda toolkit
return true;
}

// download cuda toolkit
const std::string jan_host = "https://catalog.jan.ai";
const std::string cuda_toolkit_file_name = "cuda.tar.gz";
const std::string download_id = "cuda";

auto gpu_driver_version = system_info_utils::GetDriverVersion();
// TODO: we don't have API to retrieve list of cuda toolkit dependencies atm because we hosting it at jan
// will have better logic after https://github.com/janhq/cortex/issues/1046 finished
// for now, assume that we have only 11.7 and 12.4
auto suitable_toolkit_version = "";
if (engineName_ == "cortex.tensorrt-llm") {
// for tensorrt-llm, we need to download cuda toolkit v12.4
suitable_toolkit_version = "12.4";
} else {
// llamacpp
auto cuda_driver_semver =
semantic_version_utils::SplitVersion(cuda_driver_version);
if (cuda_driver_semver.major == 11) {
suitable_toolkit_version = "11.7";
} else if (cuda_driver_semver.major == 12) {
suitable_toolkit_version = "12.4";
}
}

auto cuda_runtime_version =
cuda_toolkit_utils::GetCompatibleCudaToolkitVersion(
gpu_driver_version, system_info.os, engineName_);
// compare cuda driver version with cuda toolkit version
// cuda driver version should be greater than toolkit version to ensure compatibility
if (semantic_version_utils::CompareSemanticVersion(
cuda_driver_version, suitable_toolkit_version) < 0) {
LOG_ERROR << "Your Cuda driver version " << cuda_driver_version
<< " is not compatible with cuda toolkit version "
<< suitable_toolkit_version;
return false;
}

std::ostringstream cuda_toolkit_path;
cuda_toolkit_path << "dist/cuda-dependencies/" << 11.7 << "/"
<< system_info.os << "/"
<< cuda_toolkit_file_name;
cuda_toolkit_path << "dist/cuda-dependencies/"
<< cuda_driver_version << "/" << system_info.os
<< "/" << cuda_toolkit_file_name;

LOG_DEBUG << "Cuda toolkit download url: " << jan_host
<< cuda_toolkit_path.str();
Expand Down
20 changes: 15 additions & 5 deletions engine/utils/engine_matcher_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <trantor/utils/Logger.h>
#include <algorithm>
#include <iostream>
#include <iterator>
#include <regex>
#include <string>
Expand Down Expand Up @@ -93,9 +93,19 @@ inline std::string GetSuitableCudaVariant(
bestMatchMinor = variantMinor;
}
}
} else if (cuda_version.empty() && selectedVariant.empty()) {
// If no CUDA version is provided, select the variant without any CUDA in the name
selectedVariant = variant;
}
}

// If no CUDA version is provided, select the variant without any CUDA in the name
if (selectedVariant.empty()) {
LOG_WARN
<< "No suitable CUDA variant found, selecting a variant without CUDA";
for (const auto& variant : variants) {
if (variant.find("cuda") == std::string::npos) {
selectedVariant = variant;
LOG_INFO << "Found variant without CUDA: " << selectedVariant << "\n";
break;
}
}
}

Expand Down Expand Up @@ -177,4 +187,4 @@ inline std::string Validate(const std::vector<std::string>& variants,

return cuda_compatible;
}
} // namespace engine_matcher_utils
} // namespace engine_matcher_utils
67 changes: 47 additions & 20 deletions engine/utils/semantic_version_utils.h
Original file line number Diff line number Diff line change
@@ -1,34 +1,61 @@
#include <trantor/utils/Logger.h>
#include <sstream>
#include <vector>

namespace semantic_version_utils {
inline std::vector<int> SplitVersion(const std::string& version) {
std::vector<int> parts;
std::stringstream ss(version);
std::string part;
struct SemVer {
int major;
int minor;
int patch;
};

while (std::getline(ss, part, '.')) {
parts.push_back(std::stoi(part));
inline SemVer SplitVersion(const std::string& version) {
if (version.empty()) {
LOG_WARN << "Passed in version is empty!";
}
SemVer semVer = {0, 0, 0}; // default value
std::stringstream ss(version);
std::string part;

while (parts.size() < 3) {
parts.push_back(0);
int index = 0;
while (std::getline(ss, part, '.') && index < 3) {
int value = std::stoi(part);
switch (index) {
case 0:
semVer.major = value;
break;
case 1:
semVer.minor = value;
break;
case 2:
semVer.patch = value;
break;
}
++index;
}

return parts;
return semVer;
}

inline int CompareSemanticVersion(const std::string& version1,
const std::string& version2) {
std::vector<int> v1 = SplitVersion(version1);
std::vector<int> v2 = SplitVersion(version2);

for (size_t i = 0; i < 3; ++i) {
if (v1[i] < v2[i])
return -1;
if (v1[i] > v2[i])
return 1;
}
SemVer v1 = SplitVersion(version1);
SemVer v2 = SplitVersion(version2);

if (v1.major < v2.major)
return -1;
if (v1.major > v2.major)
return 1;

if (v1.minor < v2.minor)
return -1;
if (v1.minor > v2.minor)
return 1;

if (v1.patch < v2.patch)
return -1;
if (v1.patch > v2.patch)
return 1;

return 0;
}
} // namespace semantic_version_utils
} // namespace semantic_version_utils

0 comments on commit fa72355

Please sign in to comment.