Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mdspan/mdarray template functions and utilities #601

Merged
merged 25 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
119d75b
working through is_device_mdspan
divyegala Mar 30, 2022
c205a40
Merge remote-tracking branch 'upstream/branch-22.06' into imp-22.06-l…
divyegala Mar 30, 2022
b6799f5
specializations for is_mdspan
divyegala Mar 30, 2022
3451b07
flatten and tests
divyegala Mar 30, 2022
1eb7565
checking for derived mdspan
divyegala Apr 1, 2022
fb2de54
working flatten for mdarray and mdspan
divyegala Apr 6, 2022
6539936
add copyright
divyegala Apr 6, 2022
eb400c6
working through reshape
divyegala Apr 19, 2022
27c5c19
finishing up host/device flatten with tests
divyegala Apr 19, 2022
ff282f3
working host reshape with tests
divyegala Apr 19, 2022
b8aec4c
working reshape for device arrays
divyegala Apr 19, 2022
ceb7dd4
adding docstrings
divyegala Apr 19, 2022
fa88e23
Apply suggestions from code review
divyegala Apr 20, 2022
e3f52ce
static extents tests, some variable renaming
divyegala Apr 21, 2022
95e66b9
Merge remote-tracking branch 'origin/fea-22.06-mdspan_utils' into fea…
divyegala Apr 21, 2022
de6bea6
merging upstream
divyegala Apr 21, 2022
4fc4a5d
working array_interface
divyegala Apr 21, 2022
f87d02f
small fix to docstring
divyegala Apr 21, 2022
30d878c
removing unneeded aliases from array_interface
divyegala Apr 21, 2022
65e54a5
using array_interface with CRTP
divyegala Apr 27, 2022
8659b39
Merge remote-tracking branch 'upstream/branch-22.06' into fea-22.06-m…
divyegala Apr 27, 2022
66ade4b
remove implict operator converters of mdspan from mdarray
divyegala Apr 27, 2022
a407aee
remove flatten overloads
divyegala Apr 27, 2022
7848219
explicit cstddef include
divyegala Apr 28, 2022
30d0b8b
stddef.h
divyegala Apr 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 217 additions & 21 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@
* limitations under the License.
*/
#pragma once

#include <stddef.h>

#include <experimental/mdspan>
#include <raft/core/handle.hpp>
#include <raft/detail/mdarray.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace raft {
/**
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
*/
template <size_t... ExtentsPack>
Copy link
Contributor

@wphicks wphicks Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include stddef.h for size_t or include cstddef and use std::size_t where possible

wphicks marked this conversation as resolved.
Show resolved Hide resolved
using extents = std::experimental::extents<ExtentsPack...>;

/**
* @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory.
*/
Expand All @@ -37,6 +46,30 @@ using layout_c_contiguous = detail::stdex::layout_right;
*/
using layout_f_contiguous = detail::stdex::layout_left;

/**
* @\brief Template checks and helpers to determine if type T is an std::mdspan
* or a derived type
*/

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void __takes_an_mdspan_ptr(
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*);

template <typename T, typename = void>
struct __is_mdspan : std::false_type {
};

template <typename T>
struct __is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::true_type {
};

template <typename T>
using __is_mdspan_t = __is_mdspan<std::remove_const_t<T>>;

template <typename T>
inline constexpr bool __is_mdspan_v = __is_mdspan_t<T>::value;

/**
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location.
*/
Expand All @@ -57,6 +90,85 @@ template <typename ElementType,
using host_mdspan =
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;

template <typename T, bool B>
struct __is_device_mdspan : std::false_type {
};

template <typename T>
struct __is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_device_mdspan_v = __is_device_mdspan<T, __is_mdspan_v<T>>::value;

template <typename T, bool B>
struct __is_host_mdspan : std::false_type {
};

