From 3c4c1a9b8b8038cf69694a29c6dde05d89ed421e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 22 Feb 2022 11:28:02 -0500 Subject: [PATCH] Separating cublas/cusolver macros from wrappers --- cpp/include/raft/linalg/cublas_macros.h | 116 ++++++++++++++++++ cpp/include/raft/linalg/cusolver_macros.h | 112 +++++++++++++++++ .../raft/linalg/detail/cublas_wrappers.hpp | 96 +-------------- .../raft/linalg/detail/cusolver_wrappers.hpp | 91 +------------- 4 files changed, 230 insertions(+), 185 deletions(-) create mode 100644 cpp/include/raft/linalg/cublas_macros.h create mode 100644 cpp/include/raft/linalg/cusolver_macros.h diff --git a/cpp/include/raft/linalg/cublas_macros.h b/cpp/include/raft/linalg/cublas_macros.h new file mode 100644 index 0000000000..3979802f2c --- /dev/null +++ b/cpp/include/raft/linalg/cublas_macros.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +///@todo: enable this once we have logger enabled +//#include + +#include + +#define _CUBLAS_ERR_TO_STR(err) \ + case err: return #err + +namespace raft { + +/** + * @brief Exception thrown when a cuBLAS error is encountered. + */ + struct cublas_error : public raft::exception { + explicit cublas_error(char const* const message) : raft::exception(message) {} + explicit cublas_error(std::string const& message) : raft::exception(message) {} + }; + + namespace linalg { + namespace detail { + + inline const char* cublas_error_to_string(cublasStatus_t err) + { + switch (err) { + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_SUCCESS); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ALLOC_FAILED); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INVALID_VALUE); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_MAPPING_ERROR); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED); + _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_LICENSE_ERROR); + default: return "CUBLAS_STATUS_UNKNOWN"; + }; + } + + } // namespace detail + } // namespace linalg +} // namespace raft + +#undef _CUBLAS_ERR_TO_STR + +/** + * @brief Error checking macro for cuBLAS runtime API functions. + * + * Invokes a cuBLAS runtime API function call, if the call does not return + * CUBLAS_STATUS_SUCCESS, throws an exception detailing the cuBLAS error that occurred + */ +#define RAFT_CUBLAS_TRY(call) \ + do { \ + cublasStatus_t const status = (call); \ + if (CUBLAS_STATUS_SUCCESS != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "cuBLAS error encountered at: ", \ + "call='%s', Reason=%d:%s", \ + #call, \ + status, \ + raft::linalg::detail::cublas_error_to_string(status)); \ + throw raft::cublas_error(msg); \ + } \ + } while (0) + +// FIXME: Remove after consumers rename +#ifndef CUBLAS_TRY +#define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call) +#endif + +// /** +// * @brief check for cuda runtime API errors but log error instead of raising +// * exception. +// */ +#define RAFT_CUBLAS_TRY_NO_THROW(call) \ + do { \ + cublasStatus_t const status = call; \ + if (CUBLAS_STATUS_SUCCESS != status) { \ + printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \ + #call, \ + __FILE__, \ + __LINE__, \ + raft::linalg::detail::cublas_error_to_string(status)); \ + } \ + } while (0) + +/** FIXME: remove after cuml rename */ +#ifndef CUBLAS_CHECK +#define CUBLAS_CHECK(call) CUBLAS_TRY(call) +#endif + +/** FIXME: remove after cuml rename */ +#ifndef CUBLAS_CHECK_NO_THROW +#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call) +#endif diff --git a/cpp/include/raft/linalg/cusolver_macros.h b/cpp/include/raft/linalg/cusolver_macros.h new file mode 100644 index 0000000000..4899a67eea --- /dev/null +++ b/cpp/include/raft/linalg/cusolver_macros.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +///@todo: enable this once logging is enabled +//#include +#include +#include + +#define _CUSOLVER_ERR_TO_STR(err) \ + case err: return #err; + +namespace raft { + +/** + * @brief Exception thrown when a cuSOLVER error is encountered. + */ + struct cusolver_error : public raft::exception { + explicit cusolver_error(char const* const message) : raft::exception(message) {} + explicit cusolver_error(std::string const& message) : raft::exception(message) {} + }; + + namespace linalg { + + inline const char* cusolver_error_to_string(cusolverStatus_t err) + { + switch (err) { + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_SUCCESS); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_INITIALIZED); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ALLOC_FAILED); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INVALID_VALUE); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ARCH_MISMATCH); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_EXECUTION_FAILED); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INTERNAL_ERROR); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ZERO_PIVOT); + _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_SUPPORTED); + default: return "CUSOLVER_STATUS_UNKNOWN"; + }; + } + + } // namespace linalg +} // namespace raft + +#undef _CUSOLVER_ERR_TO_STR + +/** + * @brief Error checking macro for cuSOLVER runtime API functions. + * + * Invokes a cuSOLVER runtime API function call, if the call does not return + * CUSolver_STATUS_SUCCESS, throws an exception detailing the cuSOLVER error that occurred + */ +#define RAFT_CUSOLVER_TRY(call) \ + do { \ + cusolverStatus_t const status = (call); \ + if (CUSOLVER_STATUS_SUCCESS != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "cuSOLVER error encountered at: ", \ + "call='%s', Reason=%d:%s", \ + #call, \ + status, \ + raft::linalg::detail::cusolver_error_to_string(status)); \ + throw raft::cusolver_error(msg); \ + } \ + } while (0) + +// FIXME: remove after consumer rename +#ifndef CUSOLVER_TRY +#define CUSOLVER_TRY(call) RAFT_CUSOLVER_TRY(call) +#endif + +// /** +// * @brief check for cuda runtime API errors but log error instead of raising +// * exception. +// */ +#define RAFT_CUSOLVER_TRY_NO_THROW(call) \ + do { \ + cusolverStatus_t const status = call; \ + if (CUSOLVER_STATUS_SUCCESS != status) { \ + printf("CUSOLVER call='%s' at file=%s line=%d failed with %s\n", \ + #call, \ + __FILE__, \ + __LINE__, \ + raft::linalg::detail::cusolver_error_to_string(status)); \ + } \ + } while (0) + +// FIXME: remove after cuml rename +#ifndef CUSOLVER_CHECK +#define CUSOLVER_CHECK(call) CUSOLVER_TRY(call) +#endif + +#ifndef CUSOLVER_CHECK_NO_THROW +#define CUSOLVER_CHECK_NO_THROW(call) CUSOLVER_TRY_NO_THROW(call) +#endif diff --git a/cpp/include/raft/linalg/detail/cublas_wrappers.hpp b/cpp/include/raft/linalg/detail/cublas_wrappers.hpp index 752235d246..a926246c08 100644 --- a/cpp/include/raft/linalg/detail/cublas_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cublas_wrappers.hpp @@ -17,103 +17,9 @@ #pragma once #include - +#include #include -///@todo: enable this once we have logger enabled -//#include - -#include - -#define _CUBLAS_ERR_TO_STR(err) \ - case err: return #err - -namespace raft { - -/** - * @brief Exception thrown when a cuBLAS error is encountered. - */ -struct cublas_error : public raft::exception { - explicit cublas_error(char const* const message) : raft::exception(message) {} - explicit cublas_error(std::string const& message) : raft::exception(message) {} -}; - -namespace linalg { -namespace detail { -inline const char* cublas_error_to_string(cublasStatus_t err) -{ - switch (err) { - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_SUCCESS); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ALLOC_FAILED); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INVALID_VALUE); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_MAPPING_ERROR); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED); - _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_LICENSE_ERROR); - default: return "CUBLAS_STATUS_UNKNOWN"; - }; -} - -} // namespace detail -} // namespace linalg -} // namespace raft - -#undef _CUBLAS_ERR_TO_STR - -/** - * @brief Error checking macro for cuBLAS runtime API functions. - * - * Invokes a cuBLAS runtime API function call, if the call does not return - * CUBLAS_STATUS_SUCCESS, throws an exception detailing the cuBLAS error that occurred - */ -#define RAFT_CUBLAS_TRY(call) \ - do { \ - cublasStatus_t const status = (call); \ - if (CUBLAS_STATUS_SUCCESS != status) { \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "cuBLAS error encountered at: ", \ - "call='%s', Reason=%d:%s", \ - #call, \ - status, \ - raft::linalg::detail::cublas_error_to_string(status)); \ - throw raft::cublas_error(msg); \ - } \ - } while (0) - -// FIXME: Remove after consumers rename -#ifndef CUBLAS_TRY -#define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call) -#endif - -// /** -// * @brief check for cuda runtime API errors but log error instead of raising -// * exception. -// */ -#define RAFT_CUBLAS_TRY_NO_THROW(call) \ - do { \ - cublasStatus_t const status = call; \ - if (CUBLAS_STATUS_SUCCESS != status) { \ - printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \ - #call, \ - __FILE__, \ - __LINE__, \ - raft::linalg::detail::cublas_error_to_string(status)); \ - } \ - } while (0) - -/** FIXME: remove after cuml rename */ -#ifndef CUBLAS_CHECK -#define CUBLAS_CHECK(call) CUBLAS_TRY(call) -#endif - -/** FIXME: remove after cuml rename */ -#ifndef CUBLAS_CHECK_NO_THROW -#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call) -#endif namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp index 34ec6cb673..c6e294de5f 100644 --- a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp @@ -18,101 +18,12 @@ #include #include +#include ///@todo: enable this once logging is enabled //#include #include #include -#define _CUSOLVER_ERR_TO_STR(err) \ - case err: return #err; - -namespace raft { - -/** - * @brief Exception thrown when a cuSOLVER error is encountered. - */ -struct cusolver_error : public raft::exception { - explicit cusolver_error(char const* const message) : raft::exception(message) {} - explicit cusolver_error(std::string const& message) : raft::exception(message) {} -}; - -namespace linalg { -namespace detail { - -inline const char* cusolver_error_to_string(cusolverStatus_t err) -{ - switch (err) { - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_SUCCESS); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_INITIALIZED); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ALLOC_FAILED); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INVALID_VALUE); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ARCH_MISMATCH); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_EXECUTION_FAILED); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INTERNAL_ERROR); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ZERO_PIVOT); - _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_SUPPORTED); - default: return "CUSOLVER_STATUS_UNKNOWN"; - }; -} - -} // namespace detail -} // namespace linalg -} // namespace raft - -#undef _CUSOLVER_ERR_TO_STR - -/** - * @brief Error checking macro for cuSOLVER runtime API functions. - * - * Invokes a cuSOLVER runtime API function call, if the call does not return - * CUSolver_STATUS_SUCCESS, throws an exception detailing the cuSOLVER error that occurred - */ -#define RAFT_CUSOLVER_TRY(call) \ - do { \ - cusolverStatus_t const status = (call); \ - if (CUSOLVER_STATUS_SUCCESS != status) { \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "cuSOLVER error encountered at: ", \ - "call='%s', Reason=%d:%s", \ - #call, \ - status, \ - raft::linalg::detail::cusolver_error_to_string(status)); \ - throw raft::cusolver_error(msg); \ - } \ - } while (0) - -// FIXME: remove after consumer rename -#ifndef CUSOLVER_TRY -#define CUSOLVER_TRY(call) RAFT_CUSOLVER_TRY(call) -#endif - -// /** -// * @brief check for cuda runtime API errors but log error instead of raising -// * exception. -// */ -#define RAFT_CUSOLVER_TRY_NO_THROW(call) \ - do { \ - cusolverStatus_t const status = call; \ - if (CUSOLVER_STATUS_SUCCESS != status) { \ - printf("CUSOLVER call='%s' at file=%s line=%d failed with %s\n", \ - #call, \ - __FILE__, \ - __LINE__, \ - raft::linalg::detail::cusolver_error_to_string(status)); \ - } \ - } while (0) - -// FIXME: remove after cuml rename -#ifndef CUSOLVER_CHECK -#define CUSOLVER_CHECK(call) CUSOLVER_TRY(call) -#endif - -#ifndef CUSOLVER_CHECK_NO_THROW -#define CUSOLVER_CHECK_NO_THROW(call) CUSOLVER_TRY_NO_THROW(call) -#endif - namespace raft { namespace linalg { namespace detail {