Skip to content

Commit

Permalink
Fix template types for create_descriptor function. (#1680)
Browse files Browse the repository at this point in the history
Use correct types for indices pointers and indices when creating the cusparse
descriptor from a device_csr_matrix_view.

Authors:
  - Simon Adorf (https://github.com/csadorf)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1680
  • Loading branch information
csadorf authored Aug 10, 2023
1 parent bc8850a commit b7d2a9a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
38 changes: 19 additions & 19 deletions cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,21 @@ inline void cusparsecoosortByRow( // NOLINT
* @defgroup cusparse Create CSR operations
* @{
*/
template <typename IndexT, typename ValueT>
template <typename ValueT, typename IndptrType, typename IndicesType>
cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,
int64_t rows,
int64_t cols,
int64_t nnz,
IndexT* csrRowOffsets,
IndexT* csrColInd,
IndptrType* csrRowOffsets,
IndicesType* csrColInd,
ValueT* csrValues);
template <>
inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,
int64_t rows,
int64_t cols,
int64_t nnz,
int* csrRowOffsets,
int* csrColInd,
int32_t* csrRowOffsets,
int32_t* csrColInd,
float* csrValues)
{
return cusparseCreateCsr(spMatDescr,
Expand All @@ -188,8 +188,8 @@ inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,
int64_t rows,
int64_t cols,
int64_t nnz,
int* csrRowOffsets,
int* csrColInd,
int32_t* csrRowOffsets,
int32_t* csrColInd,
double* csrValues)
{
return cusparseCreateCsr(spMatDescr,
Expand Down Expand Up @@ -1058,9 +1058,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle,

cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<float*>(csrValA));
Expand Down Expand Up @@ -1107,9 +1107,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle,
cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<double*>(csrValA));
Expand Down Expand Up @@ -1173,9 +1173,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle,
cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<float*>(csrValA));
Expand Down Expand Up @@ -1220,9 +1220,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle,
cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<double*>(csrValA));
Expand Down
16 changes: 9 additions & 7 deletions cpp/include/raft/sparse/linalg/detail/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,25 @@ cusparseDnMatDescr_t create_descriptor(
/**
* @brief create a cuSparse sparse descriptor
* @tparam ValueType Data type of sparse_view (float/double)
* @tparam IndptrType Data type of csr_matrix_view index pointers
* @tparam IndicesType Data type of csr_matrix_view indices
* @tparam NZType Type of sparse_view
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
* @returns sparse matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename NZType>
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
cusparseSpMatDescr_t create_descriptor(
raft::device_csr_matrix_view<ValueType, int, int, NZType>& sparse_view)
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
{
cusparseSpMatDescr_t descr;
auto csr_structure = sparse_view.structure_view();
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
&descr,
csr_structure.get_n_rows(),
csr_structure.get_n_cols(),
csr_structure.get_nnz(),
const_cast<int*>(csr_structure.get_indptr().data()),
const_cast<int*>(csr_structure.get_indices().data()),
static_cast<int64_t>(csr_structure.get_n_rows()),
static_cast<int64_t>(csr_structure.get_n_cols()),
static_cast<int64_t>(csr_structure.get_nnz()),
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
const_cast<IndicesType*>(csr_structure.get_indices().data()),
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
return descr;
}
Expand Down

0 comments on commit b7d2a9a

Please sign in to comment.