diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 0740e2ab8c..e8bf9c6de5 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -154,21 +154,21 @@ inline void cusparsecoosortByRow( // NOLINT * @defgroup cusparse Create CSR operations * @{ */ -template +template cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, int64_t nnz, - IndexT* csrRowOffsets, - IndexT* csrColInd, + IndptrType* csrRowOffsets, + IndicesType* csrColInd, ValueT* csrValues); template <> inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, int64_t nnz, - int* csrRowOffsets, - int* csrColInd, + int32_t* csrRowOffsets, + int32_t* csrColInd, float* csrValues) { return cusparseCreateCsr(spMatDescr, @@ -188,8 +188,8 @@ inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, int64_t nnz, - int* csrRowOffsets, - int* csrColInd, + int32_t* csrRowOffsets, + int32_t* csrColInd, double* csrValues) { return cusparseCreateCsr(spMatDescr, @@ -1058,9 +1058,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle, cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); @@ -1107,9 +1107,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle, cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); @@ -1173,9 +1173,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); @@ -1220,9 +1220,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle, cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; cusparseSpMatDescr_t matA; cusparsecreatecsr(&matA, - m, - n, - nnz, + static_cast(m), + static_cast(n), + static_cast(nnz), const_cast(csrRowPtrA), const_cast(csrColIndA), const_cast(csrValA)); diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index 4ad8623076..d8d73ee83f 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -77,23 +77,25 @@ cusparseDnMatDescr_t create_descriptor( /** * @brief create a cuSparse sparse descriptor * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices * @tparam NZType Type of sparse_view * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns * @returns sparse matrix descriptor to be used by cuSparse API */ -template +template cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) + raft::device_csr_matrix_view& sparse_view) { cusparseSpMatDescr_t descr; auto csr_structure = sparse_view.structure_view(); RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( &descr, - csr_structure.get_n_rows(), - csr_structure.get_n_cols(), - csr_structure.get_nnz(), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), const_cast*>(sparse_view.get_elements().data()))); return descr; }