Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to convert mdspan to a const view #1188

Merged
merged 10 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cpp/bench/matrix/argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ struct Argmin : public fixture {
void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
raft::matrix::argmin(handle, matrix_const_view, indices.view());
raft::matrix::argmin(handle, raft::make_const_mdspan(matrix.view()), indices.view());
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
});
}

Expand Down
11 changes: 4 additions & 7 deletions cpp/bench/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,11 @@ struct Gather : public fixture {
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
auto map_const_view =
raft::make_device_vector_view<const MapT, IdxT>(map.data_handle(), map.extent(0));
auto matrix_const_view = raft::make_const_mdspan(matrix.view());
auto map_const_view = raft::make_const_mdspan(map.view());
if constexpr (Conditional) {
auto stencil_const_view =
raft::make_device_vector_view<const T, IdxT>(stencil.data_handle(), stencil.extent(0));
auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op());
auto stencil_const_view = raft::make_const_mdspan(stencil.view());
auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op());
raft::matrix::gather_if(
handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op);
} else {
Expand Down
50 changes: 49 additions & 1 deletion cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -304,4 +304,52 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,
}
}

/**
* @brief Const accessor specialization for default_accessor
*
* @tparam ElementType
* @param a
* @return std::experimental::default_accessor<std::add_const_t<ElementType>>
*/
template <class ElementType>
std::experimental::default_accessor<std::add_const_t<ElementType>> accessor_of_const(
std::experimental::default_accessor<ElementType> a)
{
return {a};
}

/**
* @brief Const accessor specialization for host_device_accessor
*
* @tparam ElementType the data type of the mdspan elements
* @tparam MemType the type of memory where the elements are stored.
* @param a host_device_accessor
* @return host_device_accessor<std::experimental::default_accessor<std::add_const_t<ElementType>>,
* MemType>
*/
template <class ElementType, memory_type MemType>
host_device_accessor<std::experimental::default_accessor<std::add_const_t<ElementType>>, MemType>
accessor_of_const(host_device_accessor<std::experimental::default_accessor<ElementType>, MemType> a)
{
return {a};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the correct way to do it, as long as host_device_accessor has an element type converting constructor like default_accessor does (see the first constructor here). A straightforward unit test would fail to compile without that constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point, @mhoemmen. It's nice to see the new function being used in the k-means tests but we should probably have a dedicated testcase for this (ideally in the mdspan cpp test file).

}

/**
* @brief Create a copy of the given mdspan with const element type
*
* @tparam ElementType the data type of the mdspan elements
lowener marked this conversation as resolved.
Show resolved Hide resolved
* @tparam Extents raft::extents for dimensions
* @tparam Layout policy for strides and layout ordering
* @tparam Accessor Accessor policy for the input and output
* @param mds raft::mdspan object
* @return raft::mdspan
*/
template <class ElementType, class Extents, class Layout, class Accessor>
auto make_const_mdspan(mdspan<ElementType, Extents, Layout, Accessor> mds)
{
auto acc_c = accessor_of_const(mds.accessor());
return mdspan<std::add_const_t<ElementType>, Extents, Layout, decltype(acc_c)>{
mds.data_handle(), mds.mapping(), acc_c};
}

} // namespace raft
18 changes: 6 additions & 12 deletions cpp/test/cluster/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
rmm::device_uvector<char> workspace(0, stream);
rmm::device_uvector<T> L2NormBuf_OR_DistBuf(0, stream);
rmm::device_uvector<T> inRankCp(0, stream);
auto X_view =
raft::make_device_matrix_view<const T, int>(X.data_handle(), X.extent(0), X.extent(1));
auto X_view = raft::make_const_mdspan(X.view());
auto centroids_view =
raft::make_device_matrix_view<T, int>(d_centroids.data(), params.n_clusters, n_features);
auto miniX = raft::make_device_matrix<T, int>(handle, n_samples / 4, n_features);
Expand All @@ -126,12 +125,8 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
miniX.extent(0),
params.rng_state.seed);

raft::cluster::kmeans::init_plus_plus(handle,
params,
raft::make_device_matrix_view<const T, int>(
miniX.data_handle(), miniX.extent(0), miniX.extent(1)),
centroids_view,
workspace);
raft::cluster::kmeans::init_plus_plus(
handle, params, raft::make_const_mdspan(miniX.view()), centroids_view, workspace);

auto minClusterDistance = raft::make_device_vector<T, int>(handle, n_samples);
auto minClusterAndDistance =
Expand Down Expand Up @@ -285,10 +280,9 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {

raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream);

T inertia = 0;
int n_iter = 0;
auto X_view =
raft::make_device_matrix_view<const T, int>(X.data_handle(), X.extent(0), X.extent(1));
T inertia = 0;
int n_iter = 0;
auto X_view = raft::make_const_mdspan(X.view());

raft::cluster::kmeans_fit_predict<T, int>(
handle,
Expand Down
3 changes: 3 additions & 0 deletions docs/source/cpp_api/mdspan_mdspan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ mdspan: Multi-dimensional Non-owning View
.. doxygenfunction:: raft::unravel_index
:project: RAFT

.. doxygenfunction:: raft::make_const_mdspan(mdspan_type mds)
:project: RAFT


Device Vocabulary
-----------------
Expand Down