Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to convert mdspan to a const view #1188

Merged
merged 10 commits into from
Feb 1, 2023
57 changes: 43 additions & 14 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,24 +304,53 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,
}
}

/**
* @brief Const accessor specialization for default_accessor
*
* @tparam ElementType
* @param a
* @return std::experimental::default_accessor<std::add_const_t<ElementType>>
*/
template<class ElementType>
std::experimental::default_accessor<std::add_const_t<ElementType>>
accessor_of_const(std::experimental::default_accessor<ElementType> a)
{
return {a};
}

/**
* @brief Const accessor specialization for host_device_accessor
*
* @tparam ElementType the data type of the mdspan elements
* @tparam MemType the type of memory where the elements are stored.
* @param a host_device_accessor
* @return host_device_accessor<std::experimental::default_accessor<std::add_const_t<ElementType>>, MemType>
*/
template<class ElementType, memory_type MemType>
host_device_accessor<std::experimental::default_accessor<std::add_const_t<ElementType>>, MemType>
accessor_of_const(host_device_accessor<std::experimental::default_accessor<ElementType>, MemType> a)
{
return {a};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the correct way to do it, as long as host_device_accessor has an element type converting constructor like default_accessor does (see the first constructor here). A straightforward unit test would fail to compile without that constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point, @mhoemmen. It's nice to see the new function being used in the k-means tests but we should probably have a dedicated testcase for this (ideally in the mdspan cpp test file).

}

/**
* @brief Create a copy of the given mdspan with const element type
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan
* @param mds raft::host_mdspan or raft::device_mdspan object
* @return raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on AccessoryPolicy
*
* @tparam ElementType the data type of the mdspan elements
lowener marked this conversation as resolved.
Show resolved Hide resolved
* @tparam Extents raft::extents for dimensions
* @tparam Layout policy for strides and layout ordering
* @tparam Accessor Accessor policy for the input and output
* @param mds raft::mdspan object
* @return raft::mdspan
*/
template <typename mdspan_type, typename = enable_if_mdspan<mdspan_type>>
auto make_const_mdspan(mdspan_type mds)
template<class ElementType, class Extents, class Layout, class Accessor>
auto make_const_mdspan(mdspan<ElementType, Extents, Layout, Accessor> mds)
{
using const_element_t = std::add_const_t<typename mdspan_type::element_type>;
using const_accessor_t =
host_device_accessor<std::experimental::default_accessor<const_element_t>,
mdspan_type::accessor_type::mem_type>;
return std::experimental::mdspan<const_element_t,
typename mdspan_type::extents_type,
typename mdspan_type::layout_type,
const_accessor_t>(mds);
auto acc_c = accessor_of_const(mds.accessor());
return mdspan<std::add_const_t<ElementType>, Extents, Layout, decltype(acc_c)>{
mds.data_handle(),
mds.mapping(),
acc_c};
}

} // namespace raft