From 8743ac934740a6ff306ac9069f1a914d236ab44c Mon Sep 17 00:00:00 2001 From: Micka Date: Wed, 24 Jan 2024 21:20:35 +0100 Subject: [PATCH] [BUG] Fix `SPMM` strided view (#2124) With the bug fix #2117 there can be an issue with `z_tmp` memory being uninitialized. SPMM formula is `Z = alpha . X * Y + beta . Z` so when `beta` is not zero, Z is being read. The proposed solution in this PR remove the need for an extra allocation and a copy from/to an external buffer, by creating a strided view of the original Z. Authors: - Micka (https://github.com/lowener) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2124 --- cpp/include/raft/sparse/linalg/spmm.hpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index 2812b6b325..03c97fdb9d 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -42,7 +42,7 @@ namespace linalg { * @param[in] x input raft::device_csr_matrix_view * @param[in] y input raft::device_matrix_view * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view + * @param[inout] z input-output raft::device_matrix_view */ template (handle, z.extent(0), z.extent(1)); auto z_tmp_view = raft::make_device_strided_matrix_view( - z_tmp.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); @@ -70,9 +69,6 @@ void spmm(raft::resources const& handle, detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); - raft::copy( - z.data_handle(), z_tmp.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z));