Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Helpers for identifying contiguous layouts. #1861

Merged
merged 10 commits into from
Oct 12, 2023
35 changes: 35 additions & 0 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,41 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,

/** @} */

/**
* @defgroup mdspan_contiguous Whether the strides imply a contiguous layout.
* @{
*/

/**
* @brief Whether the strides imply a c-contiguous layout.
*/
template <typename Extents, typename Strides>
[[nodiscard]] auto is_c_contiguous(Extents const& extents, Strides const& strides) -> bool
{
typename Extents::index_type stride = 1;
for (auto r = extents.rank(); r > 0; r--) {
if (stride != strides[r - 1]) { return false; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your test case, in the first iteration isn't stride = 1 and strides[r - 1] = 8 where r = 3? I don't understand how this doesn't exit in the first iteration itself

stride *= extents.extent(r - 1);
}
return true;
}

/**
* @brief Whether the strides imply a f-contiguous layout.
*/
template <typename Extents, typename Strides>
[[nodiscard]] auto is_f_contiguous(Extents const& extents, Strides const& strides) -> bool
{
typename Extents::index_type stride = 1;
for (typename Extents::rank_type r = 0; r < extents.rank(); r++) {
if (stride != strides[r]) { return false; }
stride *= extents.extent(r);
}
return true;
}

/** @} */

/**
* @brief Const accessor specialization for default_accessor
*
Expand Down
25 changes: 25 additions & 0 deletions cpp/test/core/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,29 @@ void test_const_mdspan()

TEST(MDSpan, ConstMDSpan) { test_const_mdspan(); }

void test_contiguous_predicates()
{
raft::resources handle;
extents<std::int64_t, dynamic_extent, dynamic_extent, dynamic_extent> exts{4, 4, 4};

{
std::array<std::int64_t, 3> strides{16, 4, 1};
ASSERT_TRUE(is_c_contiguous(exts, strides));
ASSERT_FALSE(is_f_contiguous(exts, strides));

// ensure that we are using the same stride unit (elements v.s. bytes) as mdarray
auto arr = make_host_mdarray<float>(handle, exts);
for (std::int32_t i = 0; i < 3; ++i) {
auto s = arr.stride(i);
ASSERT_EQ(s, strides[i]);
}
}
{
std::array<std::int64_t, 3> strides{1, 4, 16};
ASSERT_FALSE(is_c_contiguous(exts, strides));
ASSERT_TRUE(is_f_contiguous(exts, strides));
}
}

TEST(MDArray, Contiguous) { test_contiguous_predicates(); }
} // namespace raft
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
}

# Config numpydoc
Expand Down
5 changes: 5 additions & 0 deletions docs/source/cpp_api/mdspan_mdspan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ mdspan: Multi-dimensional Non-owning View
:members:
:content-only:

.. doxygengroup:: mdspan_contiguous
:project: RAFT
:members:
:content-only:

.. doxygengroup:: mdspan_make_const
:project: RAFT
:members:
Expand Down