Skip to content

Commit

Permalink
Add function to convert mdspan to a const view (rapidsai#1188)
Browse files Browse the repository at this point in the history
`make_const_mdspan` is a helper function to convert `mdspan<T>` into `mdspan<const T>`.
I added examples of it's usage
@mhoemmen @Nyrio

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Louis Sugy (https://github.com/Nyrio)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Mark Hoemmen (https://github.com/mhoemmen)

URL: rapidsai#1188
  • Loading branch information
lowener authored and ahendriksen committed Feb 2, 2023
1 parent 6307c86 commit 0c0c2d6
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 23 deletions.
4 changes: 1 addition & 3 deletions cpp/bench/matrix/argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ struct Argmin : public fixture {
void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
raft::matrix::argmin(handle, matrix_const_view, indices.view());
raft::matrix::argmin(handle, raft::make_const_mdspan(matrix.view()), indices.view());
});
}

Expand Down
11 changes: 4 additions & 7 deletions cpp/bench/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,11 @@ struct Gather : public fixture {
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
auto map_const_view =
raft::make_device_vector_view<const MapT, IdxT>(map.data_handle(), map.extent(0));
auto matrix_const_view = raft::make_const_mdspan(matrix.view());
auto map_const_view = raft::make_const_mdspan(map.view());
if constexpr (Conditional) {
auto stencil_const_view =
raft::make_device_vector_view<const T, IdxT>(stencil.data_handle(), stencil.extent(0));
auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op());
auto stencil_const_view = raft::make_const_mdspan(stencil.view());
auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op());
raft::matrix::gather_if(
handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op);
} else {
Expand Down
50 changes: 49 additions & 1 deletion cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -304,4 +304,52 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,
}
}

/**
* @brief Const accessor specialization for default_accessor
*
* @tparam ElementType
* @param a
* @return std::experimental::default_accessor<std::add_const_t<ElementType>>
*/
template <class ElementType>
std::experimental::default_accessor<std::add_const_t<ElementType>> accessor_of_const(
std::experimental::default_accessor<ElementType> a)
{
return {a};
}

/**
* @brief Const accessor specialization for host_device_accessor
*
* @tparam ElementType the data type of the mdspan elements
* @tparam MemType the type of memory where the elements are stored.
* @param a host_device_accessor
* @return host_device_accessor<std::experimental::default_accessor<std::add_const_t<ElementType>>,
* MemType>
*/
template <class ElementType, memory_type MemType>
host_device_accessor<std::experimental::default_accessor<std::add_const_t<ElementType>>, MemType>
accessor_of_const(host_device_accessor<std::experimental::default_accessor<ElementType>, MemType> a)
{
return {a};
}

/**
* @brief Create a copy of the given mdspan with const element type
*
* @tparam ElementType the const-qualified data type of the mdspan elements
* @tparam Extents raft::extents for dimensions
* @tparam Layout policy for strides and layout ordering
* @tparam Accessor Accessor policy for the input and output
* @param mds raft::mdspan object
* @return raft::mdspan
*/
template <class ElementType, class Extents, class Layout, class Accessor>
auto make_const_mdspan(mdspan<ElementType, Extents, Layout, Accessor> mds)
{
auto acc_c = accessor_of_const(mds.accessor());
return mdspan<std::add_const_t<ElementType>, Extents, Layout, decltype(acc_c)>{
mds.data_handle(), mds.mapping(), acc_c};
}

} // namespace raft
18 changes: 6 additions & 12 deletions cpp/test/cluster/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
rmm::device_uvector<char> workspace(0, stream);
rmm::device_uvector<T> L2NormBuf_OR_DistBuf(0, stream);
rmm::device_uvector<T> inRankCp(0, stream);
auto X_view =
raft::make_device_matrix_view<const T, int>(X.data_handle(), X.extent(0), X.extent(1));
auto X_view = raft::make_const_mdspan(X.view());
auto centroids_view =
raft::make_device_matrix_view<T, int>(d_centroids.data(), params.n_clusters, n_features);
auto miniX = raft::make_device_matrix<T, int>(handle, n_samples / 4, n_features);
Expand All @@ -126,12 +125,8 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
miniX.extent(0),
params.rng_state.seed);

raft::cluster::kmeans::init_plus_plus(handle,
params,
raft::make_device_matrix_view<const T, int>(
miniX.data_handle(), miniX.extent(0), miniX.extent(1)),
centroids_view,
workspace);
raft::cluster::kmeans::init_plus_plus(
handle, params, raft::make_const_mdspan(miniX.view()), centroids_view, workspace);

auto minClusterDistance = raft::make_device_vector<T, int>(handle, n_samples);
auto minClusterAndDistance =
Expand Down Expand Up @@ -285,10 +280,9 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {

raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream);

T inertia = 0;
int n_iter = 0;
auto X_view =
raft::make_device_matrix_view<const T, int>(X.data_handle(), X.extent(0), X.extent(1));
T inertia = 0;
int n_iter = 0;
auto X_view = raft::make_const_mdspan(X.view());

raft::cluster::kmeans_fit_predict<T, int>(
handle,
Expand Down
27 changes: 27 additions & 0 deletions cpp/test/core/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,31 @@ void test_reshape()

TEST(MDArray, Reshape) { test_reshape(); }

void test_const_mdspan()
{
// 3d host array
{
using two_d_extents = extents<int, 5, 5>;
using two_d_mdarray = host_mdarray<float, two_d_extents>;

typename two_d_mdarray::mapping_type layout{two_d_extents{}};
typename two_d_mdarray::container_policy_type policy;
two_d_mdarray mda{layout, policy};

auto const_mda = make_const_mdspan(mda.view());

static_assert(std::is_same_v<const float, typename decltype(const_mda)::element_type>,
"elements not the same");
static_assert(std::is_same_v<typename decltype(mda)::extents_type,
typename decltype(const_mda)::extents_type>,
"extents not the same");
static_assert(std::is_same_v<typename decltype(mda)::layout_type,
typename decltype(const_mda)::layout_type>,
"layouts not the same");
ASSERT_EQ(mda.size(), const_mda.size());
}
}

TEST(MDSpan, ConstMDSpan) { test_const_mdspan(); }

} // namespace raft
3 changes: 3 additions & 0 deletions docs/source/cpp_api/mdspan_mdspan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ mdspan: Multi-dimensional Non-owning View
.. doxygenfunction:: raft::unravel_index
:project: RAFT

.. doxygenfunction:: raft::make_const_mdspan(mdspan_type mds)
:project: RAFT


Device Vocabulary
-----------------
Expand Down

0 comments on commit 0c0c2d6

Please sign in to comment.