diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 4465de21e7..a4f6ca67b1 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -27,7 +27,9 @@ #include #include #include + #include +#include namespace raft { /** @@ -37,14 +39,40 @@ template using extents = std::experimental::extents; /** - * @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. + * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. + * @{ */ -using layout_c_contiguous = detail::stdex::layout_right; +using detail::stdex::layout_right; +using layout_c_contiguous = layout_right; +using row_major = layout_right; +/** @} */ /** - * @\brief F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. + * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. + * @{ */ -using layout_f_contiguous = detail::stdex::layout_left; +using detail::stdex::layout_left; +using layout_f_contiguous = layout_left; +using col_major = layout_left; +/** @} */ + +/** + * @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension + * is known at run time (dynamic_extent in each dimension). + * @{ + */ +using detail::matrix_extent; +using detail::scalar_extent; +using detail::vector_extent; + +using extent_1d = vector_extent; +using extent_2d = matrix_extent; +using extent_3d = detail::stdex::extents; +using extent_4d = + detail::stdex::extents; +using extent_5d = detail::stdex:: + extents; +/** @} */ template -using host_scalar = host_mdarray; +using host_scalar = host_mdarray; /** * @brief Shorthand for 0-dim host mdarray (scalar). * @tparam ElementType the data type of the scalar element */ template -using device_scalar = device_mdarray; +using device_scalar = device_mdarray; /** * @brief Shorthand for 1-dim host mdarray. * @tparam ElementType the data type of the vector elements */ template -using host_vector = host_mdarray; +using host_vector = host_mdarray; /** * @brief Shorthand for 1-dim device mdarray. * @tparam ElementType the data type of the vector elements */ template -using device_vector = device_mdarray; +using device_vector = device_mdarray; /** * @brief Shorthand for c-contiguous host matrix. @@ -540,7 +568,7 @@ using device_vector = device_mdarray -using host_matrix = host_mdarray; +using host_matrix = host_mdarray; /** * @brief Shorthand for c-contiguous device matrix. @@ -548,35 +576,35 @@ using host_matrix = host_mdarray -using device_matrix = device_mdarray; +using device_matrix = device_mdarray; /** * @brief Shorthand for 0-dim host mdspan (scalar). * @tparam ElementType the data type of the scalar element */ template -using host_scalar_view = host_mdspan; +using host_scalar_view = host_mdspan; /** * @brief Shorthand for 0-dim host mdspan (scalar). * @tparam ElementType the data type of the scalar element */ template -using device_scalar_view = device_mdspan; +using device_scalar_view = device_mdspan; /** * @brief Shorthand for 1-dim host mdspan. * @tparam ElementType the data type of the vector elements */ template -using host_vector_view = host_mdspan; +using host_vector_view = host_mdspan; /** * @brief Shorthand for 1-dim device mdspan. * @tparam ElementType the data type of the vector elements */ template -using device_vector_view = device_mdspan; +using device_vector_view = device_mdspan; /** * @brief Shorthand for c-contiguous host matrix view. @@ -585,7 +613,7 @@ using device_vector_view = device_mdspan -using host_matrix_view = host_mdspan; +using host_matrix_view = host_mdspan; /** * @brief Shorthand for c-contiguous device matrix view. @@ -594,7 +622,7 @@ using host_matrix_view = host_mdspan -using device_matrix_view = device_mdspan; +using device_matrix_view = device_mdspan; /** * @brief Create a 0-dim (scalar) mdspan instance for host value. @@ -605,7 +633,7 @@ using device_matrix_view = device_mdspan auto make_host_scalar_view(ElementType* ptr) { - detail::scalar_extent extents; + scalar_extent extents; return host_scalar_view{ptr, extents}; } @@ -618,7 +646,7 @@ auto make_host_scalar_view(ElementType* ptr) template auto make_device_scalar_view(ElementType* ptr) { - detail::scalar_extent extents; + scalar_extent extents; return device_scalar_view{ptr, extents}; } @@ -635,7 +663,7 @@ auto make_device_scalar_view(ElementType* ptr) template auto make_host_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) { - detail::matrix_extent extents{n_rows, n_cols}; + matrix_extent extents{n_rows, n_cols}; return host_matrix_view{ptr, extents}; } /** @@ -651,7 +679,7 @@ auto make_host_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) template auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) { - detail::matrix_extent extents{n_rows, n_cols}; + matrix_extent extents{n_rows, n_cols}; return device_matrix_view{ptr, extents}; } @@ -665,7 +693,7 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) template auto make_host_vector_view(ElementType* ptr, size_t n) { - detail::vector_extent extents{n}; + vector_extent extents{n}; return host_vector_view{ptr, extents}; } @@ -679,10 +707,84 @@ auto make_host_vector_view(ElementType* ptr, size_t n) template auto make_device_vector_view(ElementType* ptr, size_t n) { - detail::vector_extent extents{n}; + vector_extent extents{n}; return device_vector_view{ptr, extents}; } +/** + * @brief Create a host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param exts dimensionality of the array (series of integers) + * @return raft::host_mdarray + */ +template > +auto make_host_mdarray(Extents... exts) +{ + using extent_t = extents<((void)exts, dynamic_extent)...>; + using mdarray_t = host_mdarray; + + typename mdarray_t::extents_type extent{exts...}; + typename mdarray_t::mapping_type layout{extent}; + typename mdarray_t::container_policy_type policy; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param stream cuda stream for ordering events + * @param exts dimensionality of the array (series of integers) + * @return raft::device_mdarray + */ +template > +auto make_device_mdarray(rmm::cuda_stream_view stream, Extents... exts) +{ + using extent_t = extents<((void)exts, dynamic_extent)...>; + using mdarray_t = device_mdarray; + + typename mdarray_t::extents_type extent{exts...}; + typename mdarray_t::mapping_type layout{extent}; + typename mdarray_t::container_policy_type policy{stream}; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param stream cuda stream for ordering events + * @param mr rmm memory resource used for allocating the memory for the array + * @param exts dimensionality of the array (series of integers) + * @return raft::device_mdarray + */ +template > +auto make_device_mdarray(rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr, + Extents... exts) +{ + using extent_t = extents<((void)exts, dynamic_extent)...>; + using mdarray_t = device_mdarray; + + typename mdarray_t::extents_type extent{exts...}; + typename mdarray_t::mapping_type layout{extent}; + typename mdarray_t::container_policy_type policy{stream, mr}; + + return mdarray_t{layout, policy}; +} + /** * @brief Create a 2-dim c-contiguous host mdarray. * @tparam ElementType the data type of the matrix elements @@ -694,10 +796,7 @@ auto make_device_vector_view(ElementType* ptr, size_t n) template auto make_host_matrix(size_t n_rows, size_t n_cols) { - detail::matrix_extent extents{n_rows, n_cols}; - using policy_t = typename host_matrix::container_policy_type; - policy_t policy; - return host_matrix{extents, policy}; + return make_host_mdarray(n_rows, n_cols); } /** @@ -712,10 +811,7 @@ auto make_host_matrix(size_t n_rows, size_t n_cols) template auto make_device_matrix(size_t n_rows, size_t n_cols, rmm::cuda_stream_view stream) { - detail::matrix_extent extents{n_rows, n_cols}; - using policy_t = typename device_matrix::container_policy_type; - policy_t policy{stream}; - return device_matrix{extents, policy}; + return make_device_mdarray(stream, n_rows, n_cols); } /** @@ -747,7 +843,7 @@ auto make_host_scalar(ElementType const& v) // FIXME(jiamingy): We can optimize this by using std::array as container policy, which // requires some more compile time dispatching. This is enabled in the ref impl but // hasn't been ported here yet. - detail::scalar_extent extents; + scalar_extent extents; using policy_t = typename host_scalar::container_policy_type; policy_t policy; auto scalar = host_scalar{extents, policy}; @@ -766,7 +862,7 @@ auto make_host_scalar(ElementType const& v) template auto make_device_scalar(ElementType const& v, rmm::cuda_stream_view stream) { - detail::scalar_extent extents; + scalar_extent extents; using policy_t = typename device_scalar::container_policy_type; policy_t policy{stream}; auto scalar = device_scalar{extents, policy}; @@ -797,10 +893,7 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) template auto make_host_vector(size_t n) { - detail::vector_extent extents{n}; - using policy_t = typename host_vector::container_policy_type; - policy_t policy; - return host_vector{extents, policy}; + return make_host_mdarray(n); } /** @@ -813,10 +906,7 @@ auto make_host_vector(size_t n) template auto make_device_vector(size_t n, rmm::cuda_stream_view stream) { - detail::vector_extent extents{n}; - using policy_t = typename device_vector::container_policy_type; - policy_t policy{stream}; - return device_vector{extents, policy}; + return make_device_mdarray(stream, n); } /** @@ -845,10 +935,10 @@ auto flatten(mdspan_type mds) { RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); - detail::vector_extent ext{mds.size()}; + vector_extent ext{mds.size()}; return detail::stdex::mdspan(mds.data(), ext); } diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index c4557245ae..215487c82f 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -24,8 +24,11 @@ #include #include #include // dynamic_extent + #include #include +#include + #include namespace raft::detail { @@ -138,6 +141,7 @@ class device_uvector { template class device_uvector_policy { rmm::cuda_stream_view stream_; + rmm::mr::device_memory_resource* mr_; public: using element_type = ElementType; @@ -152,12 +156,21 @@ class device_uvector_policy { using const_accessor_policy = std::experimental::default_accessor; public: - auto create(size_t n) -> container_type { return container_type(n, stream_); } + auto create(size_t n) -> container_type + { + return mr_ ? container_type(n, stream_, mr_) : container_type(n, stream_); + } device_uvector_policy() = delete; explicit device_uvector_policy(rmm::cuda_stream_view stream) noexcept( std::is_nothrow_copy_constructible_v) - : stream_{stream} + : stream_{stream}, mr_(nullptr) + { + } + + device_uvector_policy(rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) noexcept( + std::is_nothrow_copy_constructible_v) + : stream_{stream}, mr_(mr) { } @@ -309,4 +322,15 @@ MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents index[0] = idx; return arr_to_tup(index); } + +/** + * Ensure all types listed in the parameter pack `Extents` are integral types. + * Usage: + * put it as the last nameless template parameter of a function: + * `typename = ensure_integral_extents` + */ +template +using ensure_integral_extents = + std::enable_if_t<(true && ... && std::is_integral_v), void>; + } // namespace raft::detail