Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing raft::mdspan as an alias #715

Merged
merged 5 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 101 additions & 30 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,53 @@ using layout_c_contiguous = detail::stdex::layout_right;
*/
using layout_f_contiguous = detail::stdex::layout_left;

template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using mdspan = detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>;

namespace detail {
/**
* @\brief Template checks and helpers to determine if type T is an std::mdspan
* or a derived type
*/

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void __takes_an_mdspan_ptr(
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*);
void __takes_an_mdspan_ptr(mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*);

template <typename T, typename = void>
struct __is_mdspan : std::false_type {
struct is_mdspan : std::false_type {
};

template <typename T>
struct __is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
struct is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::true_type {
};

template <typename T>
using __is_mdspan_t = __is_mdspan<std::remove_const_t<T>>;
using is_mdspan_t = is_mdspan<std::remove_const_t<T>>;

template <typename T>
inline constexpr bool __is_mdspan_v = __is_mdspan_t<T>::value;
inline constexpr bool is_mdspan_v = is_mdspan_t<T>::value;
} // namespace detail

template <typename...>
struct is_mdspan : std::true_type {
};
template <typename T1>
struct is_mdspan<T1> : detail::is_mdspan_t<T1> {
};
template <typename T1, typename... Tn>
struct is_mdspan<T1, Tn...>
: std::conditional_t<detail::is_mdspan_v<T1>, is_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if variadic template types Tn are either
* raft::host_mdspan/raft::device_mdspan or their derived types
*/
template <typename... Tn>
inline constexpr bool is_mdspan_v = is_mdspan<Tn...>::value;

/**
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location.
Expand All @@ -77,7 +101,7 @@ template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using device_mdspan = detail::stdex::
using device_mdspan =
mdspan<ElementType, Extents, LayoutPolicy, detail::device_accessor<AccessorPolicy>>;

/**
Expand All @@ -88,47 +112,71 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using host_mdspan =
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;
mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;

namespace detail {
template <typename T, bool B>
struct __is_device_mdspan : std::false_type {
struct is_device_mdspan : std::false_type {
};

template <typename T>
struct __is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
struct is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_device_mdspan_v = __is_device_mdspan<T, __is_mdspan_v<T>>::value;
inline constexpr bool is_device_mdspan_v = is_device_mdspan<T, is_mdspan_v<T>>::value;

template <typename T, bool B>
struct __is_host_mdspan : std::false_type {
struct is_host_mdspan : std::false_type {
};

template <typename T>
struct __is_host_mdspan<T, true> : T::accessor_type::is_host_type {
struct is_host_mdspan<T, true> : T::accessor_type::is_host_type {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_host_mdspan_v = __is_host_mdspan<T, __is_mdspan_v<T>>::value;
inline constexpr bool is_host_mdspan_v = is_host_mdspan<T, is_mdspan_v<T>>::value;
} // namespace detail

template <typename...>
struct is_device_mdspan : std::true_type {
};
template <typename T1>
struct is_device_mdspan<T1> : detail::is_device_mdspan<T1, detail::is_mdspan_v<T1>> {
};
template <typename T1, typename... Tn>
struct is_device_mdspan<T1, Tn...>
: std::conditional_t<detail::is_device_mdspan_v<T1>, is_device_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan
* or their derived types
* This is structured such that it will short-circuit if the type is not std::mdspan
* or a derived type, and otherwise it will check whether it is a raft::device_mdspan
* or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
*/
template <typename T>
inline constexpr bool is_mdspan_v =
std::conjunction_v<__is_mdspan_t<T>,
std::disjunction<__is_device_mdspan<T, true>, __is_host_mdspan<T, true>>>;
template <typename... Tn>
inline constexpr bool is_device_mdspan_v = is_device_mdspan<Tn...>::value;

template <typename...>
struct is_host_mdspan : std::true_type {
};
template <typename T1>
struct is_host_mdspan<T1> : detail::is_host_mdspan<T1, detail::is_mdspan_v<T1>> {
};
template <typename T1, typename... Tn>
struct is_host_mdspan<T1, Tn...>
: std::conditional_t<detail::is_host_mdspan_v<T1>, is_host_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_host_mdspan_v = is_host_mdspan<Tn...>::value;

/**
* @brief Interface to implement an owning multi-dimensional array
Expand All @@ -152,22 +200,45 @@ class array_interface {
auto view() const noexcept { return static_cast<Base*>(this)->view(); }
};

namespace detail {
template <typename T, typename = void>
struct __is_array_interface : std::false_type {
struct is_array_interface : std::false_type {
};

template <typename T>
struct __is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>>
struct is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>>
: std::bool_constant<is_mdspan_v<decltype(std::declval<T>().view())>> {
};

template <typename T>
using is_array_interface_t = is_array_interface<std::remove_const_t<T>>;

/**
* @\brief Boolean to determine if template type T is raft::array_interface or derived type
* or any type that has a member function `view()` that returns either
* raft::host_mdspan or raft::device_mdspan
*/
template <typename T>
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value;
inline constexpr bool is_array_interface_v = is_array_interface<std::remove_const_t<T>>::value;
} // namespace detail

template <typename...>
struct is_array_interface : std::true_type {
};
template <typename T1>
struct is_array_interface<T1> : detail::is_array_interface_t<T1> {
};
template <typename T1, typename... Tn>
struct is_array_interface<T1, Tn...> : std::conditional_t<detail::is_array_interface_v<T1>,
is_array_interface<Tn...>,
std::false_type> {
};
/**
* @\brief Boolean to determine if variadic template types Tn are raft::array_interface
* or derived type or any type that has a member function `view()` that returns either
* raft::host_mdspan or raft::device_mdspan
*/
template <typename... Tn>
inline constexpr bool is_array_interface_v = is_array_interface<Tn...>::value;

/**
* @brief Modified from the c++ mdarray proposal
Expand Down
4 changes: 4 additions & 0 deletions cpp/test/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void test_template_asserts()
"device_matrix_view type not a host_mdspan");
static_assert(is_host_mdspan_v<host_matrix_view<float>>,
"host_matrix_view type is a host_mdspan");

// checking variadics
static_assert(!is_mdspan_v<three_d_mdspan, std::vector<int>>, "variadics mdspans");
static_assert(is_mdspan_v<three_d_mdspan, d_mdspan>, "variadics not mdspans");
}

TEST(MDSpan, TemplateAsserts) { test_template_asserts(); }
Expand Down