Skip to content

Commit

Permalink
Add several type aliases and helpers for creating mdarrays (#726)
Browse files Browse the repository at this point in the history
A few small improvements to mdarrays migrated from #652 :

  - Expose more type aliases for extents and layouts
  - Add generic `make_device_mdarray` and `make_host_mdarray`
  - Allow passing an rmm memory resource when creating a container policy (e.g. for creating arrays in managed memory).

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #726
  • Loading branch information
achirkin authored Jul 5, 2022
1 parent fd17011 commit fba595d
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 44 deletions.
174 changes: 132 additions & 42 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include <raft/core/handle.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/detail/mdarray.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft {
/**
Expand All @@ -37,14 +39,40 @@ template <size_t... ExtentsPack>
using extents = std::experimental::extents<ExtentsPack...>;

/**
* @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory.
* @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory.
* @{
*/
using layout_c_contiguous = detail::stdex::layout_right;
using detail::stdex::layout_right;
using layout_c_contiguous = layout_right;
using row_major = layout_right;
/** @} */

/**
* @\brief F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory.
* @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory.
* @{
*/
using layout_f_contiguous = detail::stdex::layout_left;
using detail::stdex::layout_left;
using layout_f_contiguous = layout_left;
using col_major = layout_left;
/** @} */

/**
* @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension
* is known at run time (dynamic_extent in each dimension).
* @{
*/
using detail::matrix_extent;
using detail::scalar_extent;
using detail::vector_extent;

using extent_1d = vector_extent;
using extent_2d = matrix_extent;
using extent_3d = detail::stdex::extents<dynamic_extent, dynamic_extent, dynamic_extent>;
using extent_4d =
detail::stdex::extents<dynamic_extent, dynamic_extent, dynamic_extent, dynamic_extent>;
using extent_5d = detail::stdex::
extents<dynamic_extent, dynamic_extent, dynamic_extent, dynamic_extent, dynamic_extent>;
/** @} */

template <typename ElementType,
typename Extents,
Expand Down Expand Up @@ -511,72 +539,72 @@ using device_mdarray =
* @tparam ElementType the data type of the scalar element
*/
template <typename ElementType>
using host_scalar = host_mdarray<ElementType, detail::scalar_extent>;
using host_scalar = host_mdarray<ElementType, scalar_extent>;

/**
* @brief Shorthand for 0-dim host mdarray (scalar).
* @tparam ElementType the data type of the scalar element
*/
template <typename ElementType>
using device_scalar = device_mdarray<ElementType, detail::scalar_extent>;
using device_scalar = device_mdarray<ElementType, scalar_extent>;

/**
* @brief Shorthand for 1-dim host mdarray.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_vector = host_mdarray<ElementType, detail::vector_extent, LayoutPolicy>;
using host_vector = host_mdarray<ElementType, vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for 1-dim device mdarray.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_vector = device_mdarray<ElementType, detail::vector_extent, LayoutPolicy>;
using device_vector = device_mdarray<ElementType, vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous host matrix.
* @tparam ElementType the data type of the matrix elements
* @tparam LayoutPolicy policy for strides and layout ordering
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_matrix = host_mdarray<ElementType, detail::matrix_extent, LayoutPolicy>;
using host_matrix = host_mdarray<ElementType, matrix_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous device matrix.
* @tparam ElementType the data type of the matrix elements
* @tparam LayoutPolicy policy for strides and layout ordering
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_matrix = device_mdarray<ElementType, detail::matrix_extent, LayoutPolicy>;
using device_matrix = device_mdarray<ElementType, matrix_extent, LayoutPolicy>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
* @tparam ElementType the data type of the scalar element
*/
template <typename ElementType>
using host_scalar_view = host_mdspan<ElementType, detail::scalar_extent>;
using host_scalar_view = host_mdspan<ElementType, scalar_extent>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
* @tparam ElementType the data type of the scalar element
*/
template <typename ElementType>
using device_scalar_view = device_mdspan<ElementType, detail::scalar_extent>;
using device_scalar_view = device_mdspan<ElementType, scalar_extent>;

/**
* @brief Shorthand for 1-dim host mdspan.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_vector_view = host_mdspan<ElementType, detail::vector_extent, LayoutPolicy>;
using host_vector_view = host_mdspan<ElementType, vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for 1-dim device mdspan.
* @tparam ElementType the data type of the vector elements
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_vector_view = device_mdspan<ElementType, detail::vector_extent, LayoutPolicy>;
using device_vector_view = device_mdspan<ElementType, vector_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous host matrix view.
Expand All @@ -585,7 +613,7 @@ using device_vector_view = device_mdspan<ElementType, detail::vector_extent, Lay
*
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using host_matrix_view = host_mdspan<ElementType, detail::matrix_extent, LayoutPolicy>;
using host_matrix_view = host_mdspan<ElementType, matrix_extent, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous device matrix view.
Expand All @@ -594,7 +622,7 @@ using host_matrix_view = host_mdspan<ElementType, detail::matrix_extent, LayoutP
*
*/
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
using device_matrix_view = device_mdspan<ElementType, detail::matrix_extent, LayoutPolicy>;
using device_matrix_view = device_mdspan<ElementType, matrix_extent, LayoutPolicy>;

/**
* @brief Create a 0-dim (scalar) mdspan instance for host value.
Expand All @@ -605,7 +633,7 @@ using device_matrix_view = device_mdspan<ElementType, detail::matrix_extent, Lay
template <typename ElementType>
auto make_host_scalar_view(ElementType* ptr)
{
detail::scalar_extent extents;
scalar_extent extents;
return host_scalar_view<ElementType>{ptr, extents};
}

Expand All @@ -618,7 +646,7 @@ auto make_host_scalar_view(ElementType* ptr)
template <typename ElementType>
auto make_device_scalar_view(ElementType* ptr)
{
detail::scalar_extent extents;
scalar_extent extents;
return device_scalar_view<ElementType>{ptr, extents};
}

Expand All @@ -635,7 +663,7 @@ auto make_device_scalar_view(ElementType* ptr)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols)
{
detail::matrix_extent extents{n_rows, n_cols};
matrix_extent extents{n_rows, n_cols};
return host_matrix_view<ElementType, LayoutPolicy>{ptr, extents};
}
/**
Expand All @@ -651,7 +679,7 @@ auto make_host_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols)
{
detail::matrix_extent extents{n_rows, n_cols};
matrix_extent extents{n_rows, n_cols};
return device_matrix_view<ElementType, LayoutPolicy>{ptr, extents};
}

Expand All @@ -665,7 +693,7 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector_view(ElementType* ptr, size_t n)
{
detail::vector_extent extents{n};
vector_extent extents{n};
return host_vector_view<ElementType, LayoutPolicy>{ptr, extents};
}

Expand All @@ -679,10 +707,84 @@ auto make_host_vector_view(ElementType* ptr, size_t n)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, size_t n)
{
detail::vector_extent extents{n};
vector_extent extents{n};
return device_vector_view<ElementType, LayoutPolicy>{ptr, extents};
}

/**
* @brief Create a host mdarray.
* @tparam ElementType the data type of the matrix elements
* @tparam LayoutPolicy policy for strides and layout ordering
* @param exts dimensionality of the array (series of integers)
* @return raft::host_mdarray
*/
template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
typename... Extents,
typename = detail::ensure_integral_extents<Extents...>>
auto make_host_mdarray(Extents... exts)
{
using extent_t = extents<((void)exts, dynamic_extent)...>;
using mdarray_t = host_mdarray<ElementType, extent_t, LayoutPolicy>;

typename mdarray_t::extents_type extent{exts...};
typename mdarray_t::mapping_type layout{extent};
typename mdarray_t::container_policy_type policy;

return mdarray_t{layout, policy};
}

/**
* @brief Create a device mdarray.
* @tparam ElementType the data type of the matrix elements
* @tparam LayoutPolicy policy for strides and layout ordering
* @param stream cuda stream for ordering events
* @param exts dimensionality of the array (series of integers)
* @return raft::device_mdarray
*/
template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
typename... Extents,
typename = detail::ensure_integral_extents<Extents...>>
auto make_device_mdarray(rmm::cuda_stream_view stream, Extents... exts)
{
using extent_t = extents<((void)exts, dynamic_extent)...>;
using mdarray_t = device_mdarray<ElementType, extent_t, LayoutPolicy>;

typename mdarray_t::extents_type extent{exts...};
typename mdarray_t::mapping_type layout{extent};
typename mdarray_t::container_policy_type policy{stream};

return mdarray_t{layout, policy};
}

/**
* @brief Create a device mdarray.
* @tparam ElementType the data type of the matrix elements
* @tparam LayoutPolicy policy for strides and layout ordering
* @param stream cuda stream for ordering events
* @param mr rmm memory resource used for allocating the memory for the array
* @param exts dimensionality of the array (series of integers)
* @return raft::device_mdarray
*/
template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
typename... Extents,
typename = detail::ensure_integral_extents<Extents...>>
auto make_device_mdarray(rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr,
Extents... exts)
{
using extent_t = extents<((void)exts, dynamic_extent)...>;
using mdarray_t = device_mdarray<ElementType, extent_t, LayoutPolicy>;

typename mdarray_t::extents_type extent{exts...};
typename mdarray_t::mapping_type layout{extent};
typename mdarray_t::container_policy_type policy{stream, mr};

return mdarray_t{layout, policy};
}

/**
* @brief Create a 2-dim c-contiguous host mdarray.
* @tparam ElementType the data type of the matrix elements
Expand All @@ -694,10 +796,7 @@ auto make_device_vector_view(ElementType* ptr, size_t n)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_matrix(size_t n_rows, size_t n_cols)
{
detail::matrix_extent extents{n_rows, n_cols};
using policy_t = typename host_matrix<ElementType>::container_policy_type;
policy_t policy;
return host_matrix<ElementType, LayoutPolicy>{extents, policy};
return make_host_mdarray<ElementType, LayoutPolicy>(n_rows, n_cols);
}

/**
Expand All @@ -712,10 +811,7 @@ auto make_host_matrix(size_t n_rows, size_t n_cols)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_matrix(size_t n_rows, size_t n_cols, rmm::cuda_stream_view stream)
{
detail::matrix_extent extents{n_rows, n_cols};
using policy_t = typename device_matrix<ElementType>::container_policy_type;
policy_t policy{stream};
return device_matrix<ElementType, LayoutPolicy>{extents, policy};
return make_device_mdarray<ElementType, LayoutPolicy>(stream, n_rows, n_cols);
}

/**
Expand Down Expand Up @@ -747,7 +843,7 @@ auto make_host_scalar(ElementType const& v)
// FIXME(jiamingy): We can optimize this by using std::array as container policy, which
// requires some more compile time dispatching. This is enabled in the ref impl but
// hasn't been ported here yet.
detail::scalar_extent extents;
scalar_extent extents;
using policy_t = typename host_scalar<ElementType>::container_policy_type;
policy_t policy;
auto scalar = host_scalar<ElementType>{extents, policy};
Expand All @@ -766,7 +862,7 @@ auto make_host_scalar(ElementType const& v)
template <typename ElementType>
auto make_device_scalar(ElementType const& v, rmm::cuda_stream_view stream)
{
detail::scalar_extent extents;
scalar_extent extents;
using policy_t = typename device_scalar<ElementType>::container_policy_type;
policy_t policy{stream};
auto scalar = device_scalar<ElementType>{extents, policy};
Expand Down Expand Up @@ -797,10 +893,7 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_host_vector(size_t n)
{
detail::vector_extent extents{n};
using policy_t = typename host_vector<ElementType, LayoutPolicy>::container_policy_type;
policy_t policy;
return host_vector<ElementType, LayoutPolicy>{extents, policy};
return make_host_mdarray<ElementType, LayoutPolicy>(n);
}

/**
Expand All @@ -813,10 +906,7 @@ auto make_host_vector(size_t n)
template <typename ElementType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector(size_t n, rmm::cuda_stream_view stream)
{
detail::vector_extent extents{n};
using policy_t = typename device_vector<ElementType, LayoutPolicy>::container_policy_type;
policy_t policy{stream};
return device_vector<ElementType, LayoutPolicy>{extents, policy};
return make_device_mdarray<ElementType, LayoutPolicy>(stream, n);
}

/**
Expand Down Expand Up @@ -845,10 +935,10 @@ auto flatten(mdspan_type mds)
{
RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous.");

detail::vector_extent ext{mds.size()};
vector_extent ext{mds.size()};

return detail::stdex::mdspan<typename mdspan_type::element_type,
detail::vector_extent,
vector_extent,
typename mdspan_type::layout_type,
typename mdspan_type::accessor_type>(mds.data(), ext);
}
Expand Down
Loading

0 comments on commit fba595d

Please sign in to comment.