-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mdspan/mdarray template functions and utilities #601
Changes from all commits
119d75b
c205a40
b6799f5
3451b07
1eb7565
fb2de54
6539936
eb400c6
27c5c19
ff282f3
b8aec4c
ceb7dd4
fa88e23
e3f52ce
95e66b9
de6bea6
4fc4a5d
f87d02f
30d878c
65e54a5
8659b39
66ade4b
a407aee
7848219
30d0b8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,12 +21,21 @@ | |
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <stddef.h> | ||
|
||
#include <experimental/mdspan> | ||
#include <raft/core/handle.hpp> | ||
#include <raft/detail/mdarray.hpp> | ||
#include <rmm/cuda_stream_view.hpp> | ||
|
||
namespace raft { | ||
/** | ||
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan | ||
*/ | ||
template <size_t... ExtentsPack> | ||
wphicks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using extents = std::experimental::extents<ExtentsPack...>; | ||
|
||
/** | ||
* @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. | ||
*/ | ||
|
@@ -37,6 +46,30 @@ using layout_c_contiguous = detail::stdex::layout_right; | |
*/ | ||
using layout_f_contiguous = detail::stdex::layout_left; | ||
|
||
/** | ||
* @\brief Template checks and helpers to determine if type T is an std::mdspan | ||
* or a derived type | ||
*/ | ||
|
||
template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy> | ||
void __takes_an_mdspan_ptr( | ||
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*); | ||
|
||
template <typename T, typename = void> | ||
struct __is_mdspan : std::false_type { | ||
}; | ||
|
||
template <typename T> | ||
struct __is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>> | ||
: std::true_type { | ||
}; | ||
|
||
template <typename T> | ||
using __is_mdspan_t = __is_mdspan<std::remove_const_t<T>>; | ||
|
||
template <typename T> | ||
inline constexpr bool __is_mdspan_v = __is_mdspan_t<T>::value; | ||
|
||
/** | ||
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. | ||
*/ | ||
|
@@ -57,6 +90,85 @@ template <typename ElementType, | |
using host_mdspan = | ||
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>; | ||
|
||
template <typename T, bool B> | ||
struct __is_device_mdspan : std::false_type { | ||
}; | ||
|
||
template <typename T> | ||
struct __is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> { | ||
}; | ||
|
||
/** | ||
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type | ||
*/ | ||
template <typename T> | ||
inline constexpr bool is_device_mdspan_v = __is_device_mdspan<T, __is_mdspan_v<T>>::value; | ||
|
||
template <typename T, bool B> | ||
struct __is_host_mdspan : std::false_type { | ||
}; | ||
|
||
template <typename T> | ||
struct __is_host_mdspan<T, true> : T::accessor_type::is_host_type { | ||
}; | ||
|
||
/** | ||
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type | ||
*/ | ||
template <typename T> | ||
inline constexpr bool is_host_mdspan_v = __is_host_mdspan<T, __is_mdspan_v<T>>::value; | ||
|
||
/** | ||
* @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan | ||
* or their derived types | ||
* This is structured such that it will short-circuit if the type is not std::mdspan | ||
* or a derived type, and otherwise it will check whether it is a raft::device_mdspan | ||
* or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type | ||
*/ | ||
template <typename T> | ||
inline constexpr bool is_mdspan_v = | ||
std::conjunction_v<__is_mdspan_t<T>, | ||
std::disjunction<__is_device_mdspan<T, true>, __is_host_mdspan<T, true>>>; | ||
|
||
/** | ||
* @brief Interface to implement an owning multi-dimensional array | ||
* | ||
* raft::array_interace is an interface to owning container types for mdspan. | ||
* Check implementation of raft::mdarray which implements raft::array_interface | ||
* using Curiously Recurring Template Pattern. | ||
* This interface calls into method `view()` whose implementation is provided by | ||
* the implementing class. `view()` must return an object of type raft::host_mdspan | ||
* or raft::device_mdspan or any types derived from the them. | ||
*/ | ||
template <typename Base> | ||
class array_interface { | ||
/** | ||
* @brief Get a mdspan that can be passed down to CUDA kernels. | ||
*/ | ||
auto view() noexcept { return static_cast<Base*>(this)->view(); } | ||
/** | ||
* @brief Get a mdspan that can be passed down to CUDA kernels. | ||
*/ | ||
auto view() const noexcept { return static_cast<Base*>(this)->view(); } | ||
}; | ||
|
||
template <typename T, typename = void> | ||
struct __is_array_interface : std::false_type { | ||
}; | ||
|
||
template <typename T> | ||
struct __is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>> | ||
: std::bool_constant<is_mdspan_v<decltype(std::declval<T>().view())>> { | ||
}; | ||
|
||
/** | ||
* @\brief Boolean to determine if template type T is raft::array_interface or derived type | ||
* or any type that has a member function `view()` that returns either | ||
* raft::host_mdspan or raft::device_mdspan | ||
*/ | ||
template <typename T> | ||
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose that we give users two options here:
Point 1 is already implemented, and point 2 allows users to update their existing their code minimally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Point 2 is solved with recent updates |
||
|
||
/** | ||
* @brief Modified from the c++ mdarray proposal | ||
* | ||
|
@@ -87,7 +199,8 @@ using host_mdspan = | |
* removed. | ||
*/ | ||
template <typename ElementType, typename Extents, typename LayoutPolicy, typename ContainerPolicy> | ||
class mdarray { | ||
class mdarray | ||
: public array_interface<mdarray<ElementType, Extents, LayoutPolicy, ContainerPolicy>> { | ||
static_assert(!std::is_const<ElementType>::value, | ||
"Element type for container must not be const."); | ||
|
||
|
@@ -340,15 +453,15 @@ using device_scalar = device_mdarray<ElementType, detail::scalar_extent>; | |
* @brief Shorthand for 1-dim host mdarray. | ||
* @tparam ElementType the data type of the vector elements | ||
*/ | ||
template <typename ElementType> | ||
using host_vector = host_mdarray<ElementType, detail::vector_extent>; | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
using host_vector = host_mdarray<ElementType, detail::vector_extent, LayoutPolicy>; | ||
|
||
/** | ||
* @brief Shorthand for 1-dim device mdarray. | ||
* @tparam ElementType the data type of the vector elements | ||
*/ | ||
template <typename ElementType> | ||
using device_vector = device_mdarray<ElementType, detail::vector_extent>; | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
using device_vector = device_mdarray<ElementType, detail::vector_extent, LayoutPolicy>; | ||
|
||
/** | ||
* @brief Shorthand for c-contiguous host matrix. | ||
|
@@ -384,15 +497,15 @@ using device_scalar_view = device_mdspan<ElementType, detail::scalar_extent>; | |
* @brief Shorthand for 1-dim host mdspan. | ||
* @tparam ElementType the data type of the vector elements | ||
*/ | ||
template <typename ElementType> | ||
using host_vector_view = host_mdspan<ElementType, detail::vector_extent>; | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
using host_vector_view = host_mdspan<ElementType, detail::vector_extent, LayoutPolicy>; | ||
|
||
/** | ||
* @brief Shorthand for 1-dim device mdspan. | ||
* @tparam ElementType the data type of the vector elements | ||
*/ | ||
template <typename ElementType> | ||
using device_vector_view = device_mdspan<ElementType, detail::vector_extent>; | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
using device_vector_view = device_mdspan<ElementType, detail::vector_extent, LayoutPolicy>; | ||
|
||
/** | ||
* @brief Shorthand for c-contiguous host matrix view. | ||
|
@@ -478,11 +591,11 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) | |
* @param[in] n number of elements in pointer | ||
* @return raft::host_vector_view | ||
*/ | ||
template <typename ElementType> | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
auto make_host_vector_view(ElementType* ptr, size_t n) | ||
{ | ||
detail::vector_extent extents{n}; | ||
return host_vector_view<ElementType>{ptr, extents}; | ||
return host_vector_view<ElementType, LayoutPolicy>{ptr, extents}; | ||
} | ||
|
||
/** | ||
|
@@ -492,11 +605,11 @@ auto make_host_vector_view(ElementType* ptr, size_t n) | |
* @param[in] n number of elements in pointer | ||
* @return raft::device_vector_view | ||
*/ | ||
template <typename ElementType> | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
auto make_device_vector_view(ElementType* ptr, size_t n) | ||
{ | ||
detail::vector_extent extents{n}; | ||
return device_vector_view<ElementType>{ptr, extents}; | ||
return device_vector_view<ElementType, LayoutPolicy>{ptr, extents}; | ||
} | ||
|
||
/** | ||
|
@@ -610,13 +723,13 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) | |
* @param[in] n number of elements in vector | ||
* @return raft::host_vector | ||
*/ | ||
template <typename ElementType> | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
auto make_host_vector(size_t n) | ||
{ | ||
detail::vector_extent extents{n}; | ||
using policy_t = typename host_vector<ElementType>::container_policy_type; | ||
using policy_t = typename host_vector<ElementType, LayoutPolicy>::container_policy_type; | ||
policy_t policy; | ||
return host_vector<ElementType>{extents, policy}; | ||
return host_vector<ElementType, LayoutPolicy>{extents, policy}; | ||
} | ||
|
||
/** | ||
|
@@ -626,13 +739,13 @@ auto make_host_vector(size_t n) | |
* @param[in] stream the cuda stream for ordering events | ||
* @return raft::device_vector | ||
*/ | ||
template <typename ElementType> | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
auto make_device_vector(size_t n, rmm::cuda_stream_view stream) | ||
{ | ||
detail::vector_extent extents{n}; | ||
using policy_t = typename device_vector<ElementType>::container_policy_type; | ||
using policy_t = typename device_vector<ElementType, LayoutPolicy>::container_policy_type; | ||
policy_t policy{stream}; | ||
return device_vector<ElementType>{extents, policy}; | ||
return device_vector<ElementType, LayoutPolicy>{extents, policy}; | ||
} | ||
|
||
/** | ||
|
@@ -642,9 +755,92 @@ auto make_device_vector(size_t n, rmm::cuda_stream_view stream) | |
* @param[in] n number of elements in vector | ||
* @return raft::device_vector | ||
*/ | ||
template <typename ElementType> | ||
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous> | ||
auto make_device_vector(raft::handle_t const& handle, size_t n) | ||
{ | ||
return make_device_vector<ElementType>(n, handle.get_stream()); | ||
return make_device_vector<ElementType, LayoutPolicy>(n, handle.get_stream()); | ||
} | ||
|
||
/** | ||
* @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view | ||
* | ||
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan | ||
* @param mds raft::host_mdspan or raft::device_mdspan object | ||
* @return raft::host_mdspan or raft::device_mdspan with vector_extent | ||
* depending on AccessoryPolicy | ||
*/ | ||
template <typename mdspan_type, std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr> | ||
auto flatten(mdspan_type mds) | ||
{ | ||
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); | ||
|
||
detail::vector_extent ext{mds.size()}; | ||
|
||
return detail::stdex::mdspan<typename mdspan_type::element_type, | ||
detail::vector_extent, | ||
typename mdspan_type::layout_type, | ||
typename mdspan_type::accessor_type>(mds.data(), ext); | ||
} | ||
|
||
/** | ||
* @brief Flatten object implementing raft::array_interface into a 1-dim array view | ||
* | ||
* @tparam array_interface_type Expected type implementing raft::array_interface | ||
* @param mda raft::array_interace implementing object | ||
* @return Either raft::host_mdspan or raft::device_mdspan with vector_extent | ||
* depending on the underlying ContainerPolicy | ||
*/ | ||
template <typename array_interface_type, | ||
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr> | ||
auto flatten(const array_interface_type& mda) | ||
{ | ||
return flatten(mda.view()); | ||
} | ||
|
||
/** | ||
* @brief Reshape raft::host_mdspan or raft::device_mdspan | ||
* | ||
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan | ||
* @tparam Extents raft::extents for dimensions | ||
* @param mds raft::host_mdspan or raft::device_mdspan object | ||
* @param new_shape Desired new shape of the input | ||
* @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy | ||
*/ | ||
template <typename mdspan_type, | ||
size_t... Extents, | ||
std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr> | ||
auto reshape(mdspan_type mds, extents<Extents...> new_shape) | ||
{ | ||
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); | ||
|
||
size_t new_size = 1; | ||
for (size_t i = 0; i < new_shape.rank(); ++i) { | ||
new_size *= new_shape.extent(i); | ||
} | ||
RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); | ||
|
||
return detail::stdex::mdspan<typename mdspan_type::element_type, | ||
decltype(new_shape), | ||
typename mdspan_type::layout_type, | ||
typename mdspan_type::accessor_type>(mds.data(), new_shape); | ||
} | ||
|
||
/** | ||
* @brief Reshape object implementing raft::array_interface | ||
* | ||
* @tparam array_interface_type Expected type implementing raft::array_interface | ||
* @tparam Extents raft::extents for dimensions | ||
* @param mda raft::array_interace implementing object | ||
* @param new_shape Desired new shape of the input | ||
* @return raft::host_mdspan or raft::device_mdspan, depending on the underlying | ||
* ContainerPolicy | ||
*/ | ||
template <typename array_interface_type, | ||
size_t... Extents, | ||
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr> | ||
auto reshape(const array_interface_type& mda, extents<Extents...> new_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this function accepts (or will potentially accept in the future) an owning type, and might change the size of the underlying memory buffer, one might want to accept a raft handle as a parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we are going to be explicit about size changes with a different function called |
||
{ | ||
return reshape(mda.view(), new_shape); | ||
} | ||
|
||
} // namespace raft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include
stddef.h
forsize_t
or includecstddef
and usestd::size_t
where possible