From dc755907675e36e026535cf3bb48dc19d2e42470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Mon, 12 Feb 2024 20:58:02 +0100 Subject: [PATCH] fix is_row/col_order for strided layouts (#2173) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add additional constraint to is_row/col_major check. Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) --- cpp/include/raft/util/input_validation.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index 1977b45281..17bb53f22b 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -39,8 +39,7 @@ constexpr bool is_row_or_column_major(mdspan constexpr bool is_row_or_column_major(mdspan m) { - return m.stride(0) == typename Extents::index_type(1) || - m.stride(1) == typename Extents::index_type(1); + return is_row_major(m) || is_col_major(m); } template @@ -64,7 +63,7 @@ constexpr bool is_row_major(mdspan template constexpr bool is_row_major(mdspan m) { - return m.stride(1) == typename Extents::index_type(1); + return m.stride(1) == typename Extents::index_type(1) && m.stride(0) >= m.extent(1); } template @@ -88,7 +87,7 @@ constexpr bool is_col_major(mdspan template constexpr bool is_col_major(mdspan m) { - return m.stride(0) == typename Extents::index_type(1); + return m.stride(0) == typename Extents::index_type(1) && m.stride(1) >= m.extent(0); } template