diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 70c78a81fb..0ab882e7a0 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -46,29 +46,53 @@ using layout_c_contiguous = detail::stdex::layout_right; */ using layout_f_contiguous = detail::stdex::layout_left; +template > +using mdspan = detail::stdex::mdspan; + +namespace detail { /** * @\brief Template checks and helpers to determine if type T is an std::mdspan * or a derived type */ template -void __takes_an_mdspan_ptr( - detail::stdex::mdspan*); +void __takes_an_mdspan_ptr(mdspan*); template -struct __is_mdspan : std::false_type { +struct is_mdspan : std::false_type { }; - template -struct __is_mdspan()))>> +struct is_mdspan()))>> : std::true_type { }; template -using __is_mdspan_t = __is_mdspan>; +using is_mdspan_t = is_mdspan>; template -inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; +inline constexpr bool is_mdspan_v = is_mdspan_t::value; +} // namespace detail + +template +struct is_mdspan : std::true_type { +}; +template +struct is_mdspan : detail::is_mdspan_t { +}; +template +struct is_mdspan + : std::conditional_t, is_mdspan, std::false_type> { +}; + +/** + * @\brief Boolean to determine if variadic template types Tn are either + * raft::host_mdspan/raft::device_mdspan or their derived types + */ +template +inline constexpr bool is_mdspan_v = is_mdspan::value; /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. @@ -77,7 +101,7 @@ template > -using device_mdspan = detail::stdex:: +using device_mdspan = mdspan>; /** @@ -88,47 +112,71 @@ template > using host_mdspan = - detail::stdex::mdspan>; + mdspan>; +namespace detail { template -struct __is_device_mdspan : std::false_type { +struct is_device_mdspan : std::false_type { }; - template -struct __is_device_mdspan : std::bool_constant { +struct is_device_mdspan : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type */ template -inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; +inline constexpr bool is_device_mdspan_v = is_device_mdspan>::value; template -struct __is_host_mdspan : std::false_type { +struct is_host_mdspan : std::false_type { }; - template -struct __is_host_mdspan : T::accessor_type::is_host_type { +struct is_host_mdspan : T::accessor_type::is_host_type { }; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type */ template -inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; +inline constexpr bool is_host_mdspan_v = is_host_mdspan>::value; +} // namespace detail + +template +struct is_device_mdspan : std::true_type { +}; +template +struct is_device_mdspan : detail::is_device_mdspan> { +}; +template +struct is_device_mdspan + : std::conditional_t, is_device_mdspan, std::false_type> { +}; /** - * @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan - * or their derived types - * This is structured such that it will short-circuit if the type is not std::mdspan - * or a derived type, and otherwise it will check whether it is a raft::device_mdspan - * or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type + * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a + * derived type */ -template -inline constexpr bool is_mdspan_v = - std::conjunction_v<__is_mdspan_t, - std::disjunction<__is_device_mdspan, __is_host_mdspan>>; +template +inline constexpr bool is_device_mdspan_v = is_device_mdspan::value; + +template +struct is_host_mdspan : std::true_type { +}; +template +struct is_host_mdspan : detail::is_host_mdspan> { +}; +template +struct is_host_mdspan + : std::conditional_t, is_host_mdspan, std::false_type> { +}; + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a + * derived type + */ +template +inline constexpr bool is_host_mdspan_v = is_host_mdspan::value; /** * @brief Interface to implement an owning multi-dimensional array @@ -152,22 +200,45 @@ class array_interface { auto view() const noexcept { return static_cast(this)->view(); } }; +namespace detail { template -struct __is_array_interface : std::false_type { +struct is_array_interface : std::false_type { }; - template -struct __is_array_interface().view())>> +struct is_array_interface().view())>> : std::bool_constant().view())>> { }; +template +using is_array_interface_t = is_array_interface>; + /** * @\brief Boolean to determine if template type T is raft::array_interface or derived type * or any type that has a member function `view()` that returns either * raft::host_mdspan or raft::device_mdspan */ template -inline constexpr bool is_array_interface_v = __is_array_interface>::value; +inline constexpr bool is_array_interface_v = is_array_interface>::value; +} // namespace detail + +template +struct is_array_interface : std::true_type { +}; +template +struct is_array_interface : detail::is_array_interface_t { +}; +template +struct is_array_interface : std::conditional_t, + is_array_interface, + std::false_type> { +}; +/** + * @\brief Boolean to determine if variadic template types Tn are raft::array_interface + * or derived type or any type that has a member function `view()` that returns either + * raft::host_mdspan or raft::device_mdspan + */ +template +inline constexpr bool is_array_interface_v = is_array_interface::value; /** * @brief Modified from the c++ mdarray proposal diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 15388a5cef..7e7812b7a6 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -64,6 +64,10 @@ void test_template_asserts() "device_matrix_view type not a host_mdspan"); static_assert(is_host_mdspan_v>, "host_matrix_view type is a host_mdspan"); + + // checking variadics + static_assert(!is_mdspan_v>, "variadics mdspans"); + static_assert(is_mdspan_v, "variadics not mdspans"); } TEST(MDSpan, TemplateAsserts) { test_template_asserts(); }