diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index c2fdd64574..2812b6b325 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -60,12 +60,19 @@ void spmm(raft::resources const& handle, { bool is_row_major = detail::is_row_major(y, z); + auto z_tmp = raft::make_device_matrix(handle, z.extent(0), z.extent(1)); + auto z_tmp_view = raft::make_device_strided_matrix_view( + z_tmp.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); - auto descr_z = detail::create_descriptor(z); + auto descr_z = detail::create_descriptor(z_tmp_view); detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); + raft::copy( + z.data_handle(), z_tmp.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); @@ -76,4 +83,4 @@ void spmm(raft::resources const& handle, } // end namespace sparse } // end namespace raft -#endif +#endif \ No newline at end of file