From edc26042b570165b659398b7de8d34a563c30ef4 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 3 May 2022 10:42:38 -0700 Subject: [PATCH 1/3] variadic template types for mdspan/mdarray --- cpp/include/raft/core/mdarray.hpp | 125 ++++++++++++++++++++++++------ cpp/test/mdspan_utils.cu | 4 + 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index ab6a04587a..f50301c038 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -46,6 +46,7 @@ using layout_c_contiguous = detail::stdex::layout_right; */ using layout_f_contiguous = detail::stdex::layout_left; +namespace detail { /** * @\brief Template checks and helpers to determine if type T is an std::mdspan * or a derived type @@ -58,7 +59,6 @@ void __takes_an_mdspan_ptr( template struct __is_mdspan : std::false_type { }; - template struct __is_mdspan()))>> : std::true_type { @@ -70,6 +70,21 @@ using __is_mdspan_t = __is_mdspan>; template inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; +template +struct is_mdspan : std::true_type { +}; +template +struct is_mdspan : __is_mdspan_t { +}; +template +struct is_mdspan + : std::conditional_t<__is_mdspan_v, is_mdspan, std::false_type> { +}; + +template +inline constexpr bool is_mdspan_v = is_mdspan::value; +} // namespace detail + /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. */ @@ -90,45 +105,88 @@ template >; +namespace detail { template -struct __is_device_mdspan : std::false_type { +struct is_device_mdspan : std::false_type { }; - template -struct __is_device_mdspan : std::bool_constant { +struct is_device_mdspan : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type */ template -inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; +inline constexpr bool is_device_mdspan_v = is_device_mdspan>::value; template -struct __is_host_mdspan : std::false_type { +struct is_host_mdspan : std::false_type { }; - template -struct __is_host_mdspan : T::accessor_type::is_host_type { +struct is_host_mdspan : T::accessor_type::is_host_type { }; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type */ template -inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; +inline constexpr bool is_host_mdspan_v = is_host_mdspan>::value; +} // namespace detail + +template +struct is_device_mdspan : std::true_type { +}; +template +struct is_device_mdspan : detail::is_device_mdspan> { +}; +template +struct is_device_mdspan + : std::conditional_t, is_device_mdspan, std::false_type> { +}; /** - * @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan - * or their derived types - * This is structured such that it will short-circuit if the type is not std::mdspan - * or a derived type, and otherwise it will check whether it is a raft::device_mdspan - * or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type + * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a + * derived type */ -template -inline constexpr bool is_mdspan_v = - std::conjunction_v<__is_mdspan_t, - std::disjunction<__is_device_mdspan, __is_host_mdspan>>; +template +inline constexpr bool is_device_mdspan_v = is_device_mdspan::value; + +template +struct is_host_mdspan : std::true_type { +}; +template +struct is_host_mdspan : detail::is_host_mdspan> { +}; +template +struct is_host_mdspan + : std::conditional_t, is_host_mdspan, std::false_type> { +}; + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a + * derived type + */ +template +inline constexpr bool is_host_mdspan_v = is_host_mdspan::value; + +template +struct is_mdspan : std::true_type { +}; +template +struct is_mdspan : std::disjunction, is_host_mdspan> { +}; +template +struct is_mdspan + : std::conditional_t, is_host_mdspan>, + is_mdspan, + std::false_type> { +}; +/** + * @\brief Boolean to determine if variadic template types Tn are either + * raft::host_mdspan/raft::device_mdspan or their derived types + */ +template +inline constexpr bool is_mdspan_v = is_mdspan::value; /** * @brief Interface to implement an owning multi-dimensional array @@ -152,22 +210,45 @@ class array_interface { auto view() const noexcept { return static_cast(this)->view(); } }; +namespace detail { template -struct __is_array_interface : std::false_type { +struct is_array_interface : std::false_type { }; - template -struct __is_array_interface().view())>> +struct is_array_interface().view())>> : std::bool_constant().view())>> { }; +template +using is_array_interface_t = is_array_interface>; + /** * @\brief Boolean to determine if template type T is raft::array_interface or derived type * or any type that has a member function `view()` that returns either * raft::host_mdspan or raft::device_mdspan */ template -inline constexpr bool is_array_interface_v = __is_array_interface>::value; +inline constexpr bool is_array_interface_v = is_array_interface>::value; +} // namespace detail + +template +struct is_array_interface : std::true_type { +}; +template +struct is_array_interface : detail::is_array_interface_t { +}; +template +struct is_array_interface : std::conditional_t, + is_array_interface, + std::false_type> { +}; +/** + * @\brief Boolean to determine if variadic template types Tn are raft::array_interface + * or derived type or any type that has a member function `view()` that returns either + * raft::host_mdspan or raft::device_mdspan + */ +template +inline constexpr bool is_array_interface_v = is_array_interface::value; /** * @brief Modified from the c++ mdarray proposal diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 15388a5cef..7e7812b7a6 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -64,6 +64,10 @@ void test_template_asserts() "device_matrix_view type not a host_mdspan"); static_assert(is_host_mdspan_v>, "host_matrix_view type is a host_mdspan"); + + // checking variadics + static_assert(!is_mdspan_v>, "variadics mdspans"); + static_assert(is_mdspan_v, "variadics not mdspans"); } TEST(MDSpan, TemplateAsserts) { test_template_asserts(); } From 6d9620a4bed7d394755d0e9777945503d079b0e7 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 16 Jun 2022 14:39:46 -0700 Subject: [PATCH 2/3] introducting raft::mdspan alias --- cpp/include/raft/core/mdarray.hpp | 44 ++++++++++++------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index f50301c038..6e5a5426a5 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -46,6 +46,13 @@ using layout_c_contiguous = detail::stdex::layout_right; */ using layout_f_contiguous = detail::stdex::layout_left; +template > +using mdspan = detail::stdex:: + mdspan + namespace detail { /** * @\brief Template checks and helpers to determine if type T is an std::mdspan @@ -54,7 +61,7 @@ namespace detail { template void __takes_an_mdspan_ptr( - detail::stdex::mdspan*); + mdspan*); template struct __is_mdspan : std::false_type { @@ -69,21 +76,25 @@ using __is_mdspan_t = __is_mdspan>; template inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; +} // namespace detail template struct is_mdspan : std::true_type { }; template -struct is_mdspan : __is_mdspan_t { +struct is_mdspan : detail::__is_mdspan_t { }; template struct is_mdspan - : std::conditional_t<__is_mdspan_v, is_mdspan, std::false_type> { + : std::conditional_t, is_mdspan, std::false_type> { }; +/** + * @\brief Boolean to determine if variadic template types Tn are either + * raft::host_mdspan/raft::device_mdspan or their derived types + */ template inline constexpr bool is_mdspan_v = is_mdspan::value; -} // namespace detail /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. @@ -92,8 +103,7 @@ template > -using device_mdspan = detail::stdex:: - mdspan>; +using device_mdspan = mdspan>; /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. @@ -102,8 +112,7 @@ template > -using host_mdspan = - detail::stdex::mdspan>; +using host_mdspan = mdspan>; namespace detail { template @@ -169,25 +178,6 @@ struct is_host_mdspan template inline constexpr bool is_host_mdspan_v = is_host_mdspan::value; -template -struct is_mdspan : std::true_type { -}; -template -struct is_mdspan : std::disjunction, is_host_mdspan> { -}; -template -struct is_mdspan - : std::conditional_t, is_host_mdspan>, - is_mdspan, - std::false_type> { -}; -/** - * @\brief Boolean to determine if variadic template types Tn are either - * raft::host_mdspan/raft::device_mdspan or their derived types - */ -template -inline constexpr bool is_mdspan_v = is_mdspan::value; - /** * @brief Interface to implement an owning multi-dimensional array * From fb6e952d9005a6da81855ffebc62dfe98c61dbeb Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 16 Jun 2022 14:51:07 -0700 Subject: [PATCH 3/3] fixing compile failures --- cpp/include/raft/core/mdarray.hpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index a4b1f46c2c..0ab882e7a0 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -49,9 +49,8 @@ using layout_f_contiguous = detail::stdex::layout_left; template > -using mdspan = detail::stdex:: - mdspan + typename AccessorPolicy = detail::stdex::default_accessor> +using mdspan = detail::stdex::mdspan; namespace detail { /** @@ -60,33 +59,32 @@ namespace detail { */ template -void __takes_an_mdspan_ptr( - mdspan*); +void __takes_an_mdspan_ptr(mdspan*); template -struct __is_mdspan : std::false_type { +struct is_mdspan : std::false_type { }; template -struct __is_mdspan()))>> +struct is_mdspan()))>> : std::true_type { }; template -using __is_mdspan_t = __is_mdspan>; +using is_mdspan_t = is_mdspan>; template -inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; -} // namespace detail +inline constexpr bool is_mdspan_v = is_mdspan_t::value; +} // namespace detail template struct is_mdspan : std::true_type { }; template -struct is_mdspan : detail::__is_mdspan_t { +struct is_mdspan : detail::is_mdspan_t { }; template struct is_mdspan - : std::conditional_t, is_mdspan, std::false_type> { + : std::conditional_t, is_mdspan, std::false_type> { }; /** @@ -103,7 +101,8 @@ template > -using device_mdspan = mdspan>; +using device_mdspan = + mdspan>; /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. @@ -112,7 +111,8 @@ template > -using host_mdspan = mdspan>; +using host_mdspan = + mdspan>; namespace detail { template