template <typename T>
struct __is_host_mdspan<T, true> : T::accessor_type::is_host_type {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_host_mdspan_v = __is_host_mdspan<T, __is_mdspan_v<T>>::value;

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan
* or their derived types
* This is structured such that it will short-circuit if the type is not std::mdspan
* or a derived type, and otherwise it will check whether it is a raft::device_mdspan
* or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_mdspan_v =
std::conjunction_v<__is_mdspan_t<T>,
std::disjunction<__is_device_mdspan<T, true>, __is_host_mdspan<T, true>>>;

/**
* @brief Interface to implement an owning multi-dimensional array
*
* raft::array_interace is an interface to owning container types for mdspan.
* Check implementation of raft::mdarray which implements raft::array_interface
* using Curiously Recurring Template Pattern.
* This interface calls into method `view()` whose implementation is provided by
* the implementing class. `view()` must return an object of type raft::host_mdspan
* or raft::device_mdspan or any types derived from the them.
*/
template <typename Base>
class array_interface {
/**
* @brief Get a mdspan that can be passed down to CUDA kernels.
*/
auto view() noexcept { return static_cast<Base*>(this)->view(); }
/**
* @brief Get a mdspan that can be passed down to CUDA kernels.
*/
auto view() const noexcept { return static_cast<Base*>(this)->view(); }
};

template <typename T, typename = void>
struct __is_array_interface : std::false_type {
};

template <typename T>
struct __is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>>
: std::bool_constant<is_mdspan_v<decltype(std::declval<T>().view())>> {
};

/**
* @\brief Boolean to determine if template type T is raft::array_interface or derived type
* or any type that has a member function `view()` that returns either
* raft::host_mdspan or raft::device_mdspan
*/
template <typename T>
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value;
Copy link
Member Author

@divyegala divyegala Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose that we give users two options here:

