diff --git a/cpp/include/kvikio/shim/cufile.hpp b/cpp/include/kvikio/shim/cufile.hpp index a9976fb915..d0cd56fbcd 100644 --- a/cpp/include/kvikio/shim/cufile.hpp +++ b/cpp/include/kvikio/shim/cufile.hpp @@ -49,7 +49,7 @@ class cuFileAPI { private: cuFileAPI() { - void* lib = load_library("libcufile.so.1"); + void* lib = load_library({"libcufile.so.1", "libcufile.so.0", "libcufile.so"}); get_symbol(HandleRegister, lib, KVIKIO_STRINGIFY(cuFileHandleRegister)); get_symbol(HandleDeregister, lib, KVIKIO_STRINGIFY(cuFileHandleDeregister)); get_symbol(Read, lib, KVIKIO_STRINGIFY(cuFileRead)); diff --git a/cpp/include/kvikio/shim/utils.hpp b/cpp/include/kvikio/shim/utils.hpp index 4dda66ff21..2a500db20c 100644 --- a/cpp/include/kvikio/shim/utils.hpp +++ b/cpp/include/kvikio/shim/utils.hpp @@ -18,6 +18,8 @@ #include #include #include +#include +#include namespace kvikio { @@ -38,6 +40,26 @@ inline void* load_library(const char* name, int mode = RTLD_LAZY | RTLD_LOCAL | return ret; } +/** + * @brief Load shared library + * + * @param names Vector of names to try when loading shared library. + * @return The library handle. + */ +inline void* load_library(const std::vector& names, + int mode = RTLD_LAZY | RTLD_LOCAL | RTLD_NODELETE) +{ + std::stringstream ss; + for (const char* name : names) { + ss << name << " "; + try { + return load_library(name); + } catch (const std::runtime_error&) { + } + } + throw std::runtime_error("cannot open shared object file, tried: " + ss.str()); +} + /** * @brief Get symbol using `dlsym` *