diff --git a/taichi/backends/cuda/cuda_context.cpp b/taichi/backends/cuda/cuda_context.cpp index a138c3d555395..f1cb1ba5bba3c 100644 --- a/taichi/backends/cuda/cuda_context.cpp +++ b/taichi/backends/cuda/cuda_context.cpp @@ -13,7 +13,10 @@ TLANG_NAMESPACE_BEGIN CUDAContext::CUDAContext() - : profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) { + : profiler_(nullptr), + driver_(CUDADriver::get_instance_without_context()), + cusparse_driver_(CUSPARSEDriver::get_instance()), + cusolver_driver_(CUSOLVERDriver::get_instance()) { // CUDA initialization dev_count_ = 0; driver_.init(0); diff --git a/taichi/backends/cuda/cuda_context.h b/taichi/backends/cuda/cuda_context.h index 69a02adf6f082..0834de1120d74 100644 --- a/taichi/backends/cuda/cuda_context.h +++ b/taichi/backends/cuda/cuda_context.h @@ -15,6 +15,8 @@ TLANG_NAMESPACE_BEGIN // cases such as unit testing where many Taichi programs are created/destroyed. class CUDADriver; +class CUSPARSEDriver; +class CUSOLVERDriver; class CUDAContext { private: @@ -26,6 +28,8 @@ class CUDAContext { std::mutex lock_; KernelProfilerBase *profiler_; CUDADriver &driver_; + CUSPARSEDriver &cusparse_driver_; + CUSOLVERDriver &cusolver_driver_; bool debug_; public: diff --git a/taichi/backends/cuda/cuda_driver.cpp b/taichi/backends/cuda/cuda_driver.cpp index e01e1c0fe5bf2..16370be1543ad 100644 --- a/taichi/backends/cuda/cuda_driver.cpp +++ b/taichi/backends/cuda/cuda_driver.cpp @@ -20,24 +20,7 @@ bool CUDADriver::detected() { } CUDADriver::CUDADriver() { - disabled_by_env_ = (get_environ_config("TI_ENABLE_CUDA", 1) == 0); - if (disabled_by_env_) { - TI_TRACE("CUDA driver disabled by enviroment variable \"TI_ENABLE_CUDA\"."); - return; - } - -#if defined(TI_PLATFORM_LINUX) - loader_ = std::make_unique("libcuda.so"); -#elif defined(TI_PLATFORM_WINDOWS) - loader_ = std::make_unique("nvcuda.dll"); -#else - static_assert(false, "Taichi CUDA driver supports only Windows and Linux."); -#endif - - if (!loader_->loaded()) { - TI_WARN("CUDA driver not found."); - return; - } + load_lib("libcuda.so", "nvcuda.dll"); loader_->load_function("cuGetErrorName", get_error_name); loader_->load_function("cuGetErrorString", get_error_string); @@ -78,4 +61,48 @@ CUDADriver &CUDADriver::get_instance() { return get_instance_without_context(); } +CUDADriverBase::CUDADriverBase() { + disabled_by_env_ = (get_environ_config("TI_ENABLE_CUDA", 1) == 0); + if (disabled_by_env_) { + TI_TRACE("CUDA driver disabled by enviroment variable \"TI_ENABLE_CUDA\"."); + return; + } +} + +void CUDADriverBase::load_lib(std::string lib_linux, std::string lib_windows) { +#if defined(TI_PLATFORM_LINUX) + auto lib_name = lib_linux; +#elif defined(TI_PLATFORM_WINDOWS) + auto lib_name = lib_windows; +#else + static_assert(false, "Taichi CUDA driver supports only Windows and Linux."); +#endif + + loader_ = std::make_unique(lib_name); + if (!loader_->loaded()) { + TI_WARN("{} lib not found.", lib_name); + return; + } else { + TI_TRACE("{} loaded!", lib_name); + } +} + +CUSPARSEDriver::CUSPARSEDriver() { + load_lib("libcusparse.so", "cusparse.dll"); +} + +CUSPARSEDriver &CUSPARSEDriver::get_instance() { + static CUSPARSEDriver *instance = new CUSPARSEDriver(); + return *instance; +} + +CUSOLVERDriver::CUSOLVERDriver() { + load_lib("libcusolver.so", "cusolver.dll"); +} + +CUSOLVERDriver &CUSOLVERDriver::get_instance() { + static CUSOLVERDriver *instance = new CUSOLVERDriver(); + return *instance; +} + TLANG_NAMESPACE_END diff --git a/taichi/backends/cuda/cuda_driver.h b/taichi/backends/cuda/cuda_driver.h index 40e3eff765c63..c0569d05402e2 100644 --- a/taichi/backends/cuda/cuda_driver.h +++ b/taichi/backends/cuda/cuda_driver.h @@ -95,7 +95,20 @@ class CUDADriverFunction { std::mutex *driver_lock_{nullptr}; }; -class CUDADriver { +class CUDADriverBase { + public: + ~CUDADriverBase() = default; + + protected: + std::unique_ptr loader_; + CUDADriverBase(); + + void load_lib(std::string lib_linux, std::string lib_windows); + + bool disabled_by_env_{false}; +}; + +class CUDADriver : protected CUDADriverBase { public: #define PER_CUDA_FUNCTION(name, symbol_name, ...) \ CUDADriverFunction<__VA_ARGS__> name; @@ -110,8 +123,6 @@ class CUDADriver { bool detected(); - ~CUDADriver() = default; - static CUDADriver &get_instance(); static CUDADriver &get_instance_without_context(); @@ -119,12 +130,27 @@ class CUDADriver { private: CUDADriver(); - std::unique_ptr loader_; - std::mutex lock_; - bool disabled_by_env_{false}; bool cuda_version_valid_{false}; }; +class CUSPARSEDriver : protected CUDADriverBase { + public: + // TODO: Add cusparse function APIs + static CUSPARSEDriver &get_instance(); + + private: + CUSPARSEDriver(); +}; + +class CUSOLVERDriver : protected CUDADriverBase { + public: + // TODO: Add cusolver function APIs + static CUSOLVERDriver &get_instance(); + + private: + CUSOLVERDriver(); +}; + TLANG_NAMESPACE_END