diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index c8e4229203..041991521b 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -17,59 +17,47 @@ #pragma once #include -#include #include -#include +#include namespace raft { namespace sparse { namespace detail { /** - * @defgroup gather cusparse gather methods + * @defgroup gthr cusparse gather methods * @{ */ -inline cusparseStatus_t cusparsegather(cusparseHandle_t handle, - cusparseDnVecDescr_t vecY, - cusparseSpVecDescr_t vecX, - cudaStream_t stream) +template +cusparseStatus_t cusparsegthr( + cusparseHandle_t handle, int nnz, const T* vals, T* vals_sorted, int* d_P, cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsegthr(cusparseHandle_t handle, + int nnz, + const double* vals, + double* vals_sorted, + int* d_P, + cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - return cusparseGather(handle, vecY, vecX); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseDgthr(handle, nnz, vals, vals_sorted, d_P, CUSPARSE_INDEX_BASE_ZERO); +#pragma GCC diagnostic pop } - -template < - typename T, - typename std::enable_if_t || std::is_same_v>* = nullptr> -cusparseStatus_t cusparsegthr( - cusparseHandle_t handle, int nnz, const T* vals, T* vals_sorted, int* d_P, cudaStream_t stream) +template <> +inline cusparseStatus_t cusparsegthr(cusparseHandle_t handle, + int nnz, + const float* vals, + float* vals_sorted, + int* d_P, + cudaStream_t stream) { - auto constexpr float_type = []() constexpr - { - if constexpr (std::is_same_v) { - return CUDA_R_32F; - } else if constexpr (std::is_same_v) { - return CUDA_R_64F; - } - } - (); CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - auto dense_vector_descr = cusparseDnVecDescr_t{}; - auto sparse_vector_descr = cusparseSpVecDescr_t{}; - CUSPARSE_CHECK(cusparseCreateDnVec( - &dense_vector_descr, nnz, static_cast(const_cast(vals)), float_type)); - CUSPARSE_CHECK(cusparseCreateSpVec(&sparse_vector_descr, - nnz, - nnz, - static_cast(d_P), - static_cast(vals_sorted), - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_BASE_ZERO, - float_type)); - auto return_value = cusparseGather(handle, dense_vector_descr, sparse_vector_descr); - CUSPARSE_CHECK(cusparseDestroyDnVec(dense_vector_descr)); - CUSPARSE_CHECK(cusparseDestroySpVec(sparse_vector_descr)); - return return_value; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseSgthr(handle, nnz, vals, vals_sorted, d_P, CUSPARSE_INDEX_BASE_ZERO); +#pragma GCC diagnostic pop } /** @} */ @@ -150,6 +138,77 @@ inline void cusparsecoosortByRow( // NOLINT } /** @} */ +/** + * @defgroup Gemmi cusparse gemmi operations + * @{ + */ +template +cusparseStatus_t cusparsegemmi( // NOLINT + cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const T* alpha, + const T* A, + int lda, + const T* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const T* beta, + T* C, + int ldc, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const float* alpha, + const float* A, + int lda, + const float* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const float* beta, + float* C, + int ldc, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseSgemmi( + handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); +#pragma GCC diagnostic pop +} +template <> +inline cusparseStatus_t cusparsegemmi(cusparseHandle_t handle, + int m, + int n, + int k, + int nnz, + const double* alpha, + const double* A, + int lda, + const double* cscValB, + const int* cscColPtrB, + const int* cscRowIndB, + const double* beta, + double* C, + int ldc, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseDgemmi( + handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, cscRowIndB, beta, C, ldc); +#pragma GCC diagnostic pop +} +/** @} */ + #if not defined CUDA_ENFORCE_LOWER and CUDA_VER_10_1_UP /** * @defgroup cusparse Create CSR operations @@ -534,17 +593,8 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - return cusparseSpMM(handle, - opA, - opB, - static_cast(alpha), - matA, - matB, - static_cast(beta), - matC, - CUDA_R_32F, - alg, - static_cast(externalBuffer)); + return cusparseSpMM( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, externalBuffer); } template <> inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, @@ -560,17 +610,8 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - return cusparseSpMM(handle, - opA, - opB, - static_cast(alpha), - matA, - matB, - static_cast(beta), - matC, - CUDA_R_64F, - alg, - static_cast(externalBuffer)); + return cusparseSpMM( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, externalBuffer); } /** @} */ #else @@ -646,73 +687,6 @@ inline cusparseStatus_t cusparsecsrmm(cusparseHandle_t handle, /** @} */ #endif -/** - * @defgroup Gemmi cusparse gemmi operations - * @{ - */ -template -cusparseStatus_t cusparsegemmi( // NOLINT - cusparseHandle_t handle, - int m, - int n, - int k, - int nnz, - const T* alpha, - const T* A, - int lda, - const T* cscValB, - const int* cscColPtrB, - const int* cscRowIndB, - const T* beta, - T* C, - int ldc, - cudaStream_t stream) -{ - static_assert(std::is_same_v || std::is_same_v, "Unsupported data type"); - - cusparseDnMatDescr_t matA; - cusparseSpMatDescr_t matB; - cusparseDnMatDescr_t matC; - - auto math_type = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; - // Create sparse matrix B - CUSPARSE_CHECK(cusparseCreateCsc(&matB, - k, - n, - nnz, - static_cast(const_cast(cscColPtrB)), - static_cast(const_cast(cscRowIndB)), - static_cast(const_cast(cscValB)), - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_BASE_ZERO, - math_type)); - // Create dense matrices - CUSPARSE_CHECK(cusparseCreateDnMat( - &matA, m, k, lda, static_cast(const_cast(A)), math_type, CUSPARSE_ORDER_ROW)); - CUSPARSE_CHECK(cusparseCreateDnMat( - &matC, m, n, ldc, static_cast(const_cast(C)), math_type, CUSPARSE_ORDER_ROW)); - - cusparseOperation_t opA = CUSPARSE_OPERATION_TRANSPOSE; - cusparseOperation_t opB = CUSPARSE_OPERATION_TRANSPOSE; - cusparseSpMMAlg_t alg = CUSPARSE_SPMM_CSR_ALG2; - size_t buffer_size = 0; - - CUSPARSE_CHECK(cusparsespmm_bufferSize( - handle, opA, opB, alpha, matB, matA, beta, matC, alg, &buffer_size, stream)); - buffer_size = buffer_size / sizeof(T); - rmm::device_uvector external_buffer(buffer_size, stream); - auto return_value = cusparsespmm( - handle, opA, opB, alpha, matB, matA, beta, matC, alg, external_buffer.data(), stream); - - // destroy matrix/vector descriptors - CUSPARSE_CHECK(cusparseDestroyDnMat(matA)); - CUSPARSE_CHECK(cusparseDestroySpMat(matB)); - CUSPARSE_CHECK(cusparseDestroyDnMat(matC)); - return return_value; -} -/** @} */ - /** * @defgroup csr2coo cusparse CSR to COO converter methods * @{ @@ -759,6 +733,336 @@ inline cusparseStatus_t cusparsesetpointermode(cusparseHandle_t handle, } /** @} */ +/** + * @defgroup CsrmvEx cusparse csrmvex operations + * @{ + */ +template +cusparseStatus_t cusparsecsrmvex_bufferSize(cusparseHandle_t handle, + cusparseAlgMode_t alg, + cusparseOperation_t transA, + int m, + int n, + int nnz, + const T* alpha, + const cusparseMatDescr_t descrA, + const T* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const T* x, + const T* beta, + T* y, + size_t* bufferSizeInBytes, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsecsrmvex_bufferSize(cusparseHandle_t handle, + cusparseAlgMode_t alg, + cusparseOperation_t transA, + int m, + int n, + int nnz, + const float* alpha, + const cusparseMatDescr_t descrA, + const float* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const float* x, + const float* beta, + float* y, + size_t* bufferSizeInBytes, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + +#if CUDART_VERSION >= 11020 + cusparseSpMatDescr_t matA; + cusparsecreatecsr(&matA, + m, + n, + nnz, + const_cast(csrRowPtrA), + const_cast(csrColIndA), + const_cast(csrValA)); + + cusparseDnVecDescr_t vecX; + cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); + + cusparseDnVecDescr_t vecY; + cusparsecreatednvec(&vecY, static_cast(n), y); + + cusparseStatus_t result = cusparseSpMV_bufferSize(handle, + transA, + alpha, + matA, + vecX, + beta, + vecY, + CUDA_R_32F, + CUSPARSE_SPMV_ALG_DEFAULT, + bufferSizeInBytes); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); + return result; + +#else + + return cusparseCsrmvEx_bufferSize(handle, + alg, + transA, + m, + n, + nnz, + alpha, + CUDA_R_32F, + descrA, + csrValA, + CUDA_R_32F, + csrRowPtrA, + csrColIndA, + x, + CUDA_R_32F, + beta, + CUDA_R_32F, + y, + CUDA_R_32F, + CUDA_R_32F, + bufferSizeInBytes); +#endif +} +template <> +inline cusparseStatus_t cusparsecsrmvex_bufferSize(cusparseHandle_t handle, + cusparseAlgMode_t alg, + cusparseOperation_t transA, + int m, + int n, + int nnz, + const double* alpha, + const cusparseMatDescr_t descrA, + const double* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const double* x, + const double* beta, + double* y, + size_t* bufferSizeInBytes, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + +#if CUDART_VERSION >= 11020 + cusparseSpMatDescr_t matA; + cusparsecreatecsr(&matA, + m, + n, + nnz, + const_cast(csrRowPtrA), + const_cast(csrColIndA), + const_cast(csrValA)); + + cusparseDnVecDescr_t vecX; + cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); + + cusparseDnVecDescr_t vecY; + cusparsecreatednvec(&vecY, static_cast(n), y); + + cusparseStatus_t result = cusparseSpMV_bufferSize(handle, + transA, + alpha, + matA, + vecX, + beta, + vecY, + CUDA_R_64F, + CUSPARSE_SPMV_ALG_DEFAULT, + bufferSizeInBytes); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); + return result; +#else + return cusparseCsrmvEx_bufferSize(handle, + alg, + transA, + m, + n, + nnz, + alpha, + CUDA_R_64F, + descrA, + csrValA, + CUDA_R_64F, + csrRowPtrA, + csrColIndA, + x, + CUDA_R_64F, + beta, + CUDA_R_64F, + y, + CUDA_R_64F, + CUDA_R_64F, + bufferSizeInBytes); +#endif +} + +template +cusparseStatus_t cusparsecsrmvex(cusparseHandle_t handle, + cusparseAlgMode_t alg, + cusparseOperation_t transA, + int m, + int n, + int nnz, + const T* alpha, + const cusparseMatDescr_t descrA, + const T* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const T* x, + const T* beta, + T* y, + T* buffer, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsecsrmvex(cusparseHandle_t handle, + cusparseAlgMode_t alg, + cusparseOperation_t transA, + int m, + int n, + int nnz, + const float* alpha, + const cusparseMatDescr_t descrA, + const float* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const float* x, + const float* beta, + float* y, + float* buffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + +#if CUDART_VERSION >= 11020 + cusparseSpMatDescr_t matA; + cusparsecreatecsr(&matA, + m, + n, + nnz, + const_cast(csrRowPtrA), + const_cast(csrColIndA), + const_cast(csrValA)); + + cusparseDnVecDescr_t vecX; + cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); + + cusparseDnVecDescr_t vecY; + cusparsecreatednvec(&vecY, static_cast(n), y); + + cusparseStatus_t result = cusparseSpMV( + handle, transA, alpha, matA, vecX, beta, vecY, CUDA_R_32F, CUSPARSE_SPMV_ALG_DEFAULT, buffer); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); + return result; +#else + return cusparseCsrmvEx(handle, + alg, + transA, + m, + n, + nnz, + alpha, + CUDA_R_32F, + descrA, + csrValA, + CUDA_R_32F, + csrRowPtrA, + csrColIndA, + x, + CUDA_R_32F, + beta, + CUDA_R_32F, + y, + CUDA_R_32F, + CUDA_R_32F, + buffer); +#endif +} +template <> +inline cusparseStatus_t cusparsecsrmvex(cusparseHandle_t handle, + cusparseAlgMode_t alg, + cusparseOperation_t transA, + int m, + int n, + int nnz, + const double* alpha, + const cusparseMatDescr_t descrA, + const double* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const double* x, + const double* beta, + double* y, + double* buffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + +#if CUDART_VERSION >= 11020 + cusparseSpMatDescr_t matA; + cusparsecreatecsr(&matA, + m, + n, + nnz, + const_cast(csrRowPtrA), + const_cast(csrColIndA), + const_cast(csrValA)); + + cusparseDnVecDescr_t vecX; + cusparsecreatednvec(&vecX, static_cast(n), const_cast(x)); + + cusparseDnVecDescr_t vecY; + cusparsecreatednvec(&vecY, static_cast(n), y); + + cusparseStatus_t result = cusparseSpMV( + handle, transA, alpha, matA, vecX, beta, vecY, CUDA_R_64F, CUSPARSE_SPMV_ALG_DEFAULT, buffer); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(matA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecX)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnVec(vecY)); + return result; + +#else + + return cusparseCsrmvEx(handle, + alg, + transA, + m, + n, + nnz, + alpha, + CUDA_R_64F, + descrA, + csrValA, + CUDA_R_64F, + csrRowPtrA, + csrColIndA, + x, + CUDA_R_64F, + beta, + CUDA_R_64F, + y, + CUDA_R_64F, + CUDA_R_64F, + buffer); +#endif +} + +/** @} */ + /** * @defgroup Csr2cscEx2 cusparse csr->csc conversion * @{ @@ -943,6 +1247,340 @@ inline cusparseStatus_t cusparsecsr2csc(cusparseHandle_t handle, /** @} */ +/** + * @defgroup csrgemm2 cusparse sparse gemm operations + * @{ + */ + +template +cusparseStatus_t cusparsecsrgemm2_buffersizeext(cusparseHandle_t handle, + int m, + int n, + int k, + const T* alpha, + const T* beta, + const cusparseMatDescr_t matA, + int nnzA, + const int* rowindA, + const int* indicesA, + const cusparseMatDescr_t matB, + int nnzB, + const int* rowindB, + const int* indicesB, + const cusparseMatDescr_t matD, + int nnzD, + const int* rowindD, + const int* indicesD, + csrgemm2Info_t info, + size_t* pBufferSizeInBytes, + cudaStream_t stream); + +template <> +inline cusparseStatus_t cusparsecsrgemm2_buffersizeext(cusparseHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const cusparseMatDescr_t matA, + int nnzA, + const int* rowindA, + const int* indicesA, + const cusparseMatDescr_t matB, + int nnzB, + const int* rowindB, + const int* indicesB, + const cusparseMatDescr_t matD, + int nnzD, + const int* rowindD, + const int* indicesD, + csrgemm2Info_t info, + size_t* pBufferSizeInBytes, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseScsrgemm2_bufferSizeExt(handle, + m, + n, + k, + alpha, + matA, + nnzA, + rowindA, + indicesA, + matB, + nnzB, + rowindB, + indicesB, + beta, + matD, + nnzD, + rowindD, + indicesD, + info, + pBufferSizeInBytes); +#pragma GCC diagnostic pop +} + +template <> +inline cusparseStatus_t cusparsecsrgemm2_buffersizeext(cusparseHandle_t handle, + int m, + int n, + int k, + const double* alpha, + const double* beta, + const cusparseMatDescr_t matA, + int nnzA, + const int* rowindA, + const int* indicesA, + const cusparseMatDescr_t matB, + int nnzB, + const int* rowindB, + const int* indicesB, + const cusparseMatDescr_t matD, + int nnzD, + const int* rowindD, + const int* indicesD, + csrgemm2Info_t info, + size_t* pBufferSizeInBytes, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseDcsrgemm2_bufferSizeExt(handle, + m, + n, + k, + alpha, + matA, + nnzA, + rowindA, + indicesA, + matB, + nnzB, + rowindB, + indicesB, + beta, + matD, + nnzD, + rowindD, + indicesD, + info, + pBufferSizeInBytes); +#pragma GCC diagnostic pop +} + +inline cusparseStatus_t cusparsecsrgemm2nnz(cusparseHandle_t handle, + int m, + int n, + int k, + const cusparseMatDescr_t matA, + int nnzA, + const int* rowindA, + const int* indicesA, + const cusparseMatDescr_t matB, + int nnzB, + const int* rowindB, + const int* indicesB, + const cusparseMatDescr_t matD, + int nnzD, + const int* rowindD, + const int* indicesD, + const cusparseMatDescr_t matC, + int* rowindC, + int* nnzC, + const csrgemm2Info_t info, + void* pBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseXcsrgemm2Nnz(handle, + m, + n, + k, + matA, + nnzA, + rowindA, + indicesA, + matB, + nnzB, + rowindB, + indicesB, + matD, + nnzD, + rowindD, + indicesD, + matC, + rowindC, + nnzC, + info, + pBuffer); +#pragma GCC diagnostic pop +} + +template +cusparseStatus_t cusparsecsrgemm2(cusparseHandle_t handle, + int m, + int n, + int k, + const T* alpha, + const cusparseMatDescr_t descrA, + int nnzA, + const T* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const cusparseMatDescr_t descrB, + int nnzB, + const T* csrValB, + const int* csrRowPtrB, + const int* csrColIndB, + const T* beta, + const cusparseMatDescr_t descrD, + int nnzD, + const T* csrValD, + const int* csrRowPtrD, + const int* csrColIndD, + const cusparseMatDescr_t descrC, + T* csrValC, + const int* csrRowPtrC, + int* csrColIndC, + const csrgemm2Info_t info, + void* pBuffer, + cudaStream_t stream); + +template <> +inline cusparseStatus_t cusparsecsrgemm2(cusparseHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const cusparseMatDescr_t descrA, + int nnzA, + const float* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const cusparseMatDescr_t descrB, + int nnzB, + const float* csrValB, + const int* csrRowPtrB, + const int* csrColIndB, + const float* beta, + const cusparseMatDescr_t descrD, + int nnzD, + const float* csrValD, + const int* csrRowPtrD, + const int* csrColIndD, + const cusparseMatDescr_t descrC, + float* csrValC, + const int* csrRowPtrC, + int* csrColIndC, + const csrgemm2Info_t info, + void* pBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseScsrgemm2(handle, + m, + n, + k, + alpha, + descrA, + nnzA, + csrValA, + csrRowPtrA, + csrColIndA, + descrB, + nnzB, + csrValB, + csrRowPtrB, + csrColIndB, + beta, + descrD, + nnzD, + csrValD, + csrRowPtrD, + csrColIndD, + descrC, + csrValC, + csrRowPtrC, + csrColIndC, + info, + pBuffer); +#pragma GCC diagnostic pop +} + +template <> +inline cusparseStatus_t cusparsecsrgemm2(cusparseHandle_t handle, + int m, + int n, + int k, + const double* alpha, + const cusparseMatDescr_t descrA, + int nnzA, + const double* csrValA, + const int* csrRowPtrA, + const int* csrColIndA, + const cusparseMatDescr_t descrB, + int nnzB, + const double* csrValB, + const int* csrRowPtrB, + const int* csrColIndB, + const double* beta, + const cusparseMatDescr_t descrD, + int nnzD, + const double* csrValD, + const int* csrRowPtrD, + const int* csrColIndD, + const cusparseMatDescr_t descrC, + double* csrValC, + const int* csrRowPtrC, + int* csrColIndC, + const csrgemm2Info_t info, + void* pBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + return cusparseDcsrgemm2(handle, + m, + n, + k, + alpha, + descrA, + nnzA, + csrValA, + csrRowPtrA, + csrColIndA, + descrB, + nnzB, + csrValB, + csrRowPtrB, + csrColIndB, + beta, + descrD, + nnzD, + csrValD, + csrRowPtrD, + csrColIndD, + descrC, + csrValC, + csrRowPtrC, + csrColIndC, + info, + pBuffer); +#pragma GCC diagnostic pop +} + +/** @} */ + /** * @defgroup csrgemm2 cusparse sparse gemm operations * @{ @@ -1177,4 +1815,4 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, } // namespace detail } // namespace sparse -} // namespace raft +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/spectral/detail/matrix_wrappers.hpp b/cpp/include/raft/spectral/detail/matrix_wrappers.hpp index e4e028e9f0..40388eea84 100644 --- a/cpp/include/raft/spectral/detail/matrix_wrappers.hpp +++ b/cpp/include/raft/spectral/detail/matrix_wrappers.hpp @@ -282,9 +282,9 @@ struct sparse_matrix_t { cusparseSpMVAlg_t translate_algorithm(sparse_mv_alg_t alg) const { switch (alg) { - case sparse_mv_alg_t::SPARSE_MV_ALG1: return CUSPARSE_SPMV_CSR_ALG1; - case sparse_mv_alg_t::SPARSE_MV_ALG2: return CUSPARSE_SPMV_CSR_ALG2; - default: return CUSPARSE_SPMV_ALG_DEFAULT; + case sparse_mv_alg_t::SPARSE_MV_ALG1: return CUSPARSE_CSRMV_ALG1; + case sparse_mv_alg_t::SPARSE_MV_ALG2: return CUSPARSE_CSRMV_ALG2; + default: return CUSPARSE_MV_ALG_DEFAULT; } } #endif