From 5f7382d2cb9d2a02a5d5d06375af685935fd7be2 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 25 Jan 2023 23:08:17 +0100 Subject: [PATCH 1/8] Add `make_const_mdspan` --- cpp/include/raft/core/mdspan.hpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 786ce69f89..ab2ec2b769 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -304,4 +304,20 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +/** + * @brief Create a copy of the given mdspan with const element type + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @param mds raft::host_mdspan or raft::device_mdspan object + * @return raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on AccessoryPolicy + */ +template > +auto make_const_mdspan(mdspan_type mds) +{ + return std::experimental::mdspan, + typename mdspan_type::extents_type, + typename mdspan_type::layout_type, + typename mdspan_type::accessor_type>(mds); +} + } // namespace raft From bcdb67fe3bced6966278066f22e620a0d55b7557 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 26 Jan 2023 12:21:12 +0100 Subject: [PATCH 2/8] Add const accessor --- cpp/include/raft/core/mdspan.hpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index ab2ec2b769..f3e0aa23e2 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -314,10 +314,14 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, template > auto make_const_mdspan(mdspan_type mds) { - return std::experimental::mdspan, + using const_element_t = std::add_const_t; + using const_accessor_t = + host_device_accessor, + mdspan_type::accessor_type::mem_type>; + return std::experimental::mdspan(mds); + const_accessor_t>(mds); } } // namespace raft From 7071cf9103e0d42c96d3660064ffdfae2a6cd320 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 26 Jan 2023 12:36:17 +0100 Subject: [PATCH 3/8] Add examples --- cpp/bench/matrix/argmin.cu | 4 +--- cpp/bench/matrix/gather.cu | 11 ++++------- cpp/test/cluster/kmeans.cu | 18 ++++++------------ 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 52f5aab7f3..3869f0c5e1 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -46,9 +46,7 @@ struct Argmin : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto matrix_const_view = raft::make_device_matrix_view( - matrix.data_handle(), matrix.extent(0), matrix.extent(1)); - raft::matrix::argmin(handle, matrix_const_view, indices.view()); + raft::matrix::argmin(handle, raft::make_const_mdspan(matrix.view()), indices.view()); }); } diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu index 97812c20a1..c5d80744cd 100644 --- a/cpp/bench/matrix/gather.cu +++ b/cpp/bench/matrix/gather.cu @@ -64,14 +64,11 @@ struct Gather : public fixture { state.SetLabel(label_stream.str()); loop_on_state(state, [this]() { - auto matrix_const_view = raft::make_device_matrix_view( - matrix.data_handle(), matrix.extent(0), matrix.extent(1)); - auto map_const_view = - raft::make_device_vector_view(map.data_handle(), map.extent(0)); + auto matrix_const_view = raft::make_const_mdspan(matrix.view()); + auto map_const_view = raft::make_const_mdspan(map.view()); if constexpr (Conditional) { - auto stencil_const_view = - raft::make_device_vector_view(stencil.data_handle(), stencil.extent(0)); - auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + auto stencil_const_view = raft::make_const_mdspan(stencil.view()); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index abc4cd6e13..43e96f8a59 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -112,8 +112,7 @@ class KmeansTest : public ::testing::TestWithParam> { rmm::device_uvector workspace(0, stream); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); rmm::device_uvector inRankCp(0, stream); - auto X_view = - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); + auto X_view = raft::make_const_mdspan(X.view()); auto centroids_view = raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); auto miniX = raft::make_device_matrix(handle, n_samples / 4, n_features); @@ -126,12 +125,8 @@ class KmeansTest : public ::testing::TestWithParam> { miniX.extent(0), params.rng_state.seed); - raft::cluster::kmeans::init_plus_plus(handle, - params, - raft::make_device_matrix_view( - miniX.data_handle(), miniX.extent(0), miniX.extent(1)), - centroids_view, - workspace); + raft::cluster::kmeans::init_plus_plus( + handle, params, raft::make_const_mdspan(miniX.view()), centroids_view, workspace); auto minClusterDistance = raft::make_device_vector(handle, n_samples); auto minClusterAndDistance = @@ -285,10 +280,9 @@ class KmeansTest : public ::testing::TestWithParam> { raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); - T inertia = 0; - int n_iter = 0; - auto X_view = - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); + T inertia = 0; + int n_iter = 0; + auto X_view = raft::make_const_mdspan(X.view()); raft::cluster::kmeans_fit_predict( handle, From 3e3facbe3a6f6aa01f4d25c634b16094a6b8c50b Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 26 Jan 2023 18:51:42 +0100 Subject: [PATCH 4/8] Add doxygen doc --- docs/source/cpp_api/mdspan_mdspan.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/cpp_api/mdspan_mdspan.rst b/docs/source/cpp_api/mdspan_mdspan.rst index 272a724833..619150f538 100644 --- a/docs/source/cpp_api/mdspan_mdspan.rst +++ b/docs/source/cpp_api/mdspan_mdspan.rst @@ -22,6 +22,9 @@ mdspan: Multi-dimensional Non-owning View .. doxygenfunction:: raft::unravel_index :project: RAFT +.. doxygenfunction:: raft::make_const_mdspan(mdspan_type mds) + :project: RAFT + Device Vocabulary ----------------- From 351b09ee3a2a66c2342749b85cfaaaef267a11d9 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 31 Jan 2023 18:16:30 +0100 Subject: [PATCH 5/8] Add `accessor_of_const` --- cpp/include/raft/core/mdspan.hpp | 57 ++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index f3e0aa23e2..3fbea56a3f 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -304,24 +304,53 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +/** + * @brief Const accessor specialization for default_accessor + * + * @tparam ElementType + * @param a + * @return std::experimental::default_accessor> + */ +template +std::experimental::default_accessor> +accessor_of_const(std::experimental::default_accessor a) +{ + return {a}; +} + +/** + * @brief Const accessor specialization for host_device_accessor + * + * @tparam ElementType the data type of the mdspan elements + * @tparam MemType the type of memory where the elements are stored. + * @param a host_device_accessor + * @return host_device_accessor>, MemType> + */ +template +host_device_accessor>, MemType> +accessor_of_const(host_device_accessor, MemType> a) +{ + return {a}; +} + /** * @brief Create a copy of the given mdspan with const element type - * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan - * @param mds raft::host_mdspan or raft::device_mdspan object - * @return raft::host_mdspan or raft::device_mdspan with vector_extent - * depending on AccessoryPolicy + * + * @tparam ElementType the data type of the mdspan elements + * @tparam Extents raft::extents for dimensions + * @tparam Layout policy for strides and layout ordering + * @tparam Accessor Accessor policy for the input and output + * @param mds raft::mdspan object + * @return raft::mdspan */ -template > -auto make_const_mdspan(mdspan_type mds) +template +auto make_const_mdspan(mdspan mds) { - using const_element_t = std::add_const_t; - using const_accessor_t = - host_device_accessor, - mdspan_type::accessor_type::mem_type>; - return std::experimental::mdspan(mds); + auto acc_c = accessor_of_const(mds.accessor()); + return mdspan, Extents, Layout, decltype(acc_c)>{ + mds.data_handle(), + mds.mapping(), + acc_c}; } } // namespace raft From 5bd10a10e35045a1d2cc81100e126a16f4d67539 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 31 Jan 2023 14:31:26 -0500 Subject: [PATCH 6/8] Fixing style --- cpp/include/raft/core/mdspan.hpp | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 3fbea56a3f..d6245a29fe 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -306,27 +306,28 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, /** * @brief Const accessor specialization for default_accessor - * - * @tparam ElementType - * @param a - * @return std::experimental::default_accessor> + * + * @tparam ElementType + * @param a + * @return std::experimental::default_accessor> */ -template -std::experimental::default_accessor> -accessor_of_const(std::experimental::default_accessor a) +template +std::experimental::default_accessor> accessor_of_const( + std::experimental::default_accessor a) { return {a}; } /** * @brief Const accessor specialization for host_device_accessor - * + * * @tparam ElementType the data type of the mdspan elements * @tparam MemType the type of memory where the elements are stored. * @param a host_device_accessor - * @return host_device_accessor>, MemType> + * @return host_device_accessor>, + * MemType> */ -template +template host_device_accessor>, MemType> accessor_of_const(host_device_accessor, MemType> a) { @@ -343,14 +344,12 @@ accessor_of_const(host_device_accessor +template auto make_const_mdspan(mdspan mds) { - auto acc_c = accessor_of_const(mds.accessor()); + auto acc_c = accessor_of_const(mds.accessor()); return mdspan, Extents, Layout, decltype(acc_c)>{ - mds.data_handle(), - mds.mapping(), - acc_c}; + mds.data_handle(), mds.mapping(), acc_c}; } } // namespace raft From dd541ae5a77f6c5c72e211119ce2e045c5e3e8d7 Mon Sep 17 00:00:00 2001 From: Micka Date: Wed, 1 Feb 2023 13:33:56 +0100 Subject: [PATCH 7/8] Update cpp/include/raft/core/mdspan.hpp Co-authored-by: Mark Hoemmen --- cpp/include/raft/core/mdspan.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index d6245a29fe..f805d20064 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -337,7 +337,7 @@ accessor_of_const(host_device_accessor Date: Wed, 1 Feb 2023 16:50:26 +0100 Subject: [PATCH 8/8] Add test for `make_const_mdspan` --- cpp/test/core/mdspan_utils.cu | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/cpp/test/core/mdspan_utils.cu b/cpp/test/core/mdspan_utils.cu index f428da4b31..4bb689c8c0 100644 --- a/cpp/test/core/mdspan_utils.cu +++ b/cpp/test/core/mdspan_utils.cu @@ -214,4 +214,31 @@ void test_reshape() TEST(MDArray, Reshape) { test_reshape(); } +void test_const_mdspan() +{ + // 3d host array + { + using two_d_extents = extents; + using two_d_mdarray = host_mdarray; + + typename two_d_mdarray::mapping_type layout{two_d_extents{}}; + typename two_d_mdarray::container_policy_type policy; + two_d_mdarray mda{layout, policy}; + + auto const_mda = make_const_mdspan(mda.view()); + + static_assert(std::is_same_v, + "elements not the same"); + static_assert(std::is_same_v, + "extents not the same"); + static_assert(std::is_same_v, + "layouts not the same"); + ASSERT_EQ(mda.size(), const_mda.size()); + } +} + +TEST(MDSpan, ConstMDSpan) { test_const_mdspan(); } + } // namespace raft \ No newline at end of file