diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index f1a1adb916..15414b9af3 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -312,6 +312,41 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, /** @} */ +/** + * @defgroup mdspan_contiguous Whether the strides imply a contiguous layout. + * @{ + */ + +/** + * @brief Whether the strides imply a c-contiguous layout. + */ +template +[[nodiscard]] auto is_c_contiguous(Extents const& extents, Strides const& strides) -> bool +{ + typename Extents::index_type stride = 1; + for (auto r = extents.rank(); r > 0; r--) { + if (stride != strides[r - 1]) { return false; } + stride *= extents.extent(r - 1); + } + return true; +} + +/** + * @brief Whether the strides imply a f-contiguous layout. + */ +template +[[nodiscard]] auto is_f_contiguous(Extents const& extents, Strides const& strides) -> bool +{ + typename Extents::index_type stride = 1; + for (typename Extents::rank_type r = 0; r < extents.rank(); r++) { + if (stride != strides[r]) { return false; } + stride *= extents.extent(r); + } + return true; +} + +/** @} */ + /** * @brief Const accessor specialization for default_accessor * diff --git a/cpp/test/core/mdspan_utils.cu b/cpp/test/core/mdspan_utils.cu index ad212569c2..90ecba1a10 100644 --- a/cpp/test/core/mdspan_utils.cu +++ b/cpp/test/core/mdspan_utils.cu @@ -247,4 +247,29 @@ void test_const_mdspan() TEST(MDSpan, ConstMDSpan) { test_const_mdspan(); } +void test_contiguous_predicates() +{ + raft::resources handle; + extents exts{4, 4, 4}; + + { + std::array strides{16, 4, 1}; + ASSERT_TRUE(is_c_contiguous(exts, strides)); + ASSERT_FALSE(is_f_contiguous(exts, strides)); + + // ensure that we are using the same stride unit (elements v.s. bytes) as mdarray + auto arr = make_host_mdarray(handle, exts); + for (std::int32_t i = 0; i < 3; ++i) { + auto s = arr.stride(i); + ASSERT_EQ(s, strides[i]); + } + } + { + std::array strides{1, 4, 16}; + ASSERT_FALSE(is_c_contiguous(exts, strides)); + ASSERT_TRUE(is_f_contiguous(exts, strides)); + } +} + +TEST(MDArray, Contiguous) { test_contiguous_predicates(); } } // namespace raft diff --git a/docs/source/conf.py b/docs/source/conf.py index 822bce12fc..6c523cceb5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -174,7 +174,7 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { "python": ("https://docs.python.org/", None), - "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), } # Config numpydoc diff --git a/docs/source/cpp_api/mdspan_mdspan.rst b/docs/source/cpp_api/mdspan_mdspan.rst index 6011a9f103..f9f972aa74 100644 --- a/docs/source/cpp_api/mdspan_mdspan.rst +++ b/docs/source/cpp_api/mdspan_mdspan.rst @@ -24,6 +24,11 @@ mdspan: Multi-dimensional Non-owning View :members: :content-only: +.. doxygengroup:: mdspan_contiguous + :project: RAFT + :members: + :content-only: + .. doxygengroup:: mdspan_make_const :project: RAFT :members: