diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index 4b7605cc63..97b442afe3 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -104,7 +104,7 @@ class handle_t { { std::lock_guard _(mutex_); if (!cublas_initialized_) { - RAFT_CUBLAS_TRY(cublasCreate(&cublas_handle_)); + RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_handle_)); cublas_initialized_ = true; } return cublas_handle_; @@ -114,7 +114,7 @@ class handle_t { { std::lock_guard _(mutex_); if (!cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY(cusolverDnCreate(&cusolver_dn_handle_)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_dn_handle_)); cusolver_dn_initialized_ = true; } return cusolver_dn_handle_; @@ -124,7 +124,7 @@ class handle_t { { std::lock_guard _(mutex_); if (!cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY(cusolverSpCreate(&cusolver_sp_handle_)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_sp_handle_)); cusolver_sp_initialized_ = true; } return cusolver_sp_handle_; @@ -134,7 +134,7 @@ class handle_t { { std::lock_guard _(mutex_); if (!cusparse_initialized_) { - RAFT_CUSPARSE_TRY(cusparseCreate(&cusparse_handle_)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_handle_)); cusparse_initialized_ = true; } return cusparse_handle_; @@ -218,7 +218,7 @@ class handle_t { { std::lock_guard _(mutex_); if (!device_prop_initialized_) { - RAFT_CUDA_TRY(cudaGetDeviceProperties(&prop_, dev_id_)); + RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id_)); device_prop_initialized_ = true; } return prop_; @@ -253,11 +253,15 @@ class handle_t { void destroy_resources() { ///@todo: enable *_NO_THROW variants once we have enabled logging - if (cusparse_initialized_) { RAFT_CUSPARSE_TRY(cusparseDestroy(cusparse_handle_)); } - if (cusolver_dn_initialized_) { RAFT_CUSOLVER_TRY(cusolverDnDestroy(cusolver_dn_handle_)); } - if (cusolver_sp_initialized_) { RAFT_CUSOLVER_TRY(cusolverSpDestroy(cusolver_sp_handle_)); } - if (cublas_initialized_) { RAFT_CUBLAS_TRY(cublasDestroy(cublas_handle_)); } - RAFT_CUDA_TRY(cudaEventDestroy(event_)); + if (cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_handle_)); } + if (cusolver_dn_initialized_) { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_dn_handle_)); + } + if (cusolver_sp_initialized_) { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_sp_handle_)); + } + if (cublas_initialized_) { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_handle_)); } + RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); } }; // class handle_t diff --git a/cpp/include/raft/linalg/cublas_wrappers.h b/cpp/include/raft/linalg/cublas_wrappers.h index d125aa40dd..024ed4a0e2 100644 --- a/cpp/include/raft/linalg/cublas_wrappers.h +++ b/cpp/include/raft/linalg/cublas_wrappers.h @@ -89,8 +89,31 @@ inline const char* cublas_error_to_string(cublasStatus_t err) #define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call) #endif +// /** +// * @brief check for cuda runtime API errors but log error instead of raising +// * exception. +// */ +#define RAFT_CUBLAS_TRY_NO_THROW(call) \ + do { \ + cublasStatus_t const status = call; \ + if (CUBLAS_STATUS_SUCCESS != status) { \ + printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \ + #call, \ + __FILE__, \ + __LINE__, \ + raft::linalg::detail::cublas_error_to_string(status)); \ + } \ + } while (0) + /** FIXME: remove after cuml rename */ +#ifndef CUBLAS_CHECK #define CUBLAS_CHECK(call) CUBLAS_TRY(call) +#endif + +/** FIXME: remove after cuml rename */ +#ifndef CUBLAS_CHECK_NO_THROW +#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call) +#endif namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/cusolver_wrappers.h b/cpp/include/raft/linalg/cusolver_wrappers.h index 0c94804111..988e7512d5 100644 --- a/cpp/include/raft/linalg/cusolver_wrappers.h +++ b/cpp/include/raft/linalg/cusolver_wrappers.h @@ -88,11 +88,31 @@ inline const char* cusolver_error_to_string(cusolverStatus_t err) #define CUSOLVER_TRY(call) RAFT_CUSOLVER_TRY(call) #endif +// /** +// * @brief check for cuda runtime API errors but log error instead of raising +// * exception. +// */ +#define RAFT_CUSOLVER_TRY_NO_THROW(call) \ + do { \ + cusolverStatus_t const status = call; \ + if (CUSOLVER_STATUS_SUCCESS != status) { \ + printf("CUSOLVER call='%s' at file=%s line=%d failed with %s\n", \ + #call, \ + __FILE__, \ + __LINE__, \ + raft::linalg::detail::cusolver_error_to_string(status)); \ + } \ + } while (0) + // FIXME: remove after cuml rename #ifndef CUSOLVER_CHECK #define CUSOLVER_CHECK(call) CUSOLVER_TRY(call) #endif +#ifndef CUSOLVER_CHECK_NO_THROW +#define CUSOLVER_CHECK_NO_THROW(call) CUSOLVER_TRY_NO_THROW(call) +#endif + namespace raft { namespace linalg {