Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating raft::linalg APIs to use mdspan #809

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,23 @@ using managed_mdspan = mdspan<ElementType, Extents, LayoutPolicy, managed_access

namespace detail {
template <typename T, bool B>
struct is_device_accessible_mdspan : std::false_type {
struct is_device_mdspan : std::false_type {
};
template <typename T>
struct is_device_accessible_mdspan<T, true>
: std::bool_constant<T::accessor_type::is_device_accessible> {
struct is_device_mdspan<T, true> : std::bool_constant<T::accessor_type::is_device_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
using is_device_accessible_mdspan_t = is_device_accessible_mdspan<T, is_mdspan_v<T>>;
using is_device_mdspan_t = is_device_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_device_mdspan_t = is_device_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_device_mdspan_t = is_device_mdspan<T, is_output_mdspan_v<T>>;

template <typename T, bool B>
struct is_managed_mdspan : std::false_type {
Expand All @@ -70,18 +75,37 @@ struct is_managed_mdspan<T, true> : std::bool_constant<T::accessor_type::is_mana
template <typename T>
using is_managed_mdspan_t = is_managed_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_managed_mdspan_t = is_managed_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_managed_mdspan_t = is_managed_mdspan<T, is_output_mdspan_v<T>>;

} // end namespace detail

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_device_accessible_mdspan_v =
std::conjunction_v<detail::is_device_accessible_mdspan_t<Tn>...>;
inline constexpr bool is_device_mdspan_v = std::conjunction_v<detail::is_device_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_device_mdspan_v =
std::conjunction_v<detail::is_input_device_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_device_mdspan_v =
std::conjunction_v<detail::is_output_device_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_device_mdspan = std::enable_if_t<is_device_accessible_mdspan_v<Tn...>>;
using enable_if_device_mdspan = std::enable_if_t<is_device_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_device_mdspan = std::enable_if_t<is_input_device_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_output_device_mdspan = std::enable_if_t<is_output_device_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a
Expand All @@ -90,9 +114,23 @@ using enable_if_device_mdspan = std::enable_if_t<is_device_accessible_mdspan_v<T
template <typename... Tn>
inline constexpr bool is_managed_mdspan_v = std::conjunction_v<detail::is_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_managed_mdspan_v =
std::conjunction_v<detail::is_input_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_managed_mdspan_v =
std::conjunction_v<detail::is_output_managed_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_managed_mdspan = std::enable_if_t<is_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_managed_mdspan = std::enable_if_t<is_input_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_output_managed_mdspan = std::enable_if_t<is_output_managed_mdspan_v<Tn...>>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
* @tparam ElementType the data type of the scalar element
Expand Down
32 changes: 25 additions & 7 deletions cpp/include/raft/core/host_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,23 @@ using host_mdspan = mdspan<ElementType, Extents, LayoutPolicy, host_accessor<Acc
namespace detail {

template <typename T, bool B>
struct is_host_accessible_mdspan : std::false_type {
struct is_host_mdspan : std::false_type {
};
template <typename T>
struct is_host_accessible_mdspan<T, true>
: std::bool_constant<T::accessor_type::is_host_accessible> {
struct is_host_mdspan<T, true> : std::bool_constant<T::accessor_type::is_host_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
using is_host_accessible_mdspan_t = is_host_accessible_mdspan<T, is_mdspan_v<T>>;
using is_host_mdspan_t = is_host_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_host_mdspan_t = is_host_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_host_mdspan_t = is_host_mdspan<T, is_output_mdspan_v<T>>;

} // namespace detail

Expand All @@ -57,11 +62,24 @@ using is_host_accessible_mdspan_t = is_host_accessible_mdspan<T, is_mdspan_v<T>>
* derived type
*/
template <typename... Tn>
inline constexpr bool is_host_accessible_mdspan_v =
std::conjunction_v<detail::is_host_accessible_mdspan_t<Tn>...>;
inline constexpr bool is_host_mdspan_v = std::conjunction_v<detail::is_host_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_host_mdspan_v =
std::conjunction_v<detail::is_input_host_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_host_mdspan_v =
std::conjunction_v<detail::is_output_host_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_host_mdspan = std::enable_if_t<is_input_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_host_mdspan = std::enable_if_t<is_input_host_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_host_mdspan = std::enable_if_t<is_host_accessible_mdspan_v<Tn...>>;
using enable_if_output_host_mdspan = std::enable_if_t<is_output_host_mdspan_v<Tn...>>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
Expand Down
34 changes: 34 additions & 0 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,31 @@ struct is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>(
: std::true_type {
};

template <typename T, typename = void>
struct is_input_mdspan : std::false_type {
};
template <typename T>
struct is_input_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::bool_constant<std::is_const_v<typename T::element_type>> {
};

template <typename T, typename = void>
struct is_output_mdspan : std::false_type {
};
template <typename T>
struct is_output_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::bool_constant<not std::is_const_v<typename T::element_type>> {
};

template <typename T>
using is_mdspan_t = is_mdspan<std::remove_const_t<T>>;

template <typename T>
using is_input_mdspan_t = is_input_mdspan<T>;

template <typename T>
using is_output_mdspan_t = is_output_mdspan<T>;

/**
* @\brief Boolean to determine if variadic template types Tn are either
* raft::host_mdspan/raft::device_mdspan or their derived types
Expand All @@ -70,6 +92,18 @@ inline constexpr bool is_mdspan_v = std::conjunction_v<is_mdspan_t<Tn>...>;
template <typename... Tn>
using enable_if_mdspan = std::enable_if_t<is_mdspan_v<Tn...>>;

template <typename... Tn>
inline constexpr bool is_input_mdspan_v = std::conjunction_v<is_input_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_input_mdspan = std::enable_if_t<is_input_mdspan_v<Tn...>>;

template <typename... Tn>
inline constexpr bool is_output_mdspan_v = std::conjunction_v<is_output_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_output_mdspan = std::enable_if_t<is_output_mdspan_v<Tn...>>;

// uint division optimization inspired by the CIndexer in cupy. Division operation is
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
// bit when the index is smaller, then try to avoid division when it's exp of 2.
Expand Down
149 changes: 141 additions & 8 deletions cpp/include/raft/linalg/add.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

#include "detail/add.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/util/input_validation.hpp>
divyegala marked this conversation as resolved.
Show resolved Hide resolved

namespace raft {
namespace linalg {

Expand All @@ -46,7 +50,7 @@ using detail::adds_scalar;
* @param stream cuda stream where to launch work
*/
template <typename InT, typename OutT = InT, typename IdxType = int>
void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream)
void addScalar(OutT* out, const InT* in, const InT scalar, IdxType len, cudaStream_t stream)
{
detail::addScalar(out, in, scalar, len, stream);
}
Expand All @@ -72,24 +76,153 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st

/** Substract single value pointed by singleScalarDev parameter in device memory from inDev[i] and
* write result to outDev[i]
* @tparam math_t data-type upon which the math operation will be performed
* @tparam InT input data-type. Also the data-type upon which the math ops
* will be performed
* @tparam OutT output data-type
* @tparam IdxType Integer type used to for addressing
* @param outDev the output buffer
* @param inDev the input buffer
* @param singleScalarDev pointer to the scalar located in device memory
* @param len number of elements in the input and output buffer
* @param stream cuda stream
*/
template <typename math_t, typename IdxType = int>
void addDevScalar(math_t* outDev,
const math_t* inDev,
const math_t* singleScalarDev,
IdxType len,
cudaStream_t stream)
template <typename InT, typename OutT = InT, typename IdxType = int>
void addDevScalar(
OutT* outDev, const InT* inDev, const InT* singleScalarDev, IdxType len, cudaStream_t stream)
{
detail::addDevScalar(outDev, inDev, singleScalarDev, len, stream);
}

/**
* @defgroup add Addition Arithmetic
* @{
*/

/**
* @brief Elementwise add operation
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @param[in] handle raft::handle_t
* @param[in] in1 First Input
* @param[in] in2 Second Input
* @param[out] out Output
*/
template <typename InType,
typename OutType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void add(const raft::handle_t& handle, InType in1, InType in2, OutType out)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
Copy link
Contributor

@mhoemmen mhoemmen Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the goal here is to use 32-bit indices if possible, when inverting the layout mapping to use a 1-D loop index. This can be done, but there are two correctness issues with your approach.

  1. The right quantity to test here is out.required_span_size(), not out.size(). The layout mapping maps the input multidimensional index to the half-open interval of offsets [0, out.required_span_size()).

  2. The layout_{left, right, stride}::mapping constructors generally have as a precondition that the required span size of the input extents (and strides, if applicable) be representable as a value of type index_type.

Here is an approach that would address these issues.

template<class T>
constexpr bool is_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) == std::uint32_t;
template<class T>
constexpr bool is_greater_than_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) > std::uint32_t;

if constexpr (is_32_bit_integral_v<typename OutType::index_type>) {
  // ... always call 32-bit version ...
} else if constexpr (is_greater_than_32_bit_integral_v<typename OutType::index_type>) {
  // ... test the value of `required_span_size()`; dispatch to 32-bit or index_type (64 or more bits) as needed ...
} else {
  // ... always use index_type, which is 16 bits or less here ...
}

You'll also want to check the index_type and required_span_size() of the other mdspan. The above approach has the advantage that it only compiles an inner kernel for index types that you actually use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In point 2, what happens in extreme cases? Consider index_type=uint32_t with extents {2^32, 2}. In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala required_span_size() is not representable by index_type in this case. For layout_left and layout_right, required_span_size() and size() are the same mathematically. The only difference is the return type (index_type resp. size_t). For layout_stride, though, required_span_size() can be greater than the size(). For other layouts (e.g., the "matrix of a single value" layout that maps all multidimensional indices to the offset zero), required_span_size() can be less than size().

Note that while it's UB for users to violate preconditions, implementations aren't required to check preconditions. The reference implementation of layout_left does not currently check preconditions, as you can see here, for instance. This means two things.

  1. If someone gives you a layout_{left,right,stride}::mapping instance (e.g., in an mdspan), then you can assume that the precondition is satisfied.

  2. If you are constructing a layout_{left,right,stride}::mapping instance (e.g., by constructing an mdspan with a pointer and extents), then you are responsible for ensuring that the precondition is satisfied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala wrote:

In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Those are two separate questions, actually! : - )

