diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 3b6165b86a..7988bd3f6f 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -207,8 +207,12 @@ auto constexpr make_device_strided_matrix_view(ElementType* ptr, IndexType stride) { constexpr auto is_row_major = std::is_same_v; - IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; - IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); + constexpr auto is_col_major = std::is_same_v; + + assert(is_row_major || is_col_major); + + IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; + IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows); matrix_extent extents{n_rows, n_cols}; diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index 03c97fdb9d..1e815ba521 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -60,8 +60,11 @@ void spmm(raft::resources const& handle, { bool is_row_major = detail::is_row_major(y, z); - auto z_tmp_view = raft::make_device_strided_matrix_view( - z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + auto z_tmp_view = + is_row_major ? raft::make_device_strided_matrix_view( + z.data_handle(), z.extent(0), z.extent(1), z.stride(0)) + : raft::make_device_strided_matrix_view( + z.data_handle(), z.extent(0), z.extent(1), z.stride(1)); auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); @@ -79,4 +82,4 @@ void spmm(raft::resources const& handle, } // end namespace sparse } // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index ab5264f900..1977b45281 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,8 @@ constexpr bool is_row_or_column_major(mdspan constexpr bool is_row_or_column_major(mdspan m) { - return m.is_exhaustive(); + return m.stride(0) == typename Extents::index_type(1) || + m.stride(1) == typename Extents::index_type(1); } template @@ -63,7 +64,7 @@ constexpr bool is_row_major(mdspan template constexpr bool is_row_major(mdspan m) { - return m.is_exhaustive() && m.stride(1) == typename Extents::index_type(1); + return m.stride(1) == typename Extents::index_type(1); } template @@ -87,7 +88,7 @@ constexpr bool is_col_major(mdspan template constexpr bool is_col_major(mdspan m) { - return m.is_exhaustive() && m.stride(0) == typename Extents::index_type(1); + return m.stride(0) == typename Extents::index_type(1); } template