Skip to content

Commit

Permalink
Fix spmm strided view
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Jan 24, 2024
1 parent 9c35f73 commit 2bfd0da
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions cpp/include/raft/sparse/linalg/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace linalg {
* @param[in] x input raft::device_csr_matrix_view
* @param[in] y input raft::device_matrix_view
* @param[in] beta scalar
* @param[out] z output raft::device_matrix_view
* @param[inout] z input-output raft::device_matrix_view
*/
template <typename ValueType,
typename IndexType,
Expand All @@ -60,9 +60,12 @@ void spmm(raft::resources const& handle,
{
bool is_row_major = detail::is_row_major(y, z);

auto z_tmp_view = raft::make_device_strided_matrix_view<ValueType, IndexType, LayoutPolicyZ>(
z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1));

auto descr_x = detail::create_descriptor(x);
auto descr_y = detail::create_descriptor(y);
auto descr_z = detail::create_descriptor(z);
auto descr_z = detail::create_descriptor(z_tmp_view);

detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z);

Expand Down

0 comments on commit 2bfd0da

Please sign in to comment.