diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 041991521b..c8e4229203 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -17,47 +17,59 @@ #pragma once #include +#include #include -#include +#include namespace raft { namespace sparse { namespace detail { /** - * @defgroup gthr cusparse gather methods + * @defgroup gather cusparse gather methods * @{ */ -template -cusparseStatus_t cusparsegthr( - cusparseHandle_t handle, int nnz, const T* vals, T* vals_sorted, int* d_P, cudaStream_t stream); -template <> -inline cusparseStatus_t cusparsegthr(cusparseHandle_t handle, - int nnz, - const double* vals, - double* vals_sorted, - int* d_P, - cudaStream_t stream) +inline cusparseStatus_t cusparsegather(cusparseHandle_t handle, + cusparseDnVecDescr_t vecY, + cusparseSpVecDescr_t vecX, + cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseDgthr(handle, nnz, vals, vals_sorted, d_P, CUSPARSE_INDEX_BASE_ZERO); -#pragma GCC diagnostic pop + return cusparseGather(handle, vecY, vecX); } -template <> -inline cusparseStatus_t cusparsegthr(cusparseHandle_t handle, - int nnz, - const float* vals, - float* vals_sorted, - int* d_P, - cudaStream_t stream) + +template < + typename T, + typename std::enable_if_t || std::is_same_v>* = nullptr> +cusparseStatus_t cusparsegthr( + cusparseHandle_t handle, int nnz, const T* vals, T* vals_sorted, int* d_P, cudaStream_t stream) { + auto constexpr float_type = []() constexpr + { + if constexpr (std::is_same_v) { + return CUDA_R_32F; + } else if constexpr (std::is_same_v) { + return CUDA_R_64F; + } + } + (); CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseSgthr(handle, nnz, vals, vals_sorted, d_P, CUSPARSE_INDEX_BASE_ZERO); -#pragma GCC diagnostic pop + auto dense_vector_descr = cusparseDnVecDescr_t{}; + auto sparse_vector_descr = cusparseSpVecDescr_t{}; + CUSPARSE_CHECK(cusparseCreateDnVec( + &dense_vector_descr, nnz, static_cast(const_cast(vals)), float_type)); + CUSPARSE_CHECK(cusparseCreateSpVec(&sparse_vector_descr, + nnz, + nnz, + static_cast(d_P), + static_cast(vals_sorted), + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, + float_type)); + auto return_value = cusparseGather(handle, dense_vector_descr, sparse_vector_descr); + CUSPARSE_CHECK(cusparseDestroyDnVec(dense_vector_descr)); + CUSPARSE_CHECK(cusparseDestroySpVec(sparse_vector_descr)); + return return_value; } /** @} */ @@ -138,77 +150,6 @@ inline void cusparsecoosortByRow( // NOLINT } /** @} */ -/** - * @defgroup Gemmi cusparse gemmi operations - * @{ - */ -template -cusparseStatus_t cusparsegemmi( // NOLINT - cusparseHandle_t handle, - int m, - int n, - int k, - int nnz, - const T* alpha, - const T* A, - int lda, - const T* cscValB, - const int* cscColPtrB, - const int* cscRowIndB, - const T* beta, - T* C, - int ldc, - cudaStream_t stream); -template <> -inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, - int m, - int n, - int k, - int nnz, - const float* alpha, - const float* A, - int lda, - const float* cscValB, - const int* cscColPtrB, - const int* cscRowIndB, - const float* beta, - float* C, - int ldc, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseSgemmi( - handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); -#pragma GCC diagnostic pop -} -template <> -inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, - int m, - int n, - int k, - int nnz, - const double* alpha, - const double* A, - int lda, - const double* cscValB, - const int* cscColPtrB, - const int* cscRowIndB, - const double* beta, - double* C, - int ldc, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseDgemmi( - handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); -#pragma GCC diagnostic pop -} -/** @} */ - #if not defined CUDA_ENFORCE_LOWER and CUDA_VER_10_1_UP /** * @defgroup cusparse Create CSR operations @@ -593,8 +534,17 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - return cusparseSpMM( - handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, externalBuffer); + return cusparseSpMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_32F, + alg, + static_cast(externalBuffer)); } template <> inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, @@ -610,8 +560,17 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - return cusparseSpMM( - handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, externalBuffer); + return cusparseSpMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_64F, + alg, + static_cast(externalBuffer)); } /** @} */ #else @@ -687,6 +646,73 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle, /** @} */ #endif +/** + * @defgroup Gemmi cusparse gemmi operations + * @{ + */ +template +cusparseStatus_t cusparsegemmi( // NOLINT + cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const T* alpha, + const T* A, + int lda, + const T* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const T* beta, + T* C, + int ldc, + cudaStream_t stream) +{ + static_assert(std::is_same_v || std::is_same_v, "Unsupported data type"); + + cusparseDnMatDescr_t matA; + cusparseSpMatDescr_t matB; + cusparseDnMatDescr_t matC; + + auto math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; + // Create sparse matrix B + CUSPARSE_CHECK(cusparseCreateCsc(&matB, + k, + n, + nnz, + static_cast(const_cast(cscColPtrB)), + static_cast(const_cast(cscRowIndB)), + static_cast(const_cast(cscValB)), + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, + math_type)); + // Create dense matrices + CUSPARSE_CHECK(cusparseCreateDnMat( + &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_ROW)); + CUSPARSE_CHECK(cusparseCreateDnMat( + &matC, m, n, ldc, static_cast(const_cast(C)), math_type, CUSPARSE_ORDER_ROW)); + + cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; + cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG2; + size_t buffer_size = 0; + + CUSPARSE_CHECK(cusparsespmm_bufferSize( + handle, opA, opB, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); + buffer_size = buffer_size / sizeof(T); + rmm::device_uvector external_buffer(buffer_size, stream); + auto return_value = cusparsespmm( + handle, opA, opB, alpha, matB, matA, beta, matC, alg, external_buffer.data(), stream); + + // destroy matrix/vector descriptors + CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); + CUSPARSE_CHECK(cusparseDestroySpMat(matB)); + CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); + return return_value; +} +/** @} */ + /** * @defgroup csr2coo cusparse CSR to COO converter methods * @{ @@ -733,336 +759,6 @@ inline cusparseStatus_t cusparsesetpointermode(cusparseHandle_t handle, } /** @} */ -/** - * @defgroup CsrmvEx cusparse csrmvex operations - * @{ - */ -template -cusparseStatus_t cusparsecsrmvex_bufferSize(cusparseHandle_t handle, - cusparseAlgMode_t alg, - cusparseOperation_t transA, - int m, - int n, - int nnz, - const T* alpha, - const cusparseMatDescr_t descrA, - const T* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const T* x, - const T* beta, - T* y, - size_t* bufferSizeInBytes, - cudaStream_t stream); -template <> -inline cusparseStatus_t cusparsecsrmvex_bufferSize(cusparseHandle_t handle, - cusparseAlgMode_t alg, - cusparseOperation_t transA, - int m, - int n, - int nnz, - const float* alpha, - const cusparseMatDescr_t descrA, - const float* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const float* x, - const float* beta, - float* y, - size_t* bufferSizeInBytes, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - -#if CUDART_VERSION >= 11020 - cusparseSpMatDescr_t matA; - cusparsecreatecsr(&matA, - m, - n, - nnz, - const_cast(csrRowPtrA), - const_cast(csrColIndA), - const_cast(csrValA)); - - cusparseDnVecDescr_t vecX; - cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); - - cusparseDnVecDescr_t vecY; - cusparsecreatednvec(&vecY, static_cast(n), y); - - cusparseStatus_t result = cusparseSpMV_bufferSize(handle, - transA, - alpha, - matA, - vecX, - beta, - vecY, - CUDA_R_32F, - CUSPARSE_SPMV_ALG_DEFAULT, - bufferSizeInBytes); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); - return result; - -#else - - return cusparseCsrmvEx_bufferSize(handle, - alg, - transA, - m, - n, - nnz, - alpha, - CUDA_R_32F, - descrA, - csrValA, - CUDA_R_32F, - csrRowPtrA, - csrColIndA, - x, - CUDA_R_32F, - beta, - CUDA_R_32F, - y, - CUDA_R_32F, - CUDA_R_32F, - bufferSizeInBytes); -#endif -} -template <> -inline cusparseStatus_t cusparsecsrmvex_bufferSize(cusparseHandle_t handle, - cusparseAlgMode_t alg, - cusparseOperation_t transA, - int m, - int n, - int nnz, - const double* alpha, - const cusparseMatDescr_t descrA, - const double* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const double* x, - const double* beta, - double* y, - size_t* bufferSizeInBytes, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - -#if CUDART_VERSION >= 11020 - cusparseSpMatDescr_t matA; - cusparsecreatecsr(&matA, - m, - n, - nnz, - const_cast(csrRowPtrA), - const_cast(csrColIndA), - const_cast(csrValA)); - - cusparseDnVecDescr_t vecX; - cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); - - cusparseDnVecDescr_t vecY; - cusparsecreatednvec(&vecY, static_cast(n), y); - - cusparseStatus_t result = cusparseSpMV_bufferSize(handle, - transA, - alpha, - matA, - vecX, - beta, - vecY, - CUDA_R_64F, - CUSPARSE_SPMV_ALG_DEFAULT, - bufferSizeInBytes); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); - return result; -#else - return cusparseCsrmvEx_bufferSize(handle, - alg, - transA, - m, - n, - nnz, - alpha, - CUDA_R_64F, - descrA, - csrValA, - CUDA_R_64F, - csrRowPtrA, - csrColIndA, - x, - CUDA_R_64F, - beta, - CUDA_R_64F, - y, - CUDA_R_64F, - CUDA_R_64F, - bufferSizeInBytes); -#endif -} - -template -cusparseStatus_t cusparsecsrmvex(cusparseHandle_t handle, - cusparseAlgMode_t alg, - cusparseOperation_t transA, - int m, - int n, - int nnz, - const T* alpha, - const cusparseMatDescr_t descrA, - const T* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const T* x, - const T* beta, - T* y, - T* buffer, - cudaStream_t stream); -template <> -inline cusparseStatus_t cusparsecsrmvex(cusparseHandle_t handle, - cusparseAlgMode_t alg, - cusparseOperation_t transA, - int m, - int n, - int nnz, - const float* alpha, - const cusparseMatDescr_t descrA, - const float* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const float* x, - const float* beta, - float* y, - float* buffer, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - -#if CUDART_VERSION >= 11020 - cusparseSpMatDescr_t matA; - cusparsecreatecsr(&matA, - m, - n, - nnz, - const_cast(csrRowPtrA), - const_cast(csrColIndA), - const_cast(csrValA)); - - cusparseDnVecDescr_t vecX; - cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); - - cusparseDnVecDescr_t vecY; - cusparsecreatednvec(&vecY, static_cast(n), y); - - cusparseStatus_t result = cusparseSpMV( - handle, transA, alpha, matA, vecX, beta, vecY, CUDA_R_32F, CUSPARSE_SPMV_ALG_DEFAULT, buffer); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); - return result; -#else - return cusparseCsrmvEx(handle, - alg, - transA, - m, - n, - nnz, - alpha, - CUDA_R_32F, - descrA, - csrValA, - CUDA_R_32F, - csrRowPtrA, - csrColIndA, - x, - CUDA_R_32F, - beta, - CUDA_R_32F, - y, - CUDA_R_32F, - CUDA_R_32F, - buffer); -#endif -} -template <> -inline cusparseStatus_t cusparsecsrmvex(cusparseHandle_t handle, - cusparseAlgMode_t alg, - cusparseOperation_t transA, - int m, - int n, - int nnz, - const double* alpha, - const cusparseMatDescr_t descrA, - const double* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const double* x, - const double* beta, - double* y, - double* buffer, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - -#if CUDART_VERSION >= 11020 - cusparseSpMatDescr_t matA; - cusparsecreatecsr(&matA, - m, - n, - nnz, - const_cast(csrRowPtrA), - const_cast(csrColIndA), - const_cast(csrValA)); - - cusparseDnVecDescr_t vecX; - cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); - - cusparseDnVecDescr_t vecY; - cusparsecreatednvec(&vecY, static_cast(n), y); - - cusparseStatus_t result = cusparseSpMV( - handle, transA, alpha, matA, vecX, beta, vecY, CUDA_R_64F, CUSPARSE_SPMV_ALG_DEFAULT, buffer); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); - return result; - -#else - - return cusparseCsrmvEx(handle, - alg, - transA, - m, - n, - nnz, - alpha, - CUDA_R_64F, - descrA, - csrValA, - CUDA_R_64F, - csrRowPtrA, - csrColIndA, - x, - CUDA_R_64F, - beta, - CUDA_R_64F, - y, - CUDA_R_64F, - CUDA_R_64F, - buffer); -#endif -} - -/** @} */ - /** * @defgroup Csr2cscEx2 cusparse csr->csc conversion * @{ @@ -1247,340 +943,6 @@ inline cusparseStatus_t cusparsecsr2csc(cusparseHandle_t handle, /** @} */ -/** - * @defgroup csrgemm2 cusparse sparse gemm operations - * @{ - */ - -template -cusparseStatus_t cusparsecsrgemm2_buffersizeext(cusparseHandle_t handle, - int m, - int n, - int k, - const T* alpha, - const T* beta, - const cusparseMatDescr_t matA, - int nnzA, - const int* rowindA, - const int* indicesA, - const cusparseMatDescr_t matB, - int nnzB, - const int* rowindB, - const int* indicesB, - const cusparseMatDescr_t matD, - int nnzD, - const int* rowindD, - const int* indicesD, - csrgemm2Info_t info, - size_t* pBufferSizeInBytes, - cudaStream_t stream); - -template <> -inline cusparseStatus_t cusparsecsrgemm2_buffersizeext(cusparseHandle_t handle, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const cusparseMatDescr_t matA, - int nnzA, - const int* rowindA, - const int* indicesA, - const cusparseMatDescr_t matB, - int nnzB, - const int* rowindB, - const int* indicesB, - const cusparseMatDescr_t matD, - int nnzD, - const int* rowindD, - const int* indicesD, - csrgemm2Info_t info, - size_t* pBufferSizeInBytes, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseScsrgemm2_bufferSizeExt(handle, - m, - n, - k, - alpha, - matA, - nnzA, - rowindA, - indicesA, - matB, - nnzB, - rowindB, - indicesB, - beta, - matD, - nnzD, - rowindD, - indicesD, - info, - pBufferSizeInBytes); -#pragma GCC diagnostic pop -} - -template <> -inline cusparseStatus_t cusparsecsrgemm2_buffersizeext(cusparseHandle_t handle, - int m, - int n, - int k, - const double* alpha, - const double* beta, - const cusparseMatDescr_t matA, - int nnzA, - const int* rowindA, - const int* indicesA, - const cusparseMatDescr_t matB, - int nnzB, - const int* rowindB, - const int* indicesB, - const cusparseMatDescr_t matD, - int nnzD, - const int* rowindD, - const int* indicesD, - csrgemm2Info_t info, - size_t* pBufferSizeInBytes, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseDcsrgemm2_bufferSizeExt(handle, - m, - n, - k, - alpha, - matA, - nnzA, - rowindA, - indicesA, - matB, - nnzB, - rowindB, - indicesB, - beta, - matD, - nnzD, - rowindD, - indicesD, - info, - pBufferSizeInBytes); -#pragma GCC diagnostic pop -} - -inline cusparseStatus_t cusparsecsrgemm2nnz(cusparseHandle_t handle, - int m, - int n, - int k, - const cusparseMatDescr_t matA, - int nnzA, - const int* rowindA, - const int* indicesA, - const cusparseMatDescr_t matB, - int nnzB, - const int* rowindB, - const int* indicesB, - const cusparseMatDescr_t matD, - int nnzD, - const int* rowindD, - const int* indicesD, - const cusparseMatDescr_t matC, - int* rowindC, - int* nnzC, - const csrgemm2Info_t info, - void* pBuffer, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseXcsrgemm2Nnz(handle, - m, - n, - k, - matA, - nnzA, - rowindA, - indicesA, - matB, - nnzB, - rowindB, - indicesB, - matD, - nnzD, - rowindD, - indicesD, - matC, - rowindC, - nnzC, - info, - pBuffer); -#pragma GCC diagnostic pop -} - -template -cusparseStatus_t cusparsecsrgemm2(cusparseHandle_t handle, - int m, - int n, - int k, - const T* alpha, - const cusparseMatDescr_t descrA, - int nnzA, - const T* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const cusparseMatDescr_t descrB, - int nnzB, - const T* csrValB, - const int* csrRowPtrB, - const int* csrColIndB, - const T* beta, - const cusparseMatDescr_t descrD, - int nnzD, - const T* csrValD, - const int* csrRowPtrD, - const int* csrColIndD, - const cusparseMatDescr_t descrC, - T* csrValC, - const int* csrRowPtrC, - int* csrColIndC, - const csrgemm2Info_t info, - void* pBuffer, - cudaStream_t stream); - -template <> -inline cusparseStatus_t cusparsecsrgemm2(cusparseHandle_t handle, - int m, - int n, - int k, - const float* alpha, - const cusparseMatDescr_t descrA, - int nnzA, - const float* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const cusparseMatDescr_t descrB, - int nnzB, - const float* csrValB, - const int* csrRowPtrB, - const int* csrColIndB, - const float* beta, - const cusparseMatDescr_t descrD, - int nnzD, - const float* csrValD, - const int* csrRowPtrD, - const int* csrColIndD, - const cusparseMatDescr_t descrC, - float* csrValC, - const int* csrRowPtrC, - int* csrColIndC, - const csrgemm2Info_t info, - void* pBuffer, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseScsrgemm2(handle, - m, - n, - k, - alpha, - descrA, - nnzA, - csrValA, - csrRowPtrA, - csrColIndA, - descrB, - nnzB, - csrValB, - csrRowPtrB, - csrColIndB, - beta, - descrD, - nnzD, - csrValD, - csrRowPtrD, - csrColIndD, - descrC, - csrValC, - csrRowPtrC, - csrColIndC, - info, - pBuffer); -#pragma GCC diagnostic pop -} - -template <> -inline cusparseStatus_t cusparsecsrgemm2(cusparseHandle_t handle, - int m, - int n, - int k, - const double* alpha, - const cusparseMatDescr_t descrA, - int nnzA, - const double* csrValA, - const int* csrRowPtrA, - const int* csrColIndA, - const cusparseMatDescr_t descrB, - int nnzB, - const double* csrValB, - const int* csrRowPtrB, - const int* csrColIndB, - const double* beta, - const cusparseMatDescr_t descrD, - int nnzD, - const double* csrValD, - const int* csrRowPtrD, - const int* csrColIndD, - const cusparseMatDescr_t descrC, - double* csrValC, - const int* csrRowPtrC, - int* csrColIndC, - const csrgemm2Info_t info, - void* pBuffer, - cudaStream_t stream) -{ - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return cusparseDcsrgemm2(handle, - m, - n, - k, - alpha, - descrA, - nnzA, - csrValA, - csrRowPtrA, - csrColIndA, - descrB, - nnzB, - csrValB, - csrRowPtrB, - csrColIndB, - beta, - descrD, - nnzD, - csrValD, - csrRowPtrD, - csrColIndD, - descrC, - csrValC, - csrRowPtrC, - csrColIndC, - info, - pBuffer); -#pragma GCC diagnostic pop -} - -/** @} */ - /** * @defgroup csrgemm2 cusparse sparse gemm operations * @{ @@ -1815,4 +1177,4 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, } // namespace detail } // namespace sparse -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/spectral/detail/matrix_wrappers.hpp b/cpp/include/raft/spectral/detail/matrix_wrappers.hpp index 40388eea84..e4e028e9f0 100644 --- a/cpp/include/raft/spectral/detail/matrix_wrappers.hpp +++ b/cpp/include/raft/spectral/detail/matrix_wrappers.hpp @@ -282,9 +282,9 @@ struct sparse_matrix_t { cusparseSpMVAlg_t translate_algorithm(sparse_mv_alg_t alg) const { switch (alg) { - case sparse_mv_alg_t::SPARSE_MV_ALG1: return CUSPARSE_CSRMV_ALG1; - case sparse_mv_alg_t::SPARSE_MV_ALG2: return CUSPARSE_CSRMV_ALG2; - default: return CUSPARSE_MV_ALG_DEFAULT; + case sparse_mv_alg_t::SPARSE_MV_ALG1: return CUSPARSE_SPMV_CSR_ALG1; + case sparse_mv_alg_t::SPARSE_MV_ALG2: return CUSPARSE_SPMV_CSR_ALG2; + default: return CUSPARSE_SPMV_ALG_DEFAULT; } } #endif