From 2483c4a26f8dfe07103d8b0d89f741eca91e37e9 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 7 Feb 2024 16:23:47 +0000 Subject: [PATCH 1/2] fix test --- cpp/include/raft/core/device_mdspan.hpp | 8 ++++++-- cpp/include/raft/sparse/linalg/spmm.hpp | 10 +++++++--- cpp/include/raft/util/input_validation.hpp | 9 +++++---- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 3b6165b86a..7988bd3f6f 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -207,8 +207,12 @@ auto constexpr make_device_strided_matrix_view(ElementType* ptr, IndexType stride) { constexpr auto is_row_major = std::is_same_v; - IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; - IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); + constexpr auto is_col_major = std::is_same_v; + + assert(is_row_major || is_col_major); + + IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; + IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows); matrix_extent extents{n_rows, n_cols}; diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index 03c97fdb9d..e6789fda3c 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -60,8 +60,12 @@ void spmm(raft::resources const& handle, { bool is_row_major = detail::is_row_major(y, z); - auto z_tmp_view = raft::make_device_strided_matrix_view( - z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + auto z_tmp_view = + is_row_major + ? raft::make_device_strided_matrix_view( + z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)) + : raft::make_device_strided_matrix_view( + z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); @@ -79,4 +83,4 @@ void spmm(raft::resources const& handle, } // end namespace sparse } // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index ab5264f900..1977b45281 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,8 @@ constexpr bool is_row_or_column_major(mdspan constexpr bool is_row_or_column_major(mdspan m) { - return m.is_exhaustive(); + return m.stride(0) == typename Extents::index_type(1) || + m.stride(1) == typename Extents::index_type(1); } template @@ -63,7 +64,7 @@ constexpr bool is_row_major(mdspan template constexpr bool is_row_major(mdspan m) { - return m.is_exhaustive() && m.stride(1) == typename Extents::index_type(1); + return m.stride(1) == typename Extents::index_type(1); } template @@ -87,7 +88,7 @@ constexpr bool is_col_major(mdspan template constexpr bool is_col_major(mdspan m) { - return m.is_exhaustive() && m.stride(0) == typename Extents::index_type(1); + return m.stride(0) == typename Extents::index_type(1); } template From 2e5ab617cfcddba0ebb1822d651aeea2d56b9459 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 7 Feb 2024 16:51:30 +0000 Subject: [PATCH 2/2] review suggestion --- cpp/include/raft/sparse/linalg/spmm.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index e6789fda3c..1e815ba521 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -61,11 +61,10 @@ void spmm(raft::resources const& handle, bool is_row_major = detail::is_row_major(y, z); auto z_tmp_view = - is_row_major - ? raft::make_device_strided_matrix_view( - z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)) - : raft::make_device_strided_matrix_view( - z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + is_row_major ? raft::make_device_strided_matrix_view( + z.data_handle(), z.extent(0), z.extent(1), z.stride(0)) + : raft::make_device_strided_matrix_view( + z.data_handle(), z.extent(0), z.extent(1), z.stride(1)); auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y);