diff --git a/CHANGELOG.md b/CHANGELOG.md index 56f676702c..35986c2181 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## New Features - PR #7: Migrating cuml comms -> raft comms_t +- PR #15: add exception based error handling macros ## Improvements - PR #13: Add RMM_INCLUDE and RMM_LIBRARY options to allow linking to non-conda RMM diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 1ba7552f9c..3528c148df 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -44,18 +44,42 @@ #include #include +#include -#define NCCL_CHECK(call) \ - do { \ - ncclResult_t status = call; \ - ASSERT(ncclSuccess == status, "ERROR: NCCL call='%s'. Reason:%s\n", #call, \ - ncclGetErrorString(status)); \ - } while (0) +namespace raft { + +/** + * @brief Exception thrown when a NCCL error is encountered. + */ +struct nccl_error : public raft::exception { + explicit nccl_error(char const *const message) : raft::exception(message) {} + explicit nccl_error(std::string const &message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for NCCL runtime API functions. + * + * Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an + * exception detailing the NCCL error that occurred + */ +#define NCCL_TRY(call) \ + do { \ + ncclResult_t const status = (call); \ + if (ncclSuccess != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "NCCL error encountered at: ", "call='%s', Reason=%d:%s", \ + #call, status, ncclGetErrorString(status)); \ + throw raft::nccl_error(msg); \ + } \ + } while (0); #define NCCL_CHECK_NO_THROW(call) \ do { \ ncclResult_t status = call; \ - if (status != ncclSuccess) { \ + if (ncclSuccess != status) { \ printf("NCCL call='%s' failed. Reason:%s\n", #call, \ ncclGetErrorString(status)); \ } \ @@ -65,8 +89,6 @@ namespace raft { namespace comms { static size_t get_datatype_size(const datatype_t datatype) { - size_t ret = -1; - switch (datatype) { case datatype_t::CHAR: return sizeof(char); @@ -85,7 +107,7 @@ static size_t get_datatype_size(const datatype_t datatype) { case datatype_t::FLOAT64: return sizeof(double); default: - throw "Unsupported"; + RAFT_FAIL("Unsupported datatype."); } } @@ -145,13 +167,13 @@ class std_comms : public comms_iface { const std::shared_ptr device_allocator, cudaStream_t stream) : nccl_comm_(nccl_comm), - ucp_worker_(ucp_worker), - ucp_eps_(eps), + stream_(stream), num_ranks_(num_ranks), rank_(rank), - device_allocator_(device_allocator), - stream_(stream), - next_request_id_(0) { + ucp_worker_(ucp_worker), + ucp_eps_(eps), + next_request_id_(0), + device_allocator_(device_allocator) { initialize(); }; @@ -165,10 +187,10 @@ class std_comms : public comms_iface { const std::shared_ptr device_allocator, cudaStream_t stream) : nccl_comm_(nccl_comm), + stream_(stream), num_ranks_(num_ranks), rank_(rank), - device_allocator_(device_allocator), - stream_(stream) { + device_allocator_(device_allocator) { initialize(); }; @@ -324,29 +346,28 @@ class std_comms : public comms_iface { void allreduce(const void *sendbuff, void *recvbuff, size_t count, datatype_t datatype, op_t op, cudaStream_t stream) const { - NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, count, - get_nccl_datatype(datatype), get_nccl_op(op), - nccl_comm_, stream)); + NCCL_TRY(ncclAllReduce(sendbuff, recvbuff, count, + get_nccl_datatype(datatype), get_nccl_op(op), + nccl_comm_, stream)); } void bcast(void *buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const { - NCCL_CHECK(ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), - root, nccl_comm_, stream)); + NCCL_TRY(ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), root, + nccl_comm_, stream)); } void reduce(const void *sendbuff, void *recvbuff, size_t count, datatype_t datatype, op_t op, int root, cudaStream_t stream) const { - NCCL_CHECK(ncclReduce(sendbuff, recvbuff, count, - get_nccl_datatype(datatype), get_nccl_op(op), root, - nccl_comm_, stream)); + NCCL_TRY(ncclReduce(sendbuff, recvbuff, count, get_nccl_datatype(datatype), + get_nccl_op(op), root, nccl_comm_, stream)); } void allgather(const void *sendbuff, void *recvbuff, size_t sendcount, datatype_t datatype, cudaStream_t stream) const { - NCCL_CHECK(ncclAllGather(sendbuff, recvbuff, sendcount, - get_nccl_datatype(datatype), nccl_comm_, stream)); + NCCL_TRY(ncclAllGather(sendbuff, recvbuff, sendcount, + get_nccl_datatype(datatype), nccl_comm_, stream)); } void allgatherv(const void *sendbuf, void *recvbuf, const size_t recvcounts[], @@ -356,7 +377,7 @@ class std_comms : public comms_iface { //Listing 1 on page 4. for (int root = 0; root < num_ranks_; ++root) { size_t dtype_size = get_datatype_size(datatype); - NCCL_CHECK(ncclBroadcast( + NCCL_TRY(ncclBroadcast( sendbuf, static_cast(recvbuf) + displs[root] * dtype_size, recvcounts[root], get_nccl_datatype(datatype), root, nccl_comm_, stream)); @@ -365,9 +386,9 @@ class std_comms : public comms_iface { void reducescatter(const void *sendbuff, void *recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const { - NCCL_CHECK(ncclReduceScatter(sendbuff, recvbuff, recvcount, - get_nccl_datatype(datatype), get_nccl_op(op), - nccl_comm_, stream)); + NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount, + get_nccl_datatype(datatype), get_nccl_op(op), + nccl_comm_, stream)); } status_t sync_stream(cudaStream_t stream) const { diff --git a/cpp/include/raft/cudart_utils.h b/cpp/include/raft/cudart_utils.h index 47e76ab916..2ca23ba539 100644 --- a/cpp/include/raft/cudart_utils.h +++ b/cpp/include/raft/cudart_utils.h @@ -16,96 +16,70 @@ #pragma once +#include + #include + #include -#include #include -#include -#include -#include -#include -#include + ///@todo: enable once logging has been enabled in raft //#include "logger.hpp" namespace raft { -/** base exception class for the whole of raft */ -class exception : public std::exception { - public: - /** default ctor */ - explicit exception() noexcept : std::exception(), msg_() {} - - /** copy ctor */ - exception(const exception& src) noexcept - : std::exception(), msg_(src.what()) { - collect_call_stack(); - } - - /** ctor from an input message */ - explicit exception(const std::string _msg) noexcept - : std::exception(), msg_(std::move(_msg)) { - collect_call_stack(); - } - - /** get the message associated with this exception */ - const char* what() const noexcept override { return msg_.c_str(); } - - private: - /** message associated with this exception */ - std::string msg_; - - /** append call stack info to this exception's message for ease of debug */ - // Courtesy: https://www.gnu.org/software/libc/manual/html_node/Backtraces.html - void collect_call_stack() noexcept { -#ifdef __GNUC__ - constexpr int kMaxStackDepth = 64; - void* stack[kMaxStackDepth]; // NOLINT - auto depth = backtrace(stack, kMaxStackDepth); - std::ostringstream oss; - oss << std::endl << "Obtained " << depth << " stack frames" << std::endl; - char** strings = backtrace_symbols(stack, depth); - if (strings == nullptr) { - oss << "But no stack trace could be found!" << std::endl; - msg_ += oss.str(); - return; - } - ///@todo: support for demangling of C++ symbol names - for (int i = 0; i < depth; ++i) { - oss << "#" << i << " in " << strings[i] << std::endl; - } - free(strings); - msg_ += oss.str(); -#endif // __GNUC__ - } +/** + * @brief Exception thrown when a CUDA error is encountered. + */ +struct cuda_error : public raft::exception { + explicit cuda_error(char const* const message) : raft::exception(message) {} + explicit cuda_error(std::string const& message) : raft::exception(message) {} }; -/** macro to throw a runtime error */ -#define THROW(fmt, ...) \ - do { \ - std::string msg; \ - char errMsg[2048]; /* NOLINT */ \ - std::snprintf(errMsg, sizeof(errMsg), \ - "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \ - msg += errMsg; \ - std::snprintf(errMsg, sizeof(errMsg), fmt, ##__VA_ARGS__); \ - msg += errMsg; \ - throw raft::exception(msg); \ - } while (0) +} // namespace raft -/** macro to check for a conditional and assert on failure */ -#define ASSERT(check, fmt, ...) \ - do { \ - if (!(check)) THROW(fmt, ##__VA_ARGS__); \ +/** + * @brief Error checking macro for CUDA runtime API functions. + * + * Invokes a CUDA runtime API function call, if the call does not return + * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an + * exception detailing the CUDA error that occurred + * + */ +#define CUDA_TRY(call) \ + do { \ + cudaError_t const status = call; \ + if (status != cudaSuccess) { \ + cudaGetLastError(); \ + std::string msg{}; \ + SET_ERROR_MSG( \ + msg, "CUDA error encountered at: ", "call='%s', Reason=%s:%s", #call, \ + cudaGetErrorName(status), cudaGetErrorString(status)); \ + throw raft::cuda_error(msg); \ + } \ } while (0) -/** check for cuda runtime API errors and assert accordingly */ -#define CUDA_CHECK(call) \ - do { \ - cudaError_t status = call; \ - ASSERT(status == cudaSuccess, "FAIL: call='%s'. Reason:%s", #call, \ - cudaGetErrorString(status)); \ - } while (0) +/** + * @brief Debug macro to check for CUDA errors + * + * In a non-release build, this macro will synchronize the specified stream + * before error checking. In both release and non-release builds, this macro + * checks for any pending CUDA errors from previous calls. If an error is + * reported, an exception is thrown detailing the CUDA error that occurred. + * + * The intent of this macro is to provide a mechanism for synchronous and + * deterministic execution for debugging asynchronous CUDA execution. It should + * be used after any asynchronous CUDA call, e.g., cudaMemcpyAsync, or an + * asynchronous kernel launch. + */ +#ifndef NDEBUG +#define CHECK_CUDA(stream) CUDA_TRY(cudaStreamSynchronize(stream)); +#else +#define CHECK_CUDA(stream) CUDA_TRY(cudaPeekAtLastError()); +#endif + +/** FIXME: temporary alias for cuML compatibility */ +#define CUDA_CHECK(call) CUDA_TRY(call) ///@todo: enable this only after we have added logging support in raft // /** @@ -114,13 +88,15 @@ class exception : public std::exception { // */ #define CUDA_CHECK_NO_THROW(call) \ do { \ - cudaError_t status = call; \ - if (status != cudaSuccess) { \ + cudaError_t const status = call; \ + if (cudaSuccess != status) { \ printf("CUDA call='%s' at file=%s line=%d failed with %s\n", #call, \ __FILE__, __LINE__, cudaGetErrorString(status)); \ } \ } while (0) +namespace raft { + /** helper method to get max usable shared mem per block parameter */ inline int get_shared_memory_per_block() { int dev_id; @@ -211,4 +187,4 @@ void print_device_vector(const char* variable_name, const T* devMem, } /** @} */ -}; // namespace raft +} // namespace raft diff --git a/cpp/include/raft/error.hpp b/cpp/include/raft/error.hpp new file mode 100644 index 0000000000..0b001b01b2 --- /dev/null +++ b/cpp/include/raft/error.hpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace raft { + +/** base exception class for the whole of raft */ +class exception : public std::exception { + public: + /** default ctor */ + explicit exception() noexcept : std::exception(), msg_() {} + + /** copy ctor */ + exception(exception const& src) noexcept + : std::exception(), msg_(src.what()) { + collect_call_stack(); + } + + /** ctor from an input message */ + explicit exception(std::string const msg) noexcept + : std::exception(), msg_(std::move(msg)) { + collect_call_stack(); + } + + /** get the message associated with this exception */ + char const* what() const noexcept override { return msg_.c_str(); } + + private: + /** message associated with this exception */ + std::string msg_; + + /** append call stack info to this exception's message for ease of debug */ + // Courtesy: https://www.gnu.org/software/libc/manual/html_node/Backtraces.html + void collect_call_stack() noexcept { +#ifdef __GNUC__ + constexpr int kMaxStackDepth = 64; + void* stack[kMaxStackDepth]; // NOLINT + auto depth = backtrace(stack, kMaxStackDepth); + std::ostringstream oss; + oss << std::endl << "Obtained " << depth << " stack frames" << std::endl; + char** strings = backtrace_symbols(stack, depth); + if (strings == nullptr) { + oss << "But no stack trace could be found!" << std::endl; + msg_ += oss.str(); + return; + } + ///@todo: support for demangling of C++ symbol names + for (int i = 0; i < depth; ++i) { + oss << "#" << i << " in " << strings[i] << std::endl; + } + free(strings); + msg_ += oss.str(); +#endif // __GNUC__ + } +}; + +/** + * @brief Exception thrown when logical precondition is violated. + * + * This exception should not be thrown directly and is instead thrown by the + * RAFT_EXPECTS and RAFT_FAIL macros. + * + */ +struct logic_error : public raft::exception { + explicit logic_error(char const* const message) : raft::exception(message) {} + explicit logic_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +// FIXME: Need to be replaced with RAFT_FAIL +/** macro to throw a runtime error */ +#define THROW(fmt, ...) \ + do { \ + std::string msg; \ + char errMsg[2048]; /* NOLINT */ \ + std::snprintf(errMsg, sizeof(errMsg), \ + "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \ + msg += errMsg; \ + std::snprintf(errMsg, sizeof(errMsg), fmt, ##__VA_ARGS__); \ + msg += errMsg; \ + throw raft::exception(msg); \ + } while (0) + +// FIXME: Need to be replaced with RAFT_EXPECTS +/** macro to check for a conditional and assert on failure */ +#define ASSERT(check, fmt, ...) \ + do { \ + if (!(check)) THROW(fmt, ##__VA_ARGS__); \ + } while (0) + +#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \ + do { \ + char err_msg[2048]; /* NOLINT */ \ + std::snprintf(err_msg, sizeof(err_msg), location_prefix); \ + msg += err_msg; \ + std::snprintf(err_msg, sizeof(err_msg), "file=%s line=%d: ", __FILE__, \ + __LINE__); \ + msg += err_msg; \ + std::snprintf(err_msg, sizeof(err_msg), fmt, ##__VA_ARGS__); \ + msg += err_msg; \ + } while (0) + +/** + * @brief Macro for checking (pre-)conditions that throws an exception when a condition is false + * + * @param[in] cond Expression that evaluates to true or false + * @param[in] fmt String literal description of the reason that cond is expected to be true with + * optinal format tagas + * @throw raft::logic_error if the condition evaluates to false. + */ +#define RAFT_EXPECTS(cond, fmt, ...) \ + do { \ + if (!cond) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, "RAFT failure at ", fmt, ##__VA_ARGS__); \ + throw raft::logic_error(msg); \ + } \ + } while (0) + +/** + * @brief Indicates that an erroneous code path has been taken. + * + * @param[in] fmt String literal description of the reason that this code path is erroneous with + * optinal format tagas + * @throw always throws raft::logic_error + */ +#define RAFT_FAIL(fmt, ...) \ + do { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, "RAFT failure at ", fmt, ##__VA_ARGS__); \ + throw raft::logic_error(msg); \ + } while (0) diff --git a/cpp/include/raft/linalg/cublas_wrappers.h b/cpp/include/raft/linalg/cublas_wrappers.h index cd8a508a84..7e8a52196a 100644 --- a/cpp/include/raft/linalg/cublas_wrappers.h +++ b/cpp/include/raft/linalg/cublas_wrappers.h @@ -16,18 +16,32 @@ #pragma once +#include + #include ///@todo: enable this once we have logger enabled //#include -#include -#include -namespace raft { -namespace linalg { +#include #define _CUBLAS_ERR_TO_STR(err) \ case err: \ return #err + +namespace raft { + +/** + * @brief Exception thrown when a cuBLAS error is encountered. + */ +struct cublas_error : public raft::exception { + explicit cublas_error(char const *const message) : raft::exception(message) {} + explicit cublas_error(std::string const &message) + : raft::exception(message) {} +}; + +namespace linalg { +namespace detail { + inline const char *cublas_error_to_string(cublasStatus_t err) { switch (err) { _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_SUCCESS); @@ -44,27 +58,49 @@ inline const char *cublas_error_to_string(cublasStatus_t err) { return "CUBLAS_STATUS_UNKNOWN"; }; } + +} // namespace detail +} // namespace linalg +} // namespace raft + #undef _CUBLAS_ERR_TO_STR -/** check for cublas runtime API errors and assert accordingly */ -#define CUBLAS_CHECK(call) \ - do { \ - cublasStatus_t err = call; \ - ASSERT(err == CUBLAS_STATUS_SUCCESS, \ - "CUBLAS call='%s' got errorcode=%d err=%s", #call, err, \ - raft::linalg::cublas_error_to_string(err)); \ +/** + * @brief Error checking macro for cuBLAS runtime API functions. + * + * Invokes a cuBLAS runtime API function call, if the call does not return + * CUBLAS_STATUS_SUCCESS, throws an exception detailing the cuBLAS error that occurred + */ +#define CUBLAS_TRY(call) \ + do { \ + cublasStatus_t const status = (call); \ + if (CUBLAS_STATUS_SUCCESS != status) { \ + std::string msg{}; \ + SET_ERROR_MSG( \ + msg, "cuBLAS error encountered at: ", "call='%s', Reason=%d:%s", \ + #call, status, raft::linalg::detail::cublas_error_to_string(status)); \ + throw raft::cublas_error(msg); \ + } \ } while (0) +/** FIXME: temporary alias for cuML compatibility */ +#define CUBLAS_CHECK(call) CUBLAS_TRY(call) + ///@todo: enable this once we have logging enabled -// /** check for cublas runtime API errors but do not assert */ -// #define CUBLAS_CHECK_NO_THROW(call) \ -// do { \ -// cublasStatus_t err = call; \ -// if (err != CUBLAS_STATUS_SUCCESS) { \ -// CUML_LOG_ERROR("CUBLAS call='%s' got errorcode=%d err=%s", #call, err, \ -// raft::linalg::cublas_error_to_string(err)); \ -// } \ -// } while (0) +#if 0 +/** check for cublas runtime API errors but do not assert */ +define CUBLAS_CHECK_NO_THROW(call) \ + do { \ + cublasStatus_t err = call; \ + if (err != CUBLAS_STATUS_SUCCESS) { \ + CUML_LOG_ERROR("CUBLAS call='%s' got errorcode=%d err=%s", #call, err, \ + raft::linalg::detail::cublas_error_to_string(err)); \ + } \ + } while (0) +#endif + +namespace raft { +namespace linalg { /** * @defgroup Axpy cublas ax+y operations @@ -542,5 +578,5 @@ inline cublasStatus_t cublasdot(cublasHandle_t handle, int n, const double *x, } /** @} */ -}; // namespace linalg -}; // namespace raft +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/linalg/cusolver_wrappers.h b/cpp/include/raft/linalg/cusolver_wrappers.h index 92ba1a2194..a65042a2fd 100644 --- a/cpp/include/raft/linalg/cusolver_wrappers.h +++ b/cpp/include/raft/linalg/cusolver_wrappers.h @@ -22,12 +22,25 @@ //#include #include -namespace raft { -namespace linalg { - #define _CUSOLVER_ERR_TO_STR(err) \ case err: \ return #err; + +namespace raft { + +/** + * @brief Exception thrown when a cuSOLVER error is encountered. + */ +struct cusolver_error : public raft::exception { + explicit cusolver_error(char const *const message) + : raft::exception(message) {} + explicit cusolver_error(std::string const &message) + : raft::exception(message) {} +}; + +namespace linalg { +namespace detail { + inline const char *cusolver_error_to_string(cusolverStatus_t err) { switch (err) { _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_SUCCESS); @@ -44,27 +57,49 @@ inline const char *cusolver_error_to_string(cusolverStatus_t err) { return "CUSOLVER_STATUS_UNKNOWN"; }; } + +} // namespace detail +} // namespace linalg +} // namespace raft + #undef _CUSOLVER_ERR_TO_STR -/** check for cusolver runtime API errors and assert accordingly */ -#define CUSOLVER_CHECK(call) \ - do { \ - cusolverStatus_t err = call; \ - ASSERT(err == CUSOLVER_STATUS_SUCCESS, \ - "CUSOLVER call='%s' got errorcode=%d err=%s", #call, err, \ - raft::linalg::cusolver_error_to_string(err)); \ +/** + * @brief Error checking macro for cuSOLVER runtime API functions. + * + * Invokes a cuSOLVER runtime API function call, if the call does not return + * CUSolver_STATUS_SUCCESS, throws an exception detailing the cuSOLVER error that occurred + */ +#define CUSOLVER_TRY(call) \ + do { \ + cusolverStatus_t const status = (call); \ + if (CUSOLVER_STATUS_SUCCESS != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, "cuSOLVER error encountered at: ", \ + "call='%s', Reason=%d:%s", #call, status, \ + raft::linalg::detail::cusolver_error_to_string(status)); \ + throw raft::cublas_error(msg); \ + } \ } while (0) -///@todo: enable this once logging is enabled -// /** check for cusolver runtime API errors but do not assert */ -// #define CUSOLVER_CHECK_NO_THROW(call) \ -// do { \ -// cusolverStatus_t err = call; \ -// if (err != CUSOLVER_STATUS_SUCCESS) { \ -// CUML_LOG_ERROR("CUSOLVER call='%s' got errorcode=%d err=%s", #call, err, \ -// raft::linalg::cusolver_error_to_string(err)); \ -// } \ -// } while (0) +/** FIXME: temporary alias for cuML compatibility */ +#define CUSOLVER_CHECK(call) CUSOLVER_TRY(call) + +//@todo: enable this once logging is enabled +#if 0 +** check for cusolver runtime API errors but do not assert */ +define CUSOLVER_CHECK_NO_THROW(call) \ + do { \ + cusolverStatus_t err = call; \ + if (err != CUSOLVER_STATUS_SUCCESS) { \ + CUML_LOG_ERROR("CUSOLVER call='%s' got errorcode=%d err=%s", #call, err, \ + raft::linalg::detail::cusolver_error_to_string(err)); \ + } \ + } while (0) +#endif + +namespace raft { +namespace linalg { /** * @defgroup Getrf cusolver getrf operations @@ -683,5 +718,5 @@ inline cusolverStatus_t cusolverSpcsrqrsvBatched( // NOLINT } /** @} */ -}; // namespace linalg -}; // namespace raft +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/sparse/cusparse_wrappers.h b/cpp/include/raft/sparse/cusparse_wrappers.h index 1c63d2348b..9de242ea10 100644 --- a/cpp/include/raft/sparse/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/cusparse_wrappers.h @@ -16,17 +16,31 @@ #pragma once +#include + #include ///@todo: enable this once logging is enabled //#include -#include - -namespace raft { -namespace sparse { #define _CUSPARSE_ERR_TO_STR(err) \ case err: \ return #err; + +namespace raft { + +/** + * @brief Exception thrown when a cuSparse error is encountered. + */ +struct cusparse_error : public raft::exception { + explicit cusparse_error(char const* const message) + : raft::exception(message) {} + explicit cusparse_error(std::string const& message) + : raft::exception(message) {} +}; + +namespace sparse { +namespace detail { + inline const char* cusparse_error_to_string(cusparseStatus_t err) { #if defined(CUDART_VERSION) && CUDART_VERSION >= 10100 return cusparseGetErrorString(status); @@ -45,27 +59,49 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) { }; #endif // CUDART_VERSION } + +} // namespace detail +} // namespace sparse +} // namespace raft + #undef _CUSPARSE_ERR_TO_STR -/** check for cusparse runtime API errors and assert accordingly */ -#define CUSPARSE_CHECK(call) \ - do { \ - cusparseStatus_t err = call; \ - ASSERT(err == CUSPARSE_STATUS_SUCCESS, \ - "CUSPARSE call='%s' got errorcode=%d err=%s", #call, err, \ - raft::sparse::cusparse_error_to_string(err)); \ +/** + * @brief Error checking macro for cuSparse runtime API functions. + * + * Invokes a cuSparse runtime API function call, if the call does not return + * CUSPARSE_STATUS_SUCCESS, throws an exception detailing the cuSparse error that occurred + */ +#define CUSPARSE_TRY(call) \ + do { \ + cusparseStatus_t const status = (call); \ + if (CUSPARSE_STATUS_SUCCESS != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, "cuSparse error encountered at: ", \ + "call='%s', Reason=%d:%s", #call, status, \ + raft::sparse::detail::cusparse_error_to_string(status)); \ + throw raft::cusparse_error(msg); \ + } \ } while (0) -///@todo: enable this once logging is enabled -// /** check for cusparse runtime API errors but do not assert */ -// #define CUSPARSE_CHECK_NO_THROW(call) \ -// do { \ -// cusparseStatus_t err = call; \ -// if (err != CUSPARSE_STATUS_SUCCESS) { \ -// CUML_LOG_ERROR("CUSPARSE call='%s' got errorcode=%d err=%s", #call, err, \ -// raft::sparse::cusparse_error_to_string(err)); \ -// } \ -// } while (0) +/** FIXME: temporary alias for cuML compatibility */ +#define CUSPARSE_CHECK(call) CUSPARSE_TRY(call) + +//@todo: enable this once logging is enabled +#if 0 +/** check for cusparse runtime API errors but do not assert */ +#define CUSPARSE_CHECK_NO_THROW(call) \ + do { \ + cusparseStatus_t err = call; \ + if (err != CUSPARSE_STATUS_SUCCESS) { \ + CUML_LOG_ERROR("CUSPARSE call='%s' got errorcode=%d err=%s", #call, err, \ + raft::sparse::detail::cusparse_error_to_string(err)); \ + } \ + } while (0) +#endif + +namespace raft { +namespace sparse { /** * @defgroup gthr cusparse gather methods @@ -162,5 +198,5 @@ inline cusparseStatus_t cusparsegemmi( } /** @} */ -}; // namespace sparse -}; // namespace raft +} // namespace sparse +} // namespace raft