  1. required_span_size() is not representable by index_type in this case.
  2. Giving this extents object to layout_{left,right,stride}::mapping's constructor violates the constructor's precondition. It could overflow, or it could open a portal to the Awesome Dimension and let loose a swarm of nasal demons who search out precondition violators and boop them gently on the nose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhoemmen thanks for the explanations! How do we really represent such edge cases and safely obtain the product of the extents? Sounds like size() is the safe way to obtain the product without violating any pre-conditions since it's representable by size_t?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala Gladly! : - )

How do we really represent such edge cases and safely obtain the product of the extents?

By the time the user has created a layout mapping, it's already too late. What I mean by that is that if required_span_size() doesn't fit index_type, then the user will likely get the wrong answer when they try to index into the mdspan.

In what follows in my comment, I'll distinguish between "the Preconditions in the spec" and "what the reference implementation does." The reference implementation currently does not check this precondition in the layout mapping. This means that it's possible for users to construct extents for which the mapping's required_span_size() can overflow.

We can prevent this by wrapping mdspan creation to check the extents object for potential overflow, before it goes into a layout mapping's constructor. It's not UB to construct, e.g., dextents<uint16_t, 2>( 2^{15} , 2^{15} ). We just need to intercept that naughty extents value before it goes into a layout mapping's constructor. Otherwise, the layout mapping has the freedom to do whatever it likes, including calling abort().

Our mdarray implementation's conversion to mdspan can also check, but again, we're probably better off making the wrapper explicit and not part of the mdarray proposal. WG21 likes Preconditions and wants violating them to be UB. If we want some specified behavior (e.g., throwing a particular exception, or calling terminate() after printing a helpful error message), then we'll have to implement that ourselves.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like out.required_span_size() does not work. How do I access this from the layout?

add<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
add<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of device scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[in] scalar raft::device_scalar_view
* @param[in] out Output
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void add_scalar(const raft::handle_t& handle,
InType in,
OutType out,
raft::device_scalar_view<const typename InType::value_type, ScalarIdxType> scalar)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addDevScalar<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addDevScalar<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of host scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[in] scalar raft::host_scalar_view
* @param[in] out Output
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void add_scalar(const raft::handle_t& handle,
const InType in,
OutType out,
raft::host_scalar_view<const typename InType::value_type, ScalarIdxType> scalar)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addScalar<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addScalar<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/** @} */ // end of group add

}; // end namespace linalg
}; // end namespace raft

Expand Down
Loading