From 119d75bc70e4798c98d7b874845d5b7054eef5a0 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 29 Mar 2022 18:01:42 -0700 Subject: [PATCH 01/21] working through is_device_mdspan --- cpp/include/raft/distance/distance.cuh | 4 ++-- cpp/include/raft/mdarray.hpp | 23 +++++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/mdspan_utils.cu | 17 +++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 cpp/test/mdspan_utils.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index e13cfd94f8..0ac2faf917 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -253,7 +253,7 @@ void distance(raft::handle_t const& handle, RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous."); RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); - auto is_rowmajor = std::is_same::value; + constexpr auto is_rowmajor = std::is_same_v; distance(x.data(), y.data(), @@ -435,7 +435,7 @@ void pairwise_distance(raft::handle_t const& handle, RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); RAFT_EXPECTS(dist.is_contiguous(), "Output must be contiguous."); - bool rowmajor = x.stride(0) == 0; + constexpr auto rowmajor = std::is_same_v; rmm::device_uvector workspace(0, handle.get_stream()); diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index f92a0e5e59..aade230a08 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -47,6 +47,17 @@ template >; +template +struct is_device_mdspan + : std::conditional_t>, + std::true_type, + std::false_type> { +}; + /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. */ @@ -185,6 +196,18 @@ class mdarray { return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); } + /** + * @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. + */ + operator view_type() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } + /** + * @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. + */ + operator const_view_type() const noexcept + { + return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); + } + [[nodiscard]] constexpr auto size() const noexcept -> index_type { return this->view().size(); } [[nodiscard]] auto data() noexcept -> pointer { return c_.data(); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index c03e5d6bcd..bc706ccc84 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -75,6 +75,7 @@ add_executable(test_raft test/matrix/columnSort.cu test/matrix/linewise_op.cu test/mdarray.cu + test/mdspan_utils.cu test/mst.cu test/random/make_blobs.cu test/random/make_regression.cu diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu new file mode 100644 index 0000000000..e0f5668010 --- /dev/null +++ b/cpp/test/mdspan_utils.cu @@ -0,0 +1,17 @@ +#include +#include + +namespace raft { + +void test_template_asserts() { + using three_d_extents = stdex::extents; + using three_d_mdspan = device_mdspan; + + static_assert(is_device_mdspan::value, "Not a device_mdspan"); +} + +TEST(MDspan, TemplateAsserts) { + test_template_asserts(); +} + +} // namespace raft \ No newline at end of file From b6799f5f66447fe068baf0e79cf4249aef3bd745 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 30 Mar 2022 09:12:45 -0700 Subject: [PATCH 02/21] specializations for is_mdspan --- cpp/include/raft/mdarray.hpp | 33 +++++++++++++++++++++++---------- cpp/test/mdspan_utils.cu | 10 +++++++++- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 4322eb31da..b7299a47cf 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -37,6 +37,18 @@ using layout_c_contiguous = detail::stdex::layout_right; */ using layout_f_contiguous = detail::stdex::layout_left; +template +struct __is_mdspan : std::false_type {}; + +template +struct __is_mdspan> : std::true_type {}; + +template +inline constexpr bool is_mdspan_v = __is_mdspan>::value; + +template +using is_mdspan_t = std::enable_if_t, U>; + /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. */ @@ -47,16 +59,17 @@ template >; -template -struct is_device_mdspan - : std::conditional_t>, - std::true_type, - std::false_type> { -}; +// template +// struct is_mdspan +// : std::conditional_t>, +// std::true_type, +// std::false_type> { +// }; + /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index e0f5668010..63729410bc 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -3,11 +3,19 @@ namespace raft { +namespace stdex = std::experimental; + void test_template_asserts() { + // Testing 3d device mdspan to be an mdspan using three_d_extents = stdex::extents; using three_d_mdspan = device_mdspan; - static_assert(is_device_mdspan::value, "Not a device_mdspan"); + static_assert(is_mdspan_v, "3d mdspan type not an mdspan"); + static_assert(is_mdspan_v>, "device_matrix_view type not an mdspan"); + static_assert(is_mdspan_v>, "const host_vector_view type not an mdspan"); + static_assert(is_mdspan_v>, "const host_scalar_view type not an mdspan"); + + // static_assert(is_mdspan::value, "Not an mdspan"); } TEST(MDspan, TemplateAsserts) { From 3451b076065ad808c7fb91f73b045258cf078e6e Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 30 Mar 2022 12:17:15 -0700 Subject: [PATCH 03/21] flatten and tests --- cpp/include/raft/detail/mdarray.hpp | 16 ++++++ cpp/include/raft/mdarray.hpp | 86 +++++++++++++++++++---------- cpp/test/mdspan_utils.cu | 82 +++++++++++++++++++++++---- 3 files changed, 144 insertions(+), 40 deletions(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 624c7a4d07..784f2eb378 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -230,9 +230,25 @@ struct accessor_mixin : public AccessorPolicy { template using host_accessor = accessor_mixin; +template +struct __is_host_accessor : std::false_type { +}; + +template +struct __is_host_accessor> : std::true_type { +}; + template using device_accessor = accessor_mixin; +template +struct __is_device_accessor : std::false_type { +}; + +template +struct __is_device_accessor> : std::true_type { +}; + namespace stdex = std::experimental; using vector_extent = stdex::extents; diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index b7299a47cf..665aef7861 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -38,10 +38,12 @@ using layout_c_contiguous = detail::stdex::layout_right; using layout_f_contiguous = detail::stdex::layout_left; template -struct __is_mdspan : std::false_type {}; +struct __is_mdspan : std::false_type { +}; template -struct __is_mdspan> : std::true_type {}; +struct __is_mdspan> : std::true_type { +}; template inline constexpr bool is_mdspan_v = __is_mdspan>::value; @@ -59,17 +61,12 @@ template >; -// template -// struct is_mdspan -// : std::conditional_t>, -// std::true_type, -// std::false_type> { -// }; +template +inline constexpr bool is_device_mdspan_v = + is_mdspan_v&& detail::__is_device_accessor::value; +template +using is_device_mdspan_t = std::enable_if_t, U>; /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. @@ -81,6 +78,13 @@ template >; +template +inline constexpr bool is_host_mdspan_v = + is_mdspan_v&& detail::__is_host_accessor::value; + +template +using is_host_mdspan_t = std::enable_if_t, U>; + /** * @brief Modified from the c++ mdarray proposal * @@ -376,15 +380,15 @@ using device_scalar = device_mdarray; * @brief Shorthand for 1-dim host mdarray. * @tparam ElementType the data type of the vector elements */ -template -using host_vector = host_mdarray; +template +using host_vector = host_mdarray; /** * @brief Shorthand for 1-dim device mdarray. * @tparam ElementType the data type of the vector elements */ -template -using device_vector = device_mdarray; +template +using device_vector = device_mdarray; /** * @brief Shorthand for c-contiguous host matrix. @@ -420,15 +424,15 @@ using device_scalar_view = device_mdspan; * @brief Shorthand for 1-dim host mdspan. * @tparam ElementType the data type of the vector elements */ -template -using host_vector_view = host_mdspan; +template +using host_vector_view = host_mdspan; /** * @brief Shorthand for 1-dim device mdspan. * @tparam ElementType the data type of the vector elements */ -template -using device_vector_view = device_mdspan; +template +using device_vector_view = device_mdspan; /** * @brief Shorthand for c-contiguous host matrix view. @@ -514,11 +518,11 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) * @param[in] n number of elements in pointer * @return raft::host_vector_view */ -template +template auto make_host_vector_view(ElementType* ptr, size_t n) { detail::vector_extent extents{n}; - return host_vector_view{ptr, extents}; + return host_vector_view{ptr, extents}; } /** @@ -528,11 +532,11 @@ auto make_host_vector_view(ElementType* ptr, size_t n) * @param[in] n number of elements in pointer * @return raft::device_vector_view */ -template +template auto make_device_vector_view(ElementType* ptr, size_t n) { detail::vector_extent extents{n}; - return device_vector_view{ptr, extents}; + return device_vector_view{ptr, extents}; } /** @@ -646,13 +650,13 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) * @param[in] n number of elements in vector * @return raft::host_vector */ -template +template auto make_host_vector(size_t n) { detail::vector_extent extents{n}; using policy_t = typename host_vector::container_policy_type; policy_t policy; - return host_vector{extents, policy}; + return host_vector{extents, policy}; } /** @@ -662,13 +666,13 @@ auto make_host_vector(size_t n) * @param[in] stream the cuda stream for ordering events * @return raft::device_vector */ -template +template auto make_device_vector(size_t n, rmm::cuda_stream_view stream) { detail::vector_extent extents{n}; using policy_t = typename device_vector::container_policy_type; policy_t policy{stream}; - return device_vector{extents, policy}; + return device_vector{extents, policy}; } /** @@ -683,4 +687,30 @@ auto make_device_vector(raft::handle_t const& handle, size_t n) { return make_device_vector(n, handle.get_stream()); } + +template > +auto flatten(host_mdspan_type h_mds) +{ + size_t flat_dimension = 1; + for (size_t i = 0; i < h_mds.extents().rank(); ++i) { + flat_dimension *= h_mds.extent(i); + } + + return make_host_vector_view(h_mds.data(), + flat_dimension); +} + +template +constexpr auto flatten(host_vector_view h_vv) +{ + return h_vv; +} + +template +constexpr auto flatten(host_scalar_view h_sv) +{ + return h_sv; +} + } // namespace raft diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 63729410bc..2520a04f0b 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -5,21 +5,79 @@ namespace raft { namespace stdex = std::experimental; -void test_template_asserts() { - // Testing 3d device mdspan to be an mdspan - using three_d_extents = stdex::extents; - using three_d_mdspan = device_mdspan; +void test_template_asserts() +{ + // Testing 3d device mdspan to be an mdspan + using three_d_extents = stdex::extents; + using three_d_mdspan = device_mdspan; + + // Checking if types are mdspan, supposed to fail for std::vector + static_assert(is_mdspan_v, "3d mdspan type not an mdspan"); + static_assert(is_mdspan_v>, "device_matrix_view type not an mdspan"); + static_assert(is_mdspan_v>, + "const host_vector_view type not an mdspan"); + static_assert(is_mdspan_v>, + "const host_scalar_view type not an mdspan"); + static_assert(!is_mdspan_v>, "std::vector is an mdspan"); - static_assert(is_mdspan_v, "3d mdspan type not an mdspan"); - static_assert(is_mdspan_v>, "device_matrix_view type not an mdspan"); - static_assert(is_mdspan_v>, "const host_vector_view type not an mdspan"); - static_assert(is_mdspan_v>, "const host_scalar_view type not an mdspan"); + // Checking if types are device_mdspan + static_assert(is_device_mdspan_v>, + "device_matrix_view type not a device_mdspan"); + static_assert(!is_device_mdspan_v>, + "host_matrix_view type is a device_mdspan"); - // static_assert(is_mdspan::value, "Not an mdspan"); + // Checking if types are host_mdspan + static_assert(!is_host_mdspan_v>, + "device_matrix_view type not a host_mdspan"); + static_assert(is_host_mdspan_v>, + "host_matrix_view type is a host_mdspan"); } -TEST(MDspan, TemplateAsserts) { - test_template_asserts(); +TEST(MDSpan, TemplateAsserts) { test_template_asserts(); } + +void test_host_flatten() +{ + // flatten 3d host matrix + { + using three_d_extents = stdex::extents; + using three_d_mdarray = host_mdarray; + + three_d_extents extents{3, 3, 3}; + three_d_mdarray::container_policy_type policy; + three_d_mdarray mda{extents, policy}; + + auto flat_view = flatten(mda.view()); + + static_assert(std::is_same_v, + "layouts not the same"); + + ASSERT_EQ(flat_view.extents().rank(), 1); + ASSERT_EQ(flat_view.extent(0), 27); + } + + // flatten host vector + { + auto hv = make_host_vector(27); + auto flat_view = flatten(hv.view()); + + static_assert(std::is_same_v, "types not the same"); + + ASSERT_EQ(hv.extents().rank(), flat_view.extents().rank()); + ASSERT_EQ(hv.extent(0), flat_view.extent(0)); + } + + // flatten host scalar + { + auto hs = make_host_scalar(27); + auto flat_view = flatten(hs.view()); + + static_assert(std::is_same_v, "types not the same"); + + ASSERT_EQ(flat_view.extent(0), 1); + } } -} // namespace raft \ No newline at end of file +TEST(MDArray, HostFlatten) { test_host_flatten(); } + +} // namespace raft \ No newline at end of file From 1eb756519a6f30a2b67d687d4951bf53c128b4a2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 1 Apr 2022 11:25:57 -0700 Subject: [PATCH 04/21] checking for derived mdspan --- cpp/include/raft/mdarray.hpp | 16 +++++++++++++++- cpp/test/mdspan_utils.cu | 10 ++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 665aef7861..2f86329231 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -45,8 +45,22 @@ template struct __is_mdspan> : std::true_type { }; +template +void __takes_an_mdspan_ptr( + detail::stdex::mdspan*); + +template +struct __is_derived_mdspan : std::false_type { +}; + +template +struct __is_derived_mdspan()))>> + : std::true_type { +}; + template -inline constexpr bool is_mdspan_v = __is_mdspan>::value; +inline constexpr bool is_mdspan_v = + __is_mdspan>::value || __is_derived_mdspan>::value; template using is_mdspan_t = std::enable_if_t, U>; diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 2520a04f0b..0aa83d11a2 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -5,11 +5,20 @@ namespace raft { namespace stdex = std::experimental; +template > +struct derived_mdspan + : public detail::stdex::mdspan { +}; + void test_template_asserts() { // Testing 3d device mdspan to be an mdspan using three_d_extents = stdex::extents; using three_d_mdspan = device_mdspan; + using d_mdspan = derived_mdspan; // Checking if types are mdspan, supposed to fail for std::vector static_assert(is_mdspan_v, "3d mdspan type not an mdspan"); @@ -19,6 +28,7 @@ void test_template_asserts() static_assert(is_mdspan_v>, "const host_scalar_view type not an mdspan"); static_assert(!is_mdspan_v>, "std::vector is an mdspan"); + static_assert(is_mdspan_v, "Derived mdspan type is not mdspan"); // Checking if types are device_mdspan static_assert(is_device_mdspan_v>, From fb2de54770b883fcc3c4701e24b4ef3980df56a8 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 6 Apr 2022 16:09:48 -0700 Subject: [PATCH 05/21] working flatten for mdarray and mdspan --- cpp/include/raft/detail/mdarray.hpp | 6 ++ cpp/include/raft/mdarray.hpp | 89 ++++++++++++++++++++++++----- cpp/test/mdspan_utils.cu | 2 +- 3 files changed, 81 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 784f2eb378..38f77042c0 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -238,6 +238,9 @@ template struct __is_host_accessor> : std::true_type { }; +template +inline constexpr bool is_host_accessor_v = __is_host_accessor::value; + template using device_accessor = accessor_mixin; @@ -249,6 +252,9 @@ template struct __is_device_accessor> : std::true_type { }; +template +inline constexpr bool is_device_accessor_v = __is_device_accessor::value; + namespace stdex = std::experimental; using vector_extent = stdex::extents; diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 2f86329231..837c9ed8c2 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -59,8 +59,8 @@ struct __is_derived_mdspan -inline constexpr bool is_mdspan_v = - __is_mdspan>::value || __is_derived_mdspan>::value; +inline constexpr bool is_mdspan_v = std::disjunction_v<__is_mdspan>, + __is_derived_mdspan>>; template using is_mdspan_t = std::enable_if_t, U>; @@ -75,9 +75,17 @@ template >; +template +struct __is_device_mdspan : std::false_type { +}; + +template +struct __is_device_mdspan + : std::bool_constant> { +}; + template -inline constexpr bool is_device_mdspan_v = - is_mdspan_v&& detail::__is_device_accessor::value; +inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; template using is_device_mdspan_t = std::enable_if_t, U>; @@ -92,9 +100,17 @@ template >; +template +struct __is_host_mdspan : std::false_type { +}; + +template +struct __is_host_mdspan + : std::bool_constant> { +}; + template -inline constexpr bool is_host_mdspan_v = - is_mdspan_v&& detail::__is_host_accessor::value; +inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; template using is_host_mdspan_t = std::enable_if_t, U>; @@ -130,9 +146,6 @@ using is_host_mdspan_t = std::enable_if_t, U>; */ template class mdarray { - static_assert(!std::is_const::value, - "Element type for container must not be const."); - public: using extents_type = Extents; using layout_type = LayoutPolicy; @@ -169,7 +182,6 @@ class mdarray { using view_type = view_type_impl; using const_view_type = view_type_impl; - public: constexpr mdarray() noexcept(std::is_nothrow_default_constructible_v) : cp_{rmm::cuda_stream_default}, c_{cp_.create(0)} {}; constexpr mdarray(mdarray const&) noexcept(std::is_nothrow_copy_constructible_v) = @@ -218,11 +230,11 @@ class mdarray { /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } + view_type view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - auto view() const noexcept + const_view_type view() const noexcept { return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); } @@ -348,6 +360,20 @@ class mdarray { container_type c_; }; +template +struct __is_mdarray : std::false_type { +}; + +template +struct __is_mdarray> : std::true_type { +}; + +template +inline constexpr bool is_mdarray_v = __is_mdarray>::value; + +template +using is_mdarray_t = std::enable_if_t, U>; + /** * @brief mdarray with host container policy * @tparam ElementType the data type of the elements @@ -702,7 +728,8 @@ auto make_device_vector(raft::handle_t const& handle, size_t n) return make_device_vector(n, handle.get_stream()); } -template > +template >* = nullptr> auto flatten(host_mdspan_type h_mds) { size_t flat_dimension = 1; @@ -715,16 +742,48 @@ auto flatten(host_mdspan_type h_mds) flat_dimension); } -template -constexpr auto flatten(host_vector_view h_vv) +template >* = nullptr> +auto flatten(device_mdspan_type d_mds) +{ + size_t flat_dimension = 1; + for (size_t i = 0; i < d_mds.extents().rank(); ++i) { + flat_dimension *= d_mds.extent(i); + } + + return make_device_vector_view(d_mds.data(), + flat_dimension); +} + +template >* = nullptr> +auto flatten(const mdarray_type& mda) +{ + return flatten(mda.view()); +} + +template +constexpr auto flatten(host_vector_view h_vv) { return h_vv; } +template +auto flatten(const host_vector& h_v) +{ + return flatten(h_v.view()); +} + template constexpr auto flatten(host_scalar_view h_sv) { return h_sv; } +template +auto flatten(const host_scalar& h_s) +{ + return flatten(h_s.view()); +} + } // namespace raft diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 0aa83d11a2..7a88dbc256 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -56,7 +56,7 @@ void test_host_flatten() three_d_mdarray::container_policy_type policy; three_d_mdarray mda{extents, policy}; - auto flat_view = flatten(mda.view()); + auto flat_view = flatten(mda); static_assert(std::is_same_v, From 65399366aa27f6a7d92fd7d1a1c833a6cb8f9cd3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 6 Apr 2022 16:12:16 -0700 Subject: [PATCH 06/21] add copyright --- cpp/test/mdspan_utils.cu | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 7a88dbc256..3795fdcae3 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include From eb400c6e1ec798126cdc88aee4e4b005c56ad35f Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 18 Apr 2022 17:25:15 -0700 Subject: [PATCH 07/21] working through reshape --- cpp/include/raft/detail/mdarray.hpp | 32 +++++----- cpp/include/raft/mdarray.hpp | 95 +++++++++++++++++++++++------ cpp/test/mdspan_utils.cu | 15 +++-- 3 files changed, 102 insertions(+), 40 deletions(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 38f77042c0..e8aaa26d5d 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -230,30 +230,30 @@ struct accessor_mixin : public AccessorPolicy { template using host_accessor = accessor_mixin; -template -struct __is_host_accessor : std::false_type { -}; +// template +// struct __is_host_accessor : std::false_type { +// }; -template -struct __is_host_accessor> : std::true_type { -}; +// template +// struct __is_host_accessor> : std::true_type { +// }; -template -inline constexpr bool is_host_accessor_v = __is_host_accessor::value; +// template +// inline constexpr bool is_host_accessor_v = __is_host_accessor::value; template using device_accessor = accessor_mixin; -template -struct __is_device_accessor : std::false_type { -}; +// template +// struct __is_device_accessor : std::false_type { +// }; -template -struct __is_device_accessor> : std::true_type { -}; +// template +// struct __is_device_accessor> : std::true_type { +// }; -template -inline constexpr bool is_device_accessor_v = __is_device_accessor::value; +// template +// inline constexpr bool is_device_accessor_v = __is_device_accessor::value; namespace stdex = std::experimental; diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 837c9ed8c2..520831a3c7 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -80,8 +80,7 @@ struct __is_device_mdspan : std::false_type { }; template -struct __is_device_mdspan - : std::bool_constant> { +struct __is_device_mdspan : std::bool_constant { }; template @@ -105,8 +104,7 @@ struct __is_host_mdspan : std::false_type { }; template -struct __is_host_mdspan - : std::bool_constant> { +struct __is_host_mdspan : T::accessor_type::is_host_type { }; template @@ -115,6 +113,10 @@ inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::va template using is_host_mdspan_t = std::enable_if_t, U>; +// template +// inline constexpr bool is_host_or_device_mdspan_v = std::conjunction_v<__is_device_mdspan, +// __is_host_mdspan>; + /** * @brief Modified from the c++ mdarray proposal * @@ -146,6 +148,9 @@ using is_host_mdspan_t = std::enable_if_t, U>; */ template class mdarray { + static_assert(!std::is_const::value, + "Element type for container must not be const."); + public: using extents_type = Extents; using layout_type = LayoutPolicy; @@ -230,11 +235,11 @@ class mdarray { /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - view_type view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } + auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - const_view_type view() const noexcept + auto view() const noexcept { return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); } @@ -722,38 +727,30 @@ auto make_device_vector(size_t n, rmm::cuda_stream_view stream) * @param[in] n number of elements in vector * @return raft::device_vector */ -template +template auto make_device_vector(raft::handle_t const& handle, size_t n) { - return make_device_vector(n, handle.get_stream()); + return make_device_vector(n, handle.get_stream()); } template >* = nullptr> auto flatten(host_mdspan_type h_mds) { - size_t flat_dimension = 1; - for (size_t i = 0; i < h_mds.extents().rank(); ++i) { - flat_dimension *= h_mds.extent(i); - } + RAFT_EXPECTS(h_mds.is_contiguous(), "Input must be contiguous."); return make_host_vector_view(h_mds.data(), - flat_dimension); + typename host_mdspan_type::layout_type>(h_mds.data(), h_mds.size()); } template >* = nullptr> auto flatten(device_mdspan_type d_mds) { - size_t flat_dimension = 1; - for (size_t i = 0; i < d_mds.extents().rank(); ++i) { - flat_dimension *= d_mds.extent(i); - } - + RAFT_EXPECTS(d_mds.is_contiguous(), "Input must be contiguous."); return make_device_vector_view(d_mds.data(), - flat_dimension); + d_mds.size()); } template >* = nullptr> @@ -786,4 +783,62 @@ auto flatten(const host_scalar& h_s) return flatten(h_s.view()); } +template >* = nullptr> +auto reshape(host_mdspan_type h_mds, std::experimental::extents new_shape) +{ + RAFT_EXPECTS(h_mds.is_contiguous(), "Input must be contiguous."); + + if (new_shape == h_mds.extents()) { + return h_mds; + } else if (new_shape.rank(1)) { + auto new_size = new_shape.extent(0); + RAFT_EXPECTS(new_size <= h_mds.size(), + "Cannot reshape array of size %ul into %ul", + h_mds.size(), + new_size()); + + if (new_size == 1) { + return make_host_scalar_view(h_mds.data()); + } else { + return make_host_vector_view(h_mds.data(), new_size); + } + } else if (new_shape.rank(2)) { + auto new_size = new_shape.extent(0) * new_shape.extent(1); + RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch"); + + return make_host_matrix_view( + h_mds.data(), new_shape.extent(0), new_shape.extent(1)); + } else { + size_t new_size = 1; + for (size_t i = 0; i < new_shape.rank(); ++i) { + new_size *= new_shape.extent(i); + } + RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch"); + + return detail::stdex::mdspan(h_mds.data(), new_shape); + } +} + +// template >* = nullptr> +// auto reshape(device_mdspan_type d_mds) +// { +// RAFT_EXPECTS(d_mds.is_contiguous(), "Input must be contiguous."); +// return make_device_vector_view(d_mds.data(), +// d_mds.size()); +// } + +// template >* = nullptr> +// auto reshape(const mdarray_type& mda) +// { +// return reshape(mda.view()); +// } + } // namespace raft diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 3795fdcae3..237a2e7222 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -25,8 +25,8 @@ template > -struct derived_mdspan - : public detail::stdex::mdspan { +struct derived_device_mdspan + : public device_mdspan { }; void test_template_asserts() @@ -34,7 +34,13 @@ void test_template_asserts() // Testing 3d device mdspan to be an mdspan using three_d_extents = stdex::extents; using three_d_mdspan = device_mdspan; - using d_mdspan = derived_mdspan; + using d_mdspan = derived_device_mdspan; + + static_assert(std::is_same_v, device_mdspan>, + "not same"); + static_assert(std::is_same_v, + device_mdspan>>, + "not same"); // Checking if types are mdspan, supposed to fail for std::vector static_assert(is_mdspan_v, "3d mdspan type not an mdspan"); @@ -44,13 +50,14 @@ void test_template_asserts() static_assert(is_mdspan_v>, "const host_scalar_view type not an mdspan"); static_assert(!is_mdspan_v>, "std::vector is an mdspan"); - static_assert(is_mdspan_v, "Derived mdspan type is not mdspan"); + static_assert(is_mdspan_v, "Derived device mdspan type is not mdspan"); // Checking if types are device_mdspan static_assert(is_device_mdspan_v>, "device_matrix_view type not a device_mdspan"); static_assert(!is_device_mdspan_v>, "host_matrix_view type is a device_mdspan"); + static_assert(is_device_mdspan_v, "Derived device mdspan type is not device_mdspan"); // Checking if types are host_mdspan static_assert(!is_host_mdspan_v>, From 27c5c19965dd48e70d262697aacb9a4297652dc0 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 19 Apr 2022 08:53:30 -0700 Subject: [PATCH 08/21] finishing up host/device flatten with tests --- cpp/include/raft/detail/mdarray.hpp | 3 +- cpp/include/raft/mdarray.hpp | 24 +++++++++++++++ cpp/test/mdspan_utils.cu | 46 +++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index e8aaa26d5d..90c4dbb3cc 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -22,6 +22,7 @@ */ #pragma once #include +#include #include // dynamic_extent #include #include @@ -58,7 +59,7 @@ class device_reference { auto operator=(T const& other) -> device_reference& { auto* raw = ptr_.get(); - update_device(raw, &other, 1, stream_); + raft::update_device(raw, &other, 1, stream_); return *this; } }; diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 520831a3c7..f19f4e355f 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -783,6 +783,30 @@ auto flatten(const host_scalar& h_s) return flatten(h_s.view()); } +template +constexpr auto flatten(device_vector_view d_vv) +{ + return d_vv; +} + +template +auto flatten(const device_vector& d_v) +{ + return flatten(d_v.view()); +} + +template +constexpr auto flatten(device_scalar_view d_sv) +{ + return d_sv; +} + +template +auto flatten(const device_scalar& d_s) +{ + return flatten(d_s.view()); +} + template >* = nullptr> diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 237a2e7222..671f46d853 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -113,4 +113,50 @@ void test_host_flatten() TEST(MDArray, HostFlatten) { test_host_flatten(); } +void test_device_flatten() +{ + raft::handle_t handle{}; + // flatten 3d host matrix + { + using three_d_extents = stdex::extents; + using three_d_mdarray = device_mdarray; + + three_d_extents extents{3, 3, 3}; + three_d_mdarray::container_policy_type policy{handle.get_stream()}; + three_d_mdarray mda{extents, policy}; + + auto flat_view = flatten(mda); + + static_assert(std::is_same_v, + "layouts not the same"); + + ASSERT_EQ(flat_view.extents().rank(), 1); + ASSERT_EQ(flat_view.extent(0), 27); + } + + // flatten host vector + { + auto dv = make_device_vector(27, handle.get_stream()); + auto flat_view = flatten(dv.view()); + + static_assert(std::is_same_v, "types not the same"); + + ASSERT_EQ(dv.extents().rank(), flat_view.extents().rank()); + ASSERT_EQ(dv.extent(0), flat_view.extent(0)); + } + + // flatten host scalar + { + auto ds = make_device_scalar(27, handle.get_stream()); + auto flat_view = flatten(ds.view()); + + static_assert(std::is_same_v, "types not the same"); + + ASSERT_EQ(flat_view.extent(0), 1); + } +} + +TEST(MDArray, DeviceFlatten) { test_device_flatten(); } + } // namespace raft \ No newline at end of file From ff282f3d16f2796db9dd690f79a1175879662c68 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 19 Apr 2022 10:44:43 -0700 Subject: [PATCH 09/21] working host reshape with tests --- cpp/include/raft/mdarray.hpp | 82 ++++++++++++++++++++---------------- cpp/test/mdspan_utils.cu | 67 +++++++++++++++++++++++++++-- 2 files changed, 108 insertions(+), 41 deletions(-) diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index f19f4e355f..a3a3c00062 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -27,6 +27,10 @@ #include namespace raft { + +template +using extents = std::experimental::extents; + /** * @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. */ @@ -814,39 +818,41 @@ auto reshape(host_mdspan_type h_mds, std::experimental::extents new_ { RAFT_EXPECTS(h_mds.is_contiguous(), "Input must be contiguous."); - if (new_shape == h_mds.extents()) { - return h_mds; - } else if (new_shape.rank(1)) { - auto new_size = new_shape.extent(0); - RAFT_EXPECTS(new_size <= h_mds.size(), - "Cannot reshape array of size %ul into %ul", - h_mds.size(), - new_size()); - - if (new_size == 1) { - return make_host_scalar_view(h_mds.data()); - } else { - return make_host_vector_view(h_mds.data(), new_size); - } - } else if (new_shape.rank(2)) { - auto new_size = new_shape.extent(0) * new_shape.extent(1); - RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch"); - - return make_host_matrix_view( - h_mds.data(), new_shape.extent(0), new_shape.extent(1)); - } else { - size_t new_size = 1; - for (size_t i = 0; i < new_shape.rank(); ++i) { - new_size *= new_shape.extent(i); - } - RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch"); - - return detail::stdex::mdspan(h_mds.data(), new_shape); + // if (new_shape == h_mds.extents()) { + // return h_mds; + // } else if (new_shape.rank() == 1) { + // auto new_size = new_shape.extent(0); + // RAFT_EXPECTS(new_size <= h_mds.size(), + // "Cannot reshape array of size %ul into %ul", + // h_mds.size(), + // new_size()); + + // if (new_size == 1) { + // return make_host_scalar_view(h_mds.data()); + // } else { + // return make_host_vector_view(h_mds.data(), + // new_size); + // } + // } else if (new_shape.rank() == 2) { + // auto new_size = new_shape.extent(0) * new_shape.extent(1); + // RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch"); + + // return make_host_matrix_view( + // h_mds.data(), new_shape.extent(0), new_shape.extent(1)); + // } else { + size_t new_size = 1; + for (size_t i = 0; i < new_shape.rank(); ++i) { + new_size *= new_shape.extent(i); } + RAFT_EXPECTS(new_size <= h_mds.size(), "Cannot reshape array with size mismatch"); + + return detail::stdex::mdspan(h_mds.data(), new_shape); + // } } // template new_ // d_mds.size()); // } -// template >* = nullptr> -// auto reshape(const mdarray_type& mda) -// { -// return reshape(mda.view()); -// } +template >* = nullptr> +auto reshape(const mdarray_type& mda, extents new_shape) +{ + return reshape(mda.view(), new_shape); +} } // namespace raft diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 671f46d853..a2cd16d390 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -70,7 +70,7 @@ TEST(MDSpan, TemplateAsserts) { test_template_asserts(); } void test_host_flatten() { - // flatten 3d host matrix + // flatten 3d host mdspan { using three_d_extents = stdex::extents; using three_d_mdarray = host_mdarray; @@ -116,7 +116,7 @@ TEST(MDArray, HostFlatten) { test_host_flatten(); } void test_device_flatten() { raft::handle_t handle{}; - // flatten 3d host matrix + // flatten 3d device mdspan { using three_d_extents = stdex::extents; using three_d_mdarray = device_mdarray; @@ -135,7 +135,7 @@ void test_device_flatten() ASSERT_EQ(flat_view.extent(0), 27); } - // flatten host vector + // flatten device vector { auto dv = make_device_vector(27, handle.get_stream()); auto flat_view = flatten(dv.view()); @@ -146,7 +146,7 @@ void test_device_flatten() ASSERT_EQ(dv.extent(0), flat_view.extent(0)); } - // flatten host scalar + // flatten device scalar { auto ds = make_device_scalar(27, handle.get_stream()); auto flat_view = flatten(ds.view()); @@ -159,4 +159,63 @@ void test_device_flatten() TEST(MDArray, DeviceFlatten) { test_device_flatten(); } +void test_host_reshape() +{ + // reshape 3d host matrix to vector + { + using three_d_extents = stdex::extents; + using three_d_mdarray = host_mdarray; + + three_d_extents extents{3, 3, 3}; + three_d_mdarray::container_policy_type policy; + three_d_mdarray mda{extents, policy}; + + auto flat_view = reshape(mda, raft::extents{27}); + // this confirms aliasing works as intended + static_assert(std::is_same_v>, + "types not the same"); + + ASSERT_EQ(flat_view.extents().rank(), 1); + ASSERT_EQ(flat_view.extent(0), 27); + } + + // reshape 4d host matrix to 2d + { + using four_d_extents = + stdex::extents; + using four_d_mdarray = host_mdarray; + + four_d_extents extents{2, 2, 2, 2}; + four_d_mdarray::container_policy_type policy; + four_d_mdarray mda{extents, policy}; + + auto matrix = reshape(mda, raft::extents{4, 4}); + // this confirms aliasing works as intended + static_assert(std::is_same_v>, + "types not the same"); + + ASSERT_EQ(matrix.extents().rank(), 2); + ASSERT_EQ(matrix.extent(0), 4); + ASSERT_EQ(matrix.extent(1), 4); + } + + // shrink host vector + { + auto hv = make_host_vector(27); + auto shrunk_vector = reshape(hv.view(), raft::extents(20)); + + static_assert(std::is_same_v, + "types not the same"); + + ASSERT_EQ(hv.extents().rank(), shrunk_vector.extents().rank()); + ASSERT_EQ(shrunk_vector.extent(0), 20); + } +} + +TEST(MDArray, HostReshape) { test_host_reshape(); } + } // namespace raft \ No newline at end of file From b8aec4ca04d24d894f9f977bfbb36cde0eda5988 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 19 Apr 2022 11:17:04 -0700 Subject: [PATCH 10/21] working reshape for device arrays --- cpp/include/raft/mdarray.hpp | 65 ++++++------------------------------ cpp/test/mdspan_utils.cu | 29 +++++----------- 2 files changed, 19 insertions(+), 75 deletions(-) diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index a3a3c00062..bf3323598c 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -66,9 +66,6 @@ template inline constexpr bool is_mdspan_v = std::disjunction_v<__is_mdspan>, __is_derived_mdspan>>; -template -using is_mdspan_t = std::enable_if_t, U>; - /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. */ @@ -90,9 +87,6 @@ struct __is_device_mdspan : std::bool_constant inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; -template -using is_device_mdspan_t = std::enable_if_t, U>; - /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. */ @@ -114,12 +108,8 @@ struct __is_host_mdspan : T::accessor_type::is_host_type { template inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; -template -using is_host_mdspan_t = std::enable_if_t, U>; - -// template -// inline constexpr bool is_host_or_device_mdspan_v = std::conjunction_v<__is_device_mdspan, -// __is_host_mdspan>; +template +inline constexpr bool is_host_or_device_mdspan_v = is_device_mdspan_v or is_host_mdspan_v; /** * @brief Modified from the c++ mdarray proposal @@ -811,60 +801,25 @@ auto flatten(const device_scalar& d_s) return flatten(d_s.view()); } -template >* = nullptr> -auto reshape(host_mdspan_type h_mds, std::experimental::extents new_shape) + std::enable_if_t>* = nullptr> +auto reshape(mdspan_type mds, extents new_shape) { - RAFT_EXPECTS(h_mds.is_contiguous(), "Input must be contiguous."); + RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); - // if (new_shape == h_mds.extents()) { - // return h_mds; - // } else if (new_shape.rank() == 1) { - // auto new_size = new_shape.extent(0); - // RAFT_EXPECTS(new_size <= h_mds.size(), - // "Cannot reshape array of size %ul into %ul", - // h_mds.size(), - // new_size()); - - // if (new_size == 1) { - // return make_host_scalar_view(h_mds.data()); - // } else { - // return make_host_vector_view(h_mds.data(), - // new_size); - // } - // } else if (new_shape.rank() == 2) { - // auto new_size = new_shape.extent(0) * new_shape.extent(1); - // RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch"); - - // return make_host_matrix_view( - // h_mds.data(), new_shape.extent(0), new_shape.extent(1)); - // } else { size_t new_size = 1; for (size_t i = 0; i < new_shape.rank(); ++i) { new_size *= new_shape.extent(i); } - RAFT_EXPECTS(new_size <= h_mds.size(), "Cannot reshape array with size mismatch"); + RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); - return detail::stdex::mdspan(h_mds.data(), new_shape); - // } + typename mdspan_type::layout_type, + typename mdspan_type::accessor_type>(mds.data(), new_shape); } -// template >* = nullptr> -// auto reshape(device_mdspan_type d_mds) -// { -// RAFT_EXPECTS(d_mds.is_contiguous(), "Input must be contiguous."); -// return make_device_vector_view(d_mds.data(), -// d_mds.size()); -// } - template >* = nullptr> diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index a2cd16d390..2861a1c004 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -159,9 +159,9 @@ void test_device_flatten() TEST(MDArray, DeviceFlatten) { test_device_flatten(); } -void test_host_reshape() +void test_reshape() { - // reshape 3d host matrix to vector + // reshape 3d host array to vector { using three_d_extents = stdex::extents; using three_d_mdarray = host_mdarray; @@ -181,41 +181,30 @@ void test_host_reshape() ASSERT_EQ(flat_view.extent(0), 27); } - // reshape 4d host matrix to 2d + // reshape 4d device array to 2d { + raft::handle_t handle{}; using four_d_extents = stdex::extents; - using four_d_mdarray = host_mdarray; + using four_d_mdarray = device_mdarray; four_d_extents extents{2, 2, 2, 2}; - four_d_mdarray::container_policy_type policy; + four_d_mdarray::container_policy_type policy{handle.get_stream()}; four_d_mdarray mda{extents, policy}; auto matrix = reshape(mda, raft::extents{4, 4}); // this confirms aliasing works as intended static_assert(std::is_same_v>, + device_matrix_view>, "types not the same"); ASSERT_EQ(matrix.extents().rank(), 2); ASSERT_EQ(matrix.extent(0), 4); ASSERT_EQ(matrix.extent(1), 4); } - - // shrink host vector - { - auto hv = make_host_vector(27); - auto shrunk_vector = reshape(hv.view(), raft::extents(20)); - - static_assert(std::is_same_v, - "types not the same"); - - ASSERT_EQ(hv.extents().rank(), shrunk_vector.extents().rank()); - ASSERT_EQ(shrunk_vector.extent(0), 20); - } } -TEST(MDArray, HostReshape) { test_host_reshape(); } +TEST(MDArray, Reshape) { test_reshape(); } } // namespace raft \ No newline at end of file From ceb7dd47a0f354559c2cd02b24a2775f1a7c54c0 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 19 Apr 2022 15:04:22 -0700 Subject: [PATCH 11/21] adding docstrings --- cpp/include/raft/detail/mdarray.hpp | 22 ------------ cpp/include/raft/mdarray.hpp | 55 +++++++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 90c4dbb3cc..03a56839a7 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -231,31 +231,9 @@ struct accessor_mixin : public AccessorPolicy { template using host_accessor = accessor_mixin; -// template -// struct __is_host_accessor : std::false_type { -// }; - -// template -// struct __is_host_accessor> : std::true_type { -// }; - -// template -// inline constexpr bool is_host_accessor_v = __is_host_accessor::value; - template using device_accessor = accessor_mixin; -// template -// struct __is_device_accessor : std::false_type { -// }; - -// template -// struct __is_device_accessor> : std::true_type { -// }; - -// template -// inline constexpr bool is_device_accessor_v = __is_device_accessor::value; - namespace stdex = std::experimental; using vector_extent = stdex::extents; diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index bf3323598c..79675d70d6 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -28,6 +28,9 @@ namespace raft { +/** + * @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan + */ template using extents = std::experimental::extents; @@ -62,6 +65,9 @@ struct __is_derived_mdspan inline constexpr bool is_mdspan_v = std::disjunction_v<__is_mdspan>, __is_derived_mdspan>>; @@ -84,6 +90,9 @@ template struct __is_device_mdspan : std::bool_constant { }; +/** + * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type + */ template inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; @@ -105,9 +114,16 @@ template struct __is_host_mdspan : T::accessor_type::is_host_type { }; +/** + * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type + */ template inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; +/** + * @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan + * or their derived types + */ template inline constexpr bool is_host_or_device_mdspan_v = is_device_mdspan_v or is_host_mdspan_v; @@ -367,12 +383,12 @@ template struct __is_mdarray> : std::true_type { }; +/** + * @\brief Boolean to determine if template type T is raft::mdarray + */ template inline constexpr bool is_mdarray_v = __is_mdarray>::value; -template -using is_mdarray_t = std::enable_if_t, U>; - /** * @brief mdarray with host container policy * @tparam ElementType the data type of the elements @@ -727,6 +743,13 @@ auto make_device_vector(raft::handle_t const& handle, size_t n) return make_device_vector(n, handle.get_stream()); } +/** + * @brief + * + * @tparam host_mdspan_type Expected type raft::host_mdspan + * @param h_mds raft::host_mdspan object + * @return raft::host_mdspan + */ template >* = nullptr> auto flatten(host_mdspan_type h_mds) @@ -747,6 +770,14 @@ auto flatten(device_mdspan_type d_mds) d_mds.size()); } +/** + * @brief + * + * @tparam mdarray_type Expected type raft::mdarray + * @param mda raft::mdarray object + * @return Either raft::host_mdspan or raft::device_mdspan depending on the underlying + * ContainerType + */ template >* = nullptr> auto flatten(const mdarray_type& mda) { @@ -801,6 +832,15 @@ auto flatten(const device_scalar& d_s) return flatten(d_s.view()); } +/** + * @brief + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @tparam Extents raft::extents for dimensions + * @param mds raft::host_mdspan or raft::device_mdspan object + * @param new_shape Desired new shape of the input + * @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy + */ template >* = nullptr> @@ -820,6 +860,15 @@ auto reshape(mdspan_type mds, extents new_shape) typename mdspan_type::accessor_type>(mds.data(), new_shape); } +/** + * @brief + * + * @tparam mdarray_type Expected type raft::mdarray + * @tparam Extents raft::extents for dimensions + * @param mda raft::mdarray object + * @param new_shape Desired new shape of the input + * @return raft::host_mdspan or raft::device_mdspan, depending on ContainerPolicy + */ template >* = nullptr> From fa88e2313efd0544ebc3ff17f161bdbaa59d2833 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Wed, 20 Apr 2022 08:23:21 -0700 Subject: [PATCH 12/21] Apply suggestions from code review Co-authored-by: Jiaming Yuan --- cpp/include/raft/mdarray.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 79675d70d6..64ed714634 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -257,13 +257,13 @@ class mdarray { /** * @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. */ - operator view_type() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } + operator view_type() noexcept { return view(); } /** * @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. */ operator const_view_type() const noexcept { - return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); + return view(); } [[nodiscard]] constexpr auto size() const noexcept -> index_type { return this->view().size(); } From e3f52cea2a5f627928f4105f48c53108b5554956 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 21 Apr 2022 07:22:57 -0700 Subject: [PATCH 13/21] static extents tests, some variable renaming --- cpp/include/raft/mdarray.hpp | 50 +++++++++++++++++++++--------------- cpp/test/mdspan_utils.cu | 33 ++++++++++++++++-------- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 79675d70d6..2cdc9b13e6 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -44,6 +44,10 @@ using layout_c_contiguous = detail::stdex::layout_right; */ using layout_f_contiguous = detail::stdex::layout_left; +/** + * @\brief Template checks and helpers to determine if type T is an std::mdspan + * or a derived type + */ template struct __is_mdspan : std::false_type { }; @@ -65,12 +69,12 @@ struct __is_derived_mdspan -inline constexpr bool is_mdspan_v = std::disjunction_v<__is_mdspan>, - __is_derived_mdspan>>; +using __is_mdspan_t = std::disjunction<__is_mdspan>, + __is_derived_mdspan>>; + +template +inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. @@ -94,7 +98,7 @@ struct __is_device_mdspan : std::bool_constant -inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; +inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; /** * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. @@ -118,14 +122,19 @@ struct __is_host_mdspan : T::accessor_type::is_host_type { * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type */ template -inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; +inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan * or their derived types + * This is structured such that it will short-circuit if the type is not std::mdspan + * or a derived type, and otherwise it will check whether it is a raft::device_mdspan + * or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type */ template -inline constexpr bool is_host_or_device_mdspan_v = is_device_mdspan_v or is_host_mdspan_v; +inline constexpr bool is_mdspan_v = + std::conjunction_v<__is_mdspan_t, + std::disjunction<__is_device_mdspan, __is_host_mdspan>>; /** * @brief Modified from the c++ mdarray proposal @@ -744,11 +753,11 @@ auto make_device_vector(raft::handle_t const& handle, size_t n) } /** - * @brief - * + * @brief Flatten raft::host_mdspan into a 1-dim array view + * * @tparam host_mdspan_type Expected type raft::host_mdspan * @param h_mds raft::host_mdspan object - * @return raft::host_mdspan + * @return raft::host_mdspan */ template >* = nullptr> @@ -771,12 +780,12 @@ auto flatten(device_mdspan_type d_mds) } /** - * @brief - * + * @brief Flatten raft::mdarray into a 1-dim array view + * * @tparam mdarray_type Expected type raft::mdarray * @param mda raft::mdarray object * @return Either raft::host_mdspan or raft::device_mdspan depending on the underlying - * ContainerType + * ContainerPolicy */ template >* = nullptr> auto flatten(const mdarray_type& mda) @@ -833,8 +842,8 @@ auto flatten(const device_scalar& d_s) } /** - * @brief - * + * @brief Reshape raft::host_mdspan or raft::device_mdspan + * * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan * @tparam Extents raft::extents for dimensions * @param mds raft::host_mdspan or raft::device_mdspan object @@ -843,7 +852,7 @@ auto flatten(const device_scalar& d_s) */ template >* = nullptr> + std::enable_if_t>* = nullptr> auto reshape(mdspan_type mds, extents new_shape) { RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); @@ -861,13 +870,14 @@ auto reshape(mdspan_type mds, extents new_shape) } /** - * @brief - * + * @brief Reshape raft::mdarray + * * @tparam mdarray_type Expected type raft::mdarray * @tparam Extents raft::extents for dimensions * @param mda raft::mdarray object * @param new_shape Desired new shape of the input - * @return raft::host_mdspan or raft::device_mdspan, depending on ContainerPolicy + * @return raft::host_mdspan or raft::device_mdspan, depending on the underlying + * ContainerPolicy */ template ; + using three_d_extents = extents; using three_d_mdspan = device_mdspan; using d_mdspan = derived_device_mdspan; static_assert(std::is_same_v, device_mdspan>, "not same"); static_assert(std::is_same_v, - device_mdspan>>, + device_mdspan>>, "not same"); // Checking if types are mdspan, supposed to fail for std::vector @@ -72,7 +72,7 @@ void test_host_flatten() { // flatten 3d host mdspan { - using three_d_extents = stdex::extents; + using three_d_extents = extents; using three_d_mdarray = host_mdarray; three_d_extents extents{3, 3, 3}; @@ -86,7 +86,7 @@ void test_host_flatten() "layouts not the same"); ASSERT_EQ(flat_view.extents().rank(), 1); - ASSERT_EQ(flat_view.extent(0), 27); + ASSERT_EQ(flat_view.size(), mda.size()); } // flatten host vector @@ -118,7 +118,7 @@ void test_device_flatten() raft::handle_t handle{}; // flatten 3d device mdspan { - using three_d_extents = stdex::extents; + using three_d_extents = extents; using three_d_mdarray = device_mdarray; three_d_extents extents{3, 3, 3}; @@ -132,7 +132,7 @@ void test_device_flatten() "layouts not the same"); ASSERT_EQ(flat_view.extents().rank(), 1); - ASSERT_EQ(flat_view.extent(0), 27); + ASSERT_EQ(flat_view.size(), mda.size()); } // flatten device vector @@ -163,7 +163,7 @@ void test_reshape() { // reshape 3d host array to vector { - using three_d_extents = stdex::extents; + using three_d_extents = extents; using three_d_mdarray = host_mdarray; three_d_extents extents{3, 3, 3}; @@ -178,14 +178,13 @@ void test_reshape() "types not the same"); ASSERT_EQ(flat_view.extents().rank(), 1); - ASSERT_EQ(flat_view.extent(0), 27); + ASSERT_EQ(flat_view.size(), mda.size()); } // reshape 4d device array to 2d { raft::handle_t handle{}; - using four_d_extents = - stdex::extents; + using four_d_extents = extents; using four_d_mdarray = device_mdarray; four_d_extents extents{2, 2, 2, 2}; @@ -203,6 +202,20 @@ void test_reshape() ASSERT_EQ(matrix.extent(0), 4); ASSERT_EQ(matrix.extent(1), 4); } + + // reshape 2d host matrix with static extents to vector + { + using two_d_extents = extents<5, 5>; + using two_d_mdarray = host_mdarray; + + two_d_mdarray::container_policy_type policy; + two_d_mdarray mda{two_d_extents{}, policy}; + + auto vector = reshape(mda, raft::extents<25>{}); + + ASSERT_EQ(vector.extents().rank(), 1); + ASSERT_EQ(vector.size(), mda.size()); + } } TEST(MDArray, Reshape) { test_reshape(); } From 4fc4a5d3f7774a0c3602c72d4a46efe1e30125f1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 21 Apr 2022 10:05:59 -0700 Subject: [PATCH 14/21] working array_interface --- cpp/include/raft/core/mdarray.hpp | 137 +++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 39 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 92f4ce0fe6..a12c023108 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -47,30 +47,22 @@ using layout_f_contiguous = detail::stdex::layout_left; * @\brief Template checks and helpers to determine if type T is an std::mdspan * or a derived type */ -template -struct __is_mdspan : std::false_type { -}; - -template -struct __is_mdspan> : std::true_type { -}; template void __takes_an_mdspan_ptr( detail::stdex::mdspan*); template -struct __is_derived_mdspan : std::false_type { +struct __is_mdspan : std::false_type { }; template -struct __is_derived_mdspan()))>> +struct __is_mdspan()))>> : std::true_type { }; template -using __is_mdspan_t = std::disjunction<__is_mdspan>, - __is_derived_mdspan>>; +using __is_mdspan_t = __is_mdspan>; template inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; @@ -135,6 +127,86 @@ inline constexpr bool is_mdspan_v = std::conjunction_v<__is_mdspan_t, std::disjunction<__is_device_mdspan, __is_host_mdspan>>; +/** + * @brief Interface to implement an owning multi-dimensional array + * + * raft::array_interace is an interface to owning container types for mdspan. + * Check implementation of raft::mdarray which implements raft::array_interface. + * This interface provides virtual method `view()` which is to be overridden by + * the implementing class. `view()` must return an object of type raft::host_mdspan + * or raft::device_mdspan or any types derived from the them. + */ +template +class array_interface { + static_assert(!std::is_const::value, + "Element type for container must not be const."); + + public: + using extents_type = Extents; + using layout_type = LayoutPolicy; + using mapping_type = typename layout_type::template mapping; + using element_type = ElementType; + + using value_type = std::remove_cv_t; + using index_type = std::size_t; + using difference_type = std::ptrdiff_t; + // Naming: ref impl: container_policy_type, proposal: container_policy + using container_policy_type = ContainerPolicy; + using container_type = typename container_policy_type::container_type; + + using pointer = typename container_policy_type::pointer; + using const_pointer = typename container_policy_type::const_pointer; + using reference = typename container_policy_type::reference; + using const_reference = typename container_policy_type::const_reference; + + private: + template , + typename container_policy_type::const_accessor_policy, + typename container_policy_type::accessor_policy>> + using view_type_impl = + std::conditional_t, + device_mdspan>; + + public: + /** + * \brief the mdspan type returned by view method. + */ + using view_type = view_type_impl; + using const_view_type = view_type_impl; + + /** + * @brief Get a mdspan that can be passed down to CUDA kernels. + */ + virtual view_type view() noexcept = 0; + /** + * @brief Get a mdspan that can be passed down to CUDA kernels. + */ + virtual const_view_type view() const noexcept = 0; +}; + +template +void __takes_an_array_interface_ptr( + array_interface*); + +template +struct __is_array_interface : std::false_type { +}; + +template +struct __is_array_interface< + T, + std::void_t()))>> : std::true_type { +}; + +/** + * @\brief Boolean to determine if template type T is raft::array_interface or derived type + */ +template +inline constexpr bool is_array_interface_v = __is_array_interface>::value; + /** * @brief Modified from the c++ mdarray proposal * @@ -165,7 +237,7 @@ inline constexpr bool is_mdspan_v = * removed. */ template -class mdarray { +class mdarray : public array_interface { static_assert(!std::is_const::value, "Element type for container must not be const."); @@ -254,11 +326,11 @@ class mdarray { /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } + view_type view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - auto view() const noexcept + const_view_type view() const noexcept { return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); } @@ -381,20 +453,6 @@ class mdarray { container_type c_; }; -template -struct __is_mdarray : std::false_type { -}; - -template -struct __is_mdarray> : std::true_type { -}; - -/** - * @\brief Boolean to determine if template type T is raft::mdarray - */ -template -inline constexpr bool is_mdarray_v = __is_mdarray>::value; - /** * @brief mdarray with host container policy * @tparam ElementType the data type of the elements @@ -777,15 +835,16 @@ auto flatten(device_mdspan_type d_mds) } /** - * @brief Flatten raft::mdarray into a 1-dim array view + * @brief Flatten object implementing raft::array_interface into a 1-dim array view * - * @tparam mdarray_type Expected type raft::mdarray - * @param mda raft::mdarray object + * @tparam mdarray_type Expected type implementing raft::array_interface + * @param mda raft::array_interace implementing object * @return Either raft::host_mdspan or raft::device_mdspan depending on the underlying * ContainerPolicy */ -template >* = nullptr> -auto flatten(const mdarray_type& mda) +template >* = nullptr> +auto flatten(const array_interface_type& mda) { return flatten(mda.view()); } @@ -867,19 +926,19 @@ auto reshape(mdspan_type mds, extents new_shape) } /** - * @brief Reshape raft::mdarray + * @brief Reshape object implementing raft::array_interface * - * @tparam mdarray_type Expected type raft::mdarray + * @tparam mdarray_type Expected type implementing raft::array_interface * @tparam Extents raft::extents for dimensions - * @param mda raft::mdarray object + * @param mda raft::array_interace implementing object * @param new_shape Desired new shape of the input * @return raft::host_mdspan or raft::device_mdspan, depending on the underlying * ContainerPolicy */ -template >* = nullptr> -auto reshape(const mdarray_type& mda, extents new_shape) + std::enable_if_t>* = nullptr> +auto reshape(const array_interface_type& mda, extents new_shape) { return reshape(mda.view(), new_shape); } From f87d02f336e6954564a191b25ac0e8d0e07b8aed Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 21 Apr 2022 10:10:41 -0700 Subject: [PATCH 15/21] small fix to docstring --- cpp/include/raft/core/mdarray.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index a12c023108..632093eb4e 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -837,7 +837,7 @@ auto flatten(device_mdspan_type d_mds) /** * @brief Flatten object implementing raft::array_interface into a 1-dim array view * - * @tparam mdarray_type Expected type implementing raft::array_interface + * @tparam array_interface_type Expected type implementing raft::array_interface * @param mda raft::array_interace implementing object * @return Either raft::host_mdspan or raft::device_mdspan depending on the underlying * ContainerPolicy @@ -928,7 +928,7 @@ auto reshape(mdspan_type mds, extents new_shape) /** * @brief Reshape object implementing raft::array_interface * - * @tparam mdarray_type Expected type implementing raft::array_interface + * @tparam array_interface_type Expected type implementing raft::array_interface * @tparam Extents raft::extents for dimensions * @param mda raft::array_interace implementing object * @param new_shape Desired new shape of the input From 30d878c4fb6a9004479cca9af175a5dac632ee89 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 21 Apr 2022 10:16:21 -0700 Subject: [PATCH 16/21] removing unneeded aliases from array_interface --- cpp/include/raft/core/mdarray.hpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 632093eb4e..7b3ca95e73 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -144,21 +144,12 @@ class array_interface { public: using extents_type = Extents; using layout_type = LayoutPolicy; - using mapping_type = typename layout_type::template mapping; using element_type = ElementType; - using value_type = std::remove_cv_t; - using index_type = std::size_t; - using difference_type = std::ptrdiff_t; // Naming: ref impl: container_policy_type, proposal: container_policy using container_policy_type = ContainerPolicy; using container_type = typename container_policy_type::container_type; - using pointer = typename container_policy_type::pointer; - using const_pointer = typename container_policy_type::const_pointer; - using reference = typename container_policy_type::reference; - using const_reference = typename container_policy_type::const_reference; - private: template Date: Wed, 27 Apr 2022 09:12:19 -0700 Subject: [PATCH 17/21] using array_interface with CRTP --- cpp/include/raft/core/mdarray.hpp | 59 ++++++++----------------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 7b3ca95e73..e8bba6dc10 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -131,69 +131,37 @@ inline constexpr bool is_mdspan_v = * @brief Interface to implement an owning multi-dimensional array * * raft::array_interace is an interface to owning container types for mdspan. - * Check implementation of raft::mdarray which implements raft::array_interface. - * This interface provides virtual method `view()` which is to be overridden by + * Check implementation of raft::mdarray which implements raft::array_interface + * using Curiously Recurring Template Pattern. + * This interface calls into method `view()` whose implementation is provided by * the implementing class. `view()` must return an object of type raft::host_mdspan * or raft::device_mdspan or any types derived from the them. */ -template +template class array_interface { - static_assert(!std::is_const::value, - "Element type for container must not be const."); - - public: - using extents_type = Extents; - using layout_type = LayoutPolicy; - using element_type = ElementType; - - // Naming: ref impl: container_policy_type, proposal: container_policy - using container_policy_type = ContainerPolicy; - using container_type = typename container_policy_type::container_type; - - private: - template , - typename container_policy_type::const_accessor_policy, - typename container_policy_type::accessor_policy>> - using view_type_impl = - std::conditional_t, - device_mdspan>; - - public: - /** - * \brief the mdspan type returned by view method. - */ - using view_type = view_type_impl; - using const_view_type = view_type_impl; - /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - virtual view_type view() noexcept = 0; + auto view() noexcept { return static_cast(this)->view(); } /** * @brief Get a mdspan that can be passed down to CUDA kernels. */ - virtual const_view_type view() const noexcept = 0; + auto view() const noexcept { return static_cast(this)->view(); } }; -template -void __takes_an_array_interface_ptr( - array_interface*); - template struct __is_array_interface : std::false_type { }; template -struct __is_array_interface< - T, - std::void_t()))>> : std::true_type { +struct __is_array_interface().view())>> + : std::bool_constant().view())>> { }; /** * @\brief Boolean to determine if template type T is raft::array_interface or derived type + * or any type that has a member function `view()` that returns either + * raft::host_mdspan or raft::device_mdspan */ template inline constexpr bool is_array_interface_v = __is_array_interface>::value; @@ -228,7 +196,8 @@ inline constexpr bool is_array_interface_v = __is_array_interface -class mdarray : public array_interface { +class mdarray + : public array_interface> { static_assert(!std::is_const::value, "Element type for container must not be const."); @@ -317,11 +286,11 @@ class mdarray : public array_interface Date: Wed, 27 Apr 2022 09:15:22 -0700 Subject: [PATCH 18/21] remove implict operator converters of mdspan from mdarray --- cpp/include/raft/core/mdarray.hpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index e8bba6dc10..6562ee5115 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -295,15 +295,6 @@ class mdarray return const_view_type(c_.data(), map_, cp_.make_accessor_policy()); } - /** - * @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. - */ - operator view_type() noexcept { return view(); } - /** - * @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. - */ - operator const_view_type() const noexcept { return view(); } - [[nodiscard]] constexpr auto size() const noexcept -> index_type { return this->view().size(); } [[nodiscard]] auto data() noexcept -> pointer { return c_.data(); } From a407aee6a98957edf7a7eac2aa13d3e32dbbbf12 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 27 Apr 2022 09:33:10 -0700 Subject: [PATCH 19/21] remove flatten overloads --- cpp/include/raft/core/mdarray.hpp | 84 ++++++------------------------- cpp/test/mdspan_utils.cu | 4 -- 2 files changed, 15 insertions(+), 73 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 6562ee5115..652da24e19 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -759,30 +759,24 @@ auto make_device_vector(raft::handle_t const& handle, size_t n) } /** - * @brief Flatten raft::host_mdspan into a 1-dim array view + * @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view * - * @tparam host_mdspan_type Expected type raft::host_mdspan - * @param h_mds raft::host_mdspan object - * @return raft::host_mdspan + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @param mds raft::host_mdspan or raft::device_mdspan object + * @return raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on AccessoryPolicy */ -template >* = nullptr> -auto flatten(host_mdspan_type h_mds) +template >* = nullptr> +auto flatten(mdspan_type mds) { - RAFT_EXPECTS(h_mds.is_contiguous(), "Input must be contiguous."); + RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); - return make_host_vector_view(h_mds.data(), h_mds.size()); -} + detail::vector_extent ext{mds.size()}; -template >* = nullptr> -auto flatten(device_mdspan_type d_mds) -{ - RAFT_EXPECTS(d_mds.is_contiguous(), "Input must be contiguous."); - return make_device_vector_view(d_mds.data(), - d_mds.size()); + return detail::stdex::mdspan(mds.data(), ext); } /** @@ -790,8 +784,8 @@ auto flatten(device_mdspan_type d_mds) * * @tparam array_interface_type Expected type implementing raft::array_interface * @param mda raft::array_interace implementing object - * @return Either raft::host_mdspan or raft::device_mdspan depending on the underlying - * ContainerPolicy + * @return Either raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on the underlying ContainerPolicy */ template >* = nullptr> @@ -800,54 +794,6 @@ auto flatten(const array_interface_type& mda) return flatten(mda.view()); } -template -constexpr auto flatten(host_vector_view h_vv) -{ - return h_vv; -} - -template -auto flatten(const host_vector& h_v) -{ - return flatten(h_v.view()); -} - -template -constexpr auto flatten(host_scalar_view h_sv) -{ - return h_sv; -} - -template -auto flatten(const host_scalar& h_s) -{ - return flatten(h_s.view()); -} - -template -constexpr auto flatten(device_vector_view d_vv) -{ - return d_vv; -} - -template -auto flatten(const device_vector& d_v) -{ - return flatten(d_v.view()); -} - -template -constexpr auto flatten(device_scalar_view d_sv) -{ - return d_sv; -} - -template -auto flatten(const device_scalar& d_s) -{ - return flatten(d_s.view()); -} - /** * @brief Reshape raft::host_mdspan or raft::device_mdspan * diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index f9d60a694d..15388a5cef 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -105,8 +105,6 @@ void test_host_flatten() auto hs = make_host_scalar(27); auto flat_view = flatten(hs.view()); - static_assert(std::is_same_v, "types not the same"); - ASSERT_EQ(flat_view.extent(0), 1); } } @@ -151,8 +149,6 @@ void test_device_flatten() auto ds = make_device_scalar(27, handle.get_stream()); auto flat_view = flatten(ds.view()); - static_assert(std::is_same_v, "types not the same"); - ASSERT_EQ(flat_view.extent(0), 1); } } From 7848219a99ec7923a2cf6e93efab71e84943fa8b Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 28 Apr 2022 08:01:50 -0700 Subject: [PATCH 20/21] explicit cstddef include --- cpp/include/raft/core/mdarray.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 652da24e19..be6e0bb416 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -21,6 +21,9 @@ * limitations under the License. */ #pragma once + +#include + #include #include #include From 30d0b8b2a119c4c3c6cc251ec353a3988dc2e7ae Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 28 Apr 2022 08:09:00 -0700 Subject: [PATCH 21/21] stddef.h --- cpp/include/raft/core/mdarray.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index be6e0bb416..ab6a04587a 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -22,7 +22,7 @@ */ #pragma once -#include +#include #include #include