Skip to content

Commit

Permalink
Fixing small bug in CUSPARSE spmm w/ CUDA 12.2 (#2117)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #2117
  • Loading branch information
cjnolet authored Jan 24, 2024
1 parent 0586fc3 commit 70e806a
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions cpp/include/raft/sparse/linalg/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,19 @@ void spmm(raft::resources const& handle,
{
bool is_row_major = detail::is_row_major(y, z);

auto z_tmp = raft::make_device_matrix<ValueType, IndexType>(handle, z.extent(0), z.extent(1));
auto z_tmp_view = raft::make_device_strided_matrix_view<ValueType, IndexType, LayoutPolicyZ>(
z_tmp.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1));

auto descr_x = detail::create_descriptor(x);
auto descr_y = detail::create_descriptor(y);
auto descr_z = detail::create_descriptor(z);
auto descr_z = detail::create_descriptor(z_tmp_view);

detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z);

raft::copy(
z.data_handle(), z_tmp.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle));

RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z));
Expand All @@ -76,4 +83,4 @@ void spmm(raft::resources const& handle,
} // end namespace sparse
} // end namespace raft

#endif
#endif

0 comments on commit 70e806a

Please sign in to comment.