From 70e806ad2b73bd0cd44e23ec775c1ef430d03253 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 24 Jan 2024 03:03:29 +0100 Subject: [PATCH] Fixing small bug in CUSPARSE spmm w/ CUDA 12.2 (#2117) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/2117 --- cpp/include/raft/sparse/linalg/spmm.hpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index c2fdd64574..2812b6b325 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -60,12 +60,19 @@ void spmm(raft::resources const& handle, { bool is_row_major = detail::is_row_major(y, z); + auto z_tmp = raft::make_device_matrix(handle, z.extent(0), z.extent(1)); + auto z_tmp_view = raft::make_device_strided_matrix_view( + z_tmp.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); - auto descr_z = detail::create_descriptor(z); + auto descr_z = detail::create_descriptor(z_tmp_view); detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); + raft::copy( + z.data_handle(), z_tmp.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); @@ -76,4 +83,4 @@ void spmm(raft::resources const& handle, } // end namespace sparse } // end namespace raft -#endif +#endif \ No newline at end of file