diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index 1977b45281..17bb53f22b 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -39,8 +39,7 @@ constexpr bool is_row_or_column_major(mdspan constexpr bool is_row_or_column_major(mdspan m) { - return m.stride(0) == typename Extents::index_type(1) || - m.stride(1) == typename Extents::index_type(1); + return is_row_major(m) || is_col_major(m); } template @@ -64,7 +63,7 @@ constexpr bool is_row_major(mdspan template constexpr bool is_row_major(mdspan m) { - return m.stride(1) == typename Extents::index_type(1); + return m.stride(1) == typename Extents::index_type(1) && m.stride(0) >= m.extent(1); } template @@ -88,7 +87,7 @@ constexpr bool is_col_major(mdspan template constexpr bool is_col_major(mdspan m) { - return m.stride(0) == typename Extents::index_type(1); + return m.stride(0) == typename Extents::index_type(1) && m.stride(1) >= m.extent(0); } template