diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index c2fdd64574..59f0bcef81 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -42,7 +42,7 @@ namespace linalg { * @param[in] x input raft::device_csr_matrix_view * @param[in] y input raft::device_matrix_view * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view + * @param[inout] z input-output raft::device_matrix_view */ template ( + z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); - auto descr_z = detail::create_descriptor(z); + auto descr_z = detail::create_descriptor(z_tmp_view); detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z);