  1. They directly implement array_interface
  2. For users with existing owning types, we just ask that they add a view() method that returns a host_mdspan or device_mdspan or a derived type

Point 1 is already implemented, and point 2 allows users to update their existing their code minimally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point 2 is solved with recent updates


/**
* @brief Modified from the c++ mdarray proposal
*
Expand Down Expand Up @@ -87,7 +199,8 @@ using host_mdspan =
* removed.
*/
template <typename ElementType, typename Extents, typename LayoutPolicy, typename ContainerPolicy>
class mdarray {
class mdarray
: public array_interface<mdarray<ElementType, Extents, LayoutPolicy, ContainerPolicy>> {
static_assert(!std::is_const<ElementType>::value,
"Element type for container must not be const.");

Expand Down Expand Up @@ -340,15 +453,15 @@ using device_scalar = device_mdarray<ElementType, detail::scalar_extent>;
* @brief Shorthand for 1-dim host mdarray.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using host_vector = host_mdarray<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_vector = host_mdarray<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for 1-dim device mdarray.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using device_vector = device_mdarray<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_vector = device_mdarray<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous host matrix.
Expand Down Expand Up @@ -384,15 +497,15 @@ using device_scalar_view = device_mdspan<ElementType, detail::scalar_extent>;
* @brief Shorthand for 1-dim host mdspan.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using host_vector_view = host_mdspan<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_vector_view = host_mdspan<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for 1-dim device mdspan.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using device_vector_view = device_mdspan<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_vector_view = device_mdspan<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous host matrix view.
Expand Down Expand Up @@ -478,11 +591,11 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols)
* @param[in] n number of elements in pointer
* @return raft::host_vector_view
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector_view(ElementType* ptr, size_t n)
{
detail::vector_extent extents{n};
return host_vector_view<ElementType>{ptr, extents};
return host_vector_view<ElementType, LayoutPolicy>{ptr, extents};
}

/**
Expand All @@ -492,11 +605,11 @@ auto make_host_vector_view(ElementType* ptr, size_t n)
* @param[in] n number of elements in pointer
* @return raft::device_vector_view
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, size_t n)
{
detail::vector_extent extents{n};
return device_vector_view<ElementType>{ptr, extents};
return device_vector_view<ElementType, LayoutPolicy>{ptr, extents};
}

/**
Expand Down Expand Up @@ -610,13 +723,13 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v)
* @param[in] n number of elements in vector
* @return raft::host_vector
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector(size_t n)
{
detail::vector_extent extents{n};
using policy_t = typename host_vector<ElementType>::container_policy_type;
using policy_t = typename host_vector<ElementType, LayoutPolicy>::container_policy_type;
policy_t policy;
return host_vector<ElementType>{extents, policy};
return host_vector<ElementType, LayoutPolicy>{extents, policy};
}

/**
Expand All @@ -626,13 +739,13 @@ auto make_host_vector(size_t n)
* @param[in] stream the cuda stream for ordering events
* @return raft::device_vector
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector(size_t n, rmm::cuda_stream_view stream)
{
detail::vector_extent extents{n};
using policy_t = typename device_vector<ElementType>::container_policy_type;
using policy_t = typename device_vector<ElementType, LayoutPolicy>::container_policy_type;
policy_t policy{stream};
return device_vector<ElementType>{extents, policy};
return device_vector<ElementType, LayoutPolicy>{extents, policy};
}

/**
Expand All @@ -642,9 +755,92 @@ auto make_device_vector(size_t n, rmm::cuda_stream_view stream)
* @param[in] n number of elements in vector
* @return raft::device_vector
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector(raft::handle_t const& handle, size_t n)
{
return make_device_vector<ElementType>(n, handle.get_stream());
return make_device_vector<ElementType, LayoutPolicy>(n, handle.get_stream());
}

/**
* @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view
*
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan
* @param mds raft::host_mdspan or raft::device_mdspan object
* @return raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on AccessoryPolicy
*/
template <typename mdspan_type, std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
auto flatten(mdspan_type mds)
{
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous.");

detail::vector_extent ext{mds.size()};

return detail::stdex::mdspan<typename mdspan_type::element_type,
detail::vector_extent,
typename mdspan_type::layout_type,
typename mdspan_type::accessor_type>(mds.data(), ext);
}

/**
* @brief Flatten object implementing raft::array_interface into a 1-dim array view
*
* @tparam array_interface_type Expected type implementing raft::array_interface
* @param mda raft::array_interace implementing object
* @return Either raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on the underlying ContainerPolicy
*/
template <typename array_interface_type,
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr>
auto flatten(const array_interface_type& mda)
{
return flatten(mda.view());
}

/**
* @brief Reshape raft::host_mdspan or raft::device_mdspan
*
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan
* @tparam Extents raft::extents for dimensions
* @param mds raft::host_mdspan or raft::device_mdspan object
* @param new_shape Desired new shape of the input
* @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy
*/
template <typename mdspan_type,
size_t... Extents,
std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
auto reshape(mdspan_type mds, extents<Extents...> new_shape)
{
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous.");

size_t new_size = 1;
for (size_t i = 0; i < new_shape.rank(); ++i) {
new_size *= new_shape.extent(i);
}
RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch");

return detail::stdex::mdspan<typename mdspan_type::element_type,
decltype(new_shape),
typename mdspan_type::layout_type,
typename mdspan_type::accessor_type>(mds.data(), new_shape);
}

/**
* @brief Reshape object implementing raft::array_interface
*
* @tparam array_interface_type Expected type implementing raft::array_interface
* @tparam Extents raft::extents for dimensions
* @param mda raft::array_interace implementing object
* @param new_shape Desired new shape of the input
* @return raft::host_mdspan or raft::device_mdspan, depending on the underlying
* ContainerPolicy
*/
template <typename array_interface_type,
size_t... Extents,
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr>
auto reshape(const array_interface_type& mda, extents<Extents...> new_shape)
Copy link
Member

@trivialfis trivialfis Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this function accepts (or will potentially accept in the future) an owning type, and might change the size of the underlying memory buffer, one might want to accept a raft handle as a parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are going to be explicit about size changes with a different function called resize. Either way, wouldn't that have to be a member function to be able to access the underlying buffer?

{
return reshape(mda.view(), new_shape);
}

} // namespace raft
3 changes: 2 additions & 1 deletion cpp/include/raft/detail/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#pragma once
#include <experimental/mdspan>
#include <raft/cudart_utils.h>
#include <raft/detail/span.hpp> // dynamic_extent
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -58,7 +59,7 @@ class device_reference {
auto operator=(T const& other) -> device_reference&
{
auto* raw = ptr_.get();
update_device(raw, &other, 1, stream_);
raft::update_device(raw, &other, 1, stream_);
return *this;
}
};
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void distance(raft::handle_t const& handle,
RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous.");

auto is_rowmajor = std::is_same<layout, layout_c_contiguous>::value;
constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data(),
y.data(),
Expand Down Expand Up @@ -433,7 +433,7 @@ void pairwise_distance(raft::handle_t const& handle,
RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous.");
RAFT_EXPECTS(dist.is_contiguous(), "Output must be contiguous.");

bool rowmajor = x.stride(0) == 0;
constexpr auto rowmajor = std::is_same_v<layout, layout_c_contiguous>;

rmm::device_uvector<char> workspace(0, handle.get_stream());

Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ add_executable(test_raft
test/matrix/columnSort.cu
test/matrix/linewise_op.cu
test/mdarray.cu
test/mdspan_utils.cu
test/mst.cu
test/random/make_blobs.cu
test/random/make_regression.cu
Expand Down
Loading