From bad4b75ea97dd6df11399e14bfeb9e50028dae61 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Wed, 26 Jul 2023 12:55:51 -0700 Subject: [PATCH 1/4] Fix template types for cuSparse create_descriptor function. --- cpp/include/raft/sparse/linalg/detail/spmm.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index 4ad8623076..9c8b2fc1e6 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -77,13 +77,15 @@ cusparseDnMatDescr_t create_descriptor( /** * @brief create a cuSparse sparse descriptor * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices * @tparam NZType Type of sparse_view * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns * @returns sparse matrix descriptor to be used by cuSparse API */ -template +template cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) + raft::device_csr_matrix_view& sparse_view) { cusparseSpMatDescr_t descr; auto csr_structure = sparse_view.structure_view(); @@ -92,8 +94,8 @@ cusparseSpMatDescr_t create_descriptor( csr_structure.get_n_rows(), csr_structure.get_n_cols(), csr_structure.get_nnz(), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), const_cast*>(sparse_view.get_elements().data()))); return descr; } From d0f6078f1504d7da2088c4c888eaa52a855d7801 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Mon, 31 Jul 2023 16:51:14 -0700 Subject: [PATCH 2/4] Use int32_t instead of int for consistency. --- cpp/include/raft/sparse/detail/cusparse_wrappers.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 0740e2ab8c..fbe921a809 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -167,8 +167,8 @@ inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, int64_t nnz, - int* csrRowOffsets, - int* csrColInd, + int32_t* csrRowOffsets, + int32_t* csrColInd, float* csrValues) { return cusparseCreateCsr(spMatDescr, @@ -188,8 +188,8 @@ inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, int64_t nnz, - int* csrRowOffsets, - int* csrColInd, + int32_t* csrRowOffsets, + int32_t* csrColInd, double* csrValues) { return cusparseCreateCsr(spMatDescr, From 2bd07e674b147711e6c14cabd7dfaf25d77caa70 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Mon, 31 Jul 2023 16:51:53 -0700 Subject: [PATCH 3/4] Explicitly cast rows, cols, and nnz for descriptor creation. --- .../raft/sparse/detail/cusparse_wrappers.h | 24 +++++++++---------- .../raft/sparse/linalg/detail/spmm.hpp | 6 ++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index fbe921a809..38e1a3fb8d 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -1058,9 +1058,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle, cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); @@ -1107,9 +1107,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle, cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); @@ -1173,9 +1173,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); @@ -1220,9 +1220,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index 9c8b2fc1e6..d8d73ee83f 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -91,9 +91,9 @@ cusparseSpMatDescr_t create_descriptor( auto csr_structure = sparse_view.structure_view(); RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( &descr, - csr_structure.get_n_rows(), - csr_structure.get_n_cols(), - csr_structure.get_nnz(), + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), const_cast(csr_structure.get_indptr().data()), const_cast(csr_structure.get_indices().data()), const_cast*>(sparse_view.get_elements().data()))); From a0e74d84ad8662f663c444fb1beea75dffc3cc2d Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Mon, 31 Jul 2023 16:58:19 -0700 Subject: [PATCH 4/4] Distinguish between Indptr- and IndicesType in cusparsecreatecsr. Similar to the calling functions. --- cpp/include/raft/sparse/detail/cusparse_wrappers.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 38e1a3fb8d..e8bf9c6de5 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -154,13 +154,13 @@ inline void cusparsecoosortByRow( // NOLINT * @defgroup cusparse Create CSR operations * @{ */ -template +template cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, int64_t nnz, - IndexT* csrRowOffsets, - IndexT* csrColInd, + IndptrType* csrRowOffsets, + IndicesType* csrColInd, ValueT* csrValues); template <> inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,