From 7fb0995adab6b59587c3ff1a2fb84785a009da98 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 12 Feb 2024 11:26:30 +0000 Subject: [PATCH] fix is_row/col_order for strided layouts --- cpp/include/raft/util/input_validation.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index 1977b45281..17bb53f22b 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -39,8 +39,7 @@ constexpr bool is_row_or_column_major(mdspan constexpr bool is_row_or_column_major(mdspan m) { - return m.stride(0) == typename Extents::index_type(1) || - m.stride(1) == typename Extents::index_type(1); + return is_row_major(m) || is_col_major(m); } template @@ -64,7 +63,7 @@ constexpr bool is_row_major(mdspan template constexpr bool is_row_major(mdspan m) { - return m.stride(1) == typename Extents::index_type(1); + return m.stride(1) == typename Extents::index_type(1) && m.stride(0) >= m.extent(1); } template @@ -88,7 +87,7 @@ constexpr bool is_col_major(mdspan template constexpr bool is_col_major(mdspan m) { - return m.stride(0) == typename Extents::index_type(1); + return m.stride(0) == typename Extents::index_type(1) && m.stride(1) >= m.extent(0); } template