Skip to content

Commit

Permalink
mdspan/mdarray template functions and utilities (#601)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - William Hicks (https://github.com/wphicks)
  - Jiaming Yuan (https://github.com/trivialfis)

URL: #601
  • Loading branch information
divyegala authored Apr 28, 2022
1 parent 2a5934f commit 924c245
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 24 deletions.
238 changes: 217 additions & 21 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@
* limitations under the License.
*/
#pragma once

#include <stddef.h>

#include <experimental/mdspan>
#include <raft/core/handle.hpp>
#include <raft/detail/mdarray.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace raft {
/**
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
*/
template <size_t... ExtentsPack>
using extents = std::experimental::extents<ExtentsPack...>;

/**
* @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory.
*/
Expand All @@ -37,6 +46,30 @@ using layout_c_contiguous = detail::stdex::layout_right;
*/
using layout_f_contiguous = detail::stdex::layout_left;

/**
* @\brief Template checks and helpers to determine if type T is an std::mdspan
* or a derived type
*/

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void __takes_an_mdspan_ptr(
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*);

template <typename T, typename = void>
struct __is_mdspan : std::false_type {
};

template <typename T>
struct __is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::true_type {
};

template <typename T>
using __is_mdspan_t = __is_mdspan<std::remove_const_t<T>>;

template <typename T>
inline constexpr bool __is_mdspan_v = __is_mdspan_t<T>::value;

/**
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location.
*/
Expand All @@ -57,6 +90,85 @@ template <typename ElementType,
using host_mdspan =
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;

template <typename T, bool B>
struct __is_device_mdspan : std::false_type {
};

template <typename T>
struct __is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_device_mdspan_v = __is_device_mdspan<T, __is_mdspan_v<T>>::value;

template <typename T, bool B>
struct __is_host_mdspan : std::false_type {
};

template <typename T>
struct __is_host_mdspan<T, true> : T::accessor_type::is_host_type {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_host_mdspan_v = __is_host_mdspan<T, __is_mdspan_v<T>>::value;

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan
* or their derived types
* This is structured such that it will short-circuit if the type is not std::mdspan
* or a derived type, and otherwise it will check whether it is a raft::device_mdspan
* or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_mdspan_v =
std::conjunction_v<__is_mdspan_t<T>,
std::disjunction<__is_device_mdspan<T, true>, __is_host_mdspan<T, true>>>;

/**
* @brief Interface to implement an owning multi-dimensional array
*
* raft::array_interace is an interface to owning container types for mdspan.
* Check implementation of raft::mdarray which implements raft::array_interface
* using Curiously Recurring Template Pattern.
* This interface calls into method `view()` whose implementation is provided by
* the implementing class. `view()` must return an object of type raft::host_mdspan
* or raft::device_mdspan or any types derived from the them.
*/
template <typename Base>
class array_interface {
/**
* @brief Get a mdspan that can be passed down to CUDA kernels.
*/
auto view() noexcept { return static_cast<Base*>(this)->view(); }
/**
* @brief Get a mdspan that can be passed down to CUDA kernels.
*/
auto view() const noexcept { return static_cast<Base*>(this)->view(); }
};

template <typename T, typename = void>
struct __is_array_interface : std::false_type {
};

template <typename T>
struct __is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>>
: std::bool_constant<is_mdspan_v<decltype(std::declval<T>().view())>> {
};

/**
* @\brief Boolean to determine if template type T is raft::array_interface or derived type
* or any type that has a member function `view()` that returns either
* raft::host_mdspan or raft::device_mdspan
*/
template <typename T>
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value;

/**
* @brief Modified from the c++ mdarray proposal
*
Expand Down Expand Up @@ -87,7 +199,8 @@ using host_mdspan =
* removed.
*/
template <typename ElementType, typename Extents, typename LayoutPolicy, typename ContainerPolicy>
class mdarray {
class mdarray
: public array_interface<mdarray<ElementType, Extents, LayoutPolicy, ContainerPolicy>> {
static_assert(!std::is_const<ElementType>::value,
"Element type for container must not be const.");

Expand Down Expand Up @@ -340,15 +453,15 @@ using device_scalar = device_mdarray<ElementType, detail::scalar_extent>;
* @brief Shorthand for 1-dim host mdarray.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using host_vector = host_mdarray<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_vector = host_mdarray<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for 1-dim device mdarray.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using device_vector = device_mdarray<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_vector = device_mdarray<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous host matrix.
Expand Down Expand Up @@ -384,15 +497,15 @@ using device_scalar_view = device_mdspan<ElementType, detail::scalar_extent>;
* @brief Shorthand for 1-dim host mdspan.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using host_vector_view = host_mdspan<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_vector_view = host_mdspan<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for 1-dim device mdspan.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType>
using device_vector_view = device_mdspan<ElementType, detail::vector_extent>;
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_vector_view = device_mdspan<ElementType, detail::vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous host matrix view.
Expand Down Expand Up @@ -478,11 +591,11 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols)
* @param[in] n number of elements in pointer
* @return raft::host_vector_view
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector_view(ElementType* ptr, size_t n)
{
detail::vector_extent extents{n};
return host_vector_view<ElementType>{ptr, extents};
return host_vector_view<ElementType, LayoutPolicy>{ptr, extents};
}

/**
Expand All @@ -492,11 +605,11 @@ auto make_host_vector_view(ElementType* ptr, size_t n)
* @param[in] n number of elements in pointer
* @return raft::device_vector_view
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, size_t n)
{
detail::vector_extent extents{n};
return device_vector_view<ElementType>{ptr, extents};
return device_vector_view<ElementType, LayoutPolicy>{ptr, extents};
}

/**
Expand Down Expand Up @@ -610,13 +723,13 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v)
* @param[in] n number of elements in vector
* @return raft::host_vector
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector(size_t n)
{
detail::vector_extent extents{n};
using policy_t = typename host_vector<ElementType>::container_policy_type;
using policy_t = typename host_vector<ElementType, LayoutPolicy>::container_policy_type;
policy_t policy;
return host_vector<ElementType>{extents, policy};
return host_vector<ElementType, LayoutPolicy>{extents, policy};
}

/**
Expand All @@ -626,13 +739,13 @@ auto make_host_vector(size_t n)
* @param[in] stream the cuda stream for ordering events
* @return raft::device_vector
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector(size_t n, rmm::cuda_stream_view stream)
{
detail::vector_extent extents{n};
using policy_t = typename device_vector<ElementType>::container_policy_type;
using policy_t = typename device_vector<ElementType, LayoutPolicy>::container_policy_type;
policy_t policy{stream};
return device_vector<ElementType>{extents, policy};
return device_vector<ElementType, LayoutPolicy>{extents, policy};
}

/**
Expand All @@ -642,9 +755,92 @@ auto make_device_vector(size_t n, rmm::cuda_stream_view stream)
* @param[in] n number of elements in vector
* @return raft::device_vector
*/
template <typename ElementType>
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector(raft::handle_t const& handle, size_t n)
{
return make_device_vector<ElementType>(n, handle.get_stream());
return make_device_vector<ElementType, LayoutPolicy>(n, handle.get_stream());
}

/**
* @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view
*
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan
* @param mds raft::host_mdspan or raft::device_mdspan object
* @return raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on AccessoryPolicy
*/
template <typename mdspan_type, std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
auto flatten(mdspan_type mds)
{
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous.");

detail::vector_extent ext{mds.size()};

return detail::stdex::mdspan<typename mdspan_type::element_type,
detail::vector_extent,
typename mdspan_type::layout_type,
typename mdspan_type::accessor_type>(mds.data(), ext);
}

/**
* @brief Flatten object implementing raft::array_interface into a 1-dim array view
*
* @tparam array_interface_type Expected type implementing raft::array_interface
* @param mda raft::array_interace implementing object
* @return Either raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on the underlying ContainerPolicy
*/
template <typename array_interface_type,
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr>
auto flatten(const array_interface_type& mda)
{
return flatten(mda.view());
}

/**
* @brief Reshape raft::host_mdspan or raft::device_mdspan
*
* @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan
* @tparam Extents raft::extents for dimensions
* @param mds raft::host_mdspan or raft::device_mdspan object
* @param new_shape Desired new shape of the input
* @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy
*/
template <typename mdspan_type,
size_t... Extents,
std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
auto reshape(mdspan_type mds, extents<Extents...> new_shape)
{
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous.");

size_t new_size = 1;
for (size_t i = 0; i < new_shape.rank(); ++i) {
new_size *= new_shape.extent(i);
}
RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch");

return detail::stdex::mdspan<typename mdspan_type::element_type,
decltype(new_shape),
typename mdspan_type::layout_type,
typename mdspan_type::accessor_type>(mds.data(), new_shape);
}

/**
* @brief Reshape object implementing raft::array_interface
*
* @tparam array_interface_type Expected type implementing raft::array_interface
* @tparam Extents raft::extents for dimensions
* @param mda raft::array_interace implementing object
* @param new_shape Desired new shape of the input
* @return raft::host_mdspan or raft::device_mdspan, depending on the underlying
* ContainerPolicy
*/
template <typename array_interface_type,
size_t... Extents,
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr>
auto reshape(const array_interface_type& mda, extents<Extents...> new_shape)
{
return reshape(mda.view(), new_shape);
}

} // namespace raft
3 changes: 2 additions & 1 deletion cpp/include/raft/detail/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#pragma once
#include <experimental/mdspan>
#include <raft/cudart_utils.h>
#include <raft/detail/span.hpp> // dynamic_extent
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -58,7 +59,7 @@ class device_reference {
auto operator=(T const& other) -> device_reference&
{
auto* raw = ptr_.get();
update_device(raw, &other, 1, stream_);
raft::update_device(raw, &other, 1, stream_);
return *this;
}
};
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void distance(raft::handle_t const& handle,
RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous.");

auto is_rowmajor = std::is_same<layout, layout_c_contiguous>::value;
constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data(),
y.data(),
Expand Down Expand Up @@ -433,7 +433,7 @@ void pairwise_distance(raft::handle_t const& handle,
RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous.");
RAFT_EXPECTS(dist.is_contiguous(), "Output must be contiguous.");

bool rowmajor = x.stride(0) == 0;
constexpr auto rowmajor = std::is_same_v<layout, layout_c_contiguous>;

rmm::device_uvector<char> workspace(0, handle.get_stream());

Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ add_executable(test_raft
test/matrix/columnSort.cu
test/matrix/linewise_op.cu
test/mdarray.cu
test/mdspan_utils.cu
test/mst.cu
test/random/make_blobs.cu
test/random/make_regression.cu
Expand Down
Loading

0 comments on commit 924c245

Please sign in to comment.