Skip to content

Commit

Permalink
Add cusparseSpMV_preprocess to cusparse wrapper (#2384)
Browse files Browse the repository at this point in the history
Add cusparseSpMV_preprocess to cusparse wrapper

Authors:
  - Nicolas Blin (https://github.com/Kh4ster)

Approvers:
  - Micka (https://github.com/lowener)

URL: #2384
  • Loading branch information
Kh4ster authored Jul 24, 2024
1 parent ffceee2 commit c30fc23
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/core/cusparse_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
//
// (i.e., before including this header)
//
#define CUDA_VER_10_1_UP (CUDART_VERSION >= 10100)
#define CUDA_VER_10_1_UP (CUDART_VERSION >= 10010)
#define CUDA_VER_12_4_UP (CUDART_VERSION >= 12040)

namespace raft {

Expand All @@ -59,7 +60,7 @@ namespace detail {

inline const char* cusparse_error_to_string(cusparseStatus_t err)
{
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10100
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10010
return cusparseGetErrorString(err);
#else // CUDART_VERSION
switch (err) {
Expand Down
28 changes: 28 additions & 0 deletions cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,34 @@ inline cusparseStatus_t cusparsespmv(cusparseHandle_t handle,
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSpMV(handle, opA, alpha, matA, vecX, beta, vecY, CUDA_R_64F, alg, externalBuffer);
}
// cusparseSpMV_preprocess is only available starting CUDA 12.4
#if CUDA_VER_12_4_UP
template <
typename T,
typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>* = nullptr>
cusparseStatus_t cusparsespmv_preprocess(cusparseHandle_t handle,
cusparseOperation_t opA,
const T* alpha,
const cusparseSpMatDescr_t matA,
const cusparseDnVecDescr_t vecX,
const T* beta,
const cusparseDnVecDescr_t vecY,
cusparseSpMVAlg_t alg,
T* externalBuffer,
cudaStream_t stream)
{
auto constexpr float_type = []() constexpr {
if constexpr (std::is_same_v<T, float>) {
return CUDA_R_32F;
} else if constexpr (std::is_same_v<T, double>) {
return CUDA_R_64F;
}
}();
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSpMV_preprocess(
handle, opA, alpha, matA, vecX, beta, vecY, float_type, alg, externalBuffer);
}
#endif
/** @} */
#else
/**
Expand Down

0 comments on commit c30fc23

Please sign in to comment.