Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix template types for create_descriptor function. #1680

Merged
38 changes: 19 additions & 19 deletions cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,21 @@ inline void cusparsecoosortByRow( // NOLINT
* @defgroup cusparse Create CSR operations
* @{
*/
template <typename IndexT, typename ValueT>
template <typename ValueT, typename IndptrType, typename IndicesType>
cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,
int64_t rows,
int64_t cols,
int64_t nnz,
IndexT* csrRowOffsets,
IndexT* csrColInd,
IndptrType* csrRowOffsets,
IndicesType* csrColInd,
ValueT* csrValues);
template <>
inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,
int64_t rows,
int64_t cols,
int64_t nnz,
int* csrRowOffsets,
int* csrColInd,
int32_t* csrRowOffsets,
int32_t* csrColInd,
float* csrValues)
{
return cusparseCreateCsr(spMatDescr,
Expand All @@ -188,8 +188,8 @@ inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr,
int64_t rows,
int64_t cols,
int64_t nnz,
int* csrRowOffsets,
int* csrColInd,
int32_t* csrRowOffsets,
int32_t* csrColInd,
double* csrValues)
{
return cusparseCreateCsr(spMatDescr,
Expand Down Expand Up @@ -1058,9 +1058,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle,

cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<float*>(csrValA));
Expand Down Expand Up @@ -1107,9 +1107,9 @@ inline cusparseStatus_t cusparsecsr2dense_buffersize(cusparseHandle_t handle,
cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<double*>(csrValA));
Expand Down Expand Up @@ -1173,9 +1173,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle,
cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<float*>(csrValA));
Expand Down Expand Up @@ -1220,9 +1220,9 @@ inline cusparseStatus_t cusparsecsr2dense(cusparseHandle_t handle,
cusparseOrder_t order = row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
cusparseSpMatDescr_t matA;
cusparsecreatecsr(&matA,
m,
n,
nnz,
static_cast<int64_t>(m),
static_cast<int64_t>(n),
static_cast<int64_t>(nnz),
const_cast<int*>(csrRowPtrA),
const_cast<int*>(csrColIndA),
const_cast<double*>(csrValA));
Expand Down
16 changes: 9 additions & 7 deletions cpp/include/raft/sparse/linalg/detail/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,25 @@ cusparseDnMatDescr_t create_descriptor(
/**
* @brief create a cuSparse sparse descriptor
* @tparam ValueType Data type of sparse_view (float/double)
* @tparam IndptrType Data type of csr_matrix_view index pointers
* @tparam IndicesType Data type of csr_matrix_view indices
* @tparam NZType Type of sparse_view
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
* @returns sparse matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename NZType>
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
cusparseSpMatDescr_t create_descriptor(
raft::device_csr_matrix_view<ValueType, int, int, NZType>& sparse_view)
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
{
cusparseSpMatDescr_t descr;
auto csr_structure = sparse_view.structure_view();
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
&descr,
csr_structure.get_n_rows(),
csr_structure.get_n_cols(),
csr_structure.get_nnz(),
const_cast<int*>(csr_structure.get_indptr().data()),
const_cast<int*>(csr_structure.get_indices().data()),
static_cast<int64_t>(csr_structure.get_n_rows()),
static_cast<int64_t>(csr_structure.get_n_cols()),
static_cast<int64_t>(csr_structure.get_nnz()),
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
const_cast<IndicesType*>(csr_structure.get_indices().data()),
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
return descr;
}
Expand Down