From 80283d44f29b79b205769ddfc89ae103516a87d1 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 24 Jan 2024 14:20:27 +0100 Subject: [PATCH] Fix spmm strided view --- cpp/include/raft/sparse/linalg/spmm.hpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index 2812b6b325..03c97fdb9d 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -42,7 +42,7 @@ namespace linalg { * @param[in] x input raft::device_csr_matrix_view * @param[in] y input raft::device_matrix_view * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view + * @param[inout] z input-output raft::device_matrix_view */ template (handle, z.extent(0), z.extent(1)); auto z_tmp_view = raft::make_device_strided_matrix_view( - z_tmp.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); @@ -70,9 +69,6 @@ void spmm(raft::resources const& handle, detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); - raft::copy( - z.data_handle(), z_tmp.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z));