Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating raft::linalg APIs to use mdspan #809

Merged
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 138 additions & 8 deletions cpp/include/raft/linalg/add.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

#include "detail/add.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/util/input_validation.hpp>
divyegala marked this conversation as resolved.
Show resolved Hide resolved

namespace raft {
namespace linalg {

Expand All @@ -46,7 +50,7 @@ using detail::adds_scalar;
* @param stream cuda stream where to launch work
*/
template <typename InT, typename OutT = InT, typename IdxType = int>
void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream)
void addScalar(OutT* out, const InT* in, const InT scalar, IdxType len, cudaStream_t stream)
{
detail::addScalar(out, in, scalar, len, stream);
}
Expand All @@ -72,24 +76,150 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st

/** Substract single value pointed by singleScalarDev parameter in device memory from inDev[i] and
* write result to outDev[i]
* @tparam math_t data-type upon which the math operation will be performed
* @tparam InT input data-type. Also the data-type upon which the math ops
* will be performed
* @tparam OutT output data-type
* @tparam IdxType Integer type used to for addressing
* @param outDev the output buffer
* @param inDev the input buffer
* @param singleScalarDev pointer to the scalar located in device memory
* @param len number of elements in the input and output buffer
* @param stream cuda stream
*/
template <typename math_t, typename IdxType = int>
void addDevScalar(math_t* outDev,
const math_t* inDev,
const math_t* singleScalarDev,
IdxType len,
cudaStream_t stream)
template <typename InT, typename OutT = InT, typename IdxType = int>
void addDevScalar(
OutT* outDev, const InT* inDev, const InT* singleScalarDev, IdxType len, cudaStream_t stream)
{
detail::addDevScalar(outDev, inDev, singleScalarDev, len, stream);
}

/**
* @defgroup add Addition Arithmetic
* @{
*/

/**
* @brief Elementwise add operation
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @param[in] handle raft::handle_t
* @param[in] in1 First Input
* @param[in] in2 Second Input
* @param[out] out Output
*/
template <typename InType,
typename OutType,
typename = raft::enable_if_device_mdspan<OutType, InType>>
void add(const raft::handle_t& handle, const InType in1, const InType in2, OutType out)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
Copy link
Contributor

@mhoemmen mhoemmen Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the goal here is to use 32-bit indices if possible, when inverting the layout mapping to use a 1-D loop index. This can be done, but there are two correctness issues with your approach.

  1. The right quantity to test here is out.required_span_size(), not out.size(). The layout mapping maps the input multidimensional index to the half-open interval of offsets [0, out.required_span_size()).

  2. The layout_{left, right, stride}::mapping constructors generally have as a precondition that the required span size of the input extents (and strides, if applicable) be representable as a value of type index_type.

Here is an approach that would address these issues.

template<class T>
constexpr bool is_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) == std::uint32_t;
template<class T>
constexpr bool is_greater_than_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) > std::uint32_t;

if constexpr (is_32_bit_integral_v<typename OutType::index_type>) {
  // ... always call 32-bit version ...
} else if constexpr (is_greater_than_32_bit_integral_v<typename OutType::index_type>) {
  // ... test the value of `required_span_size()`; dispatch to 32-bit or index_type (64 or more bits) as needed ...
} else {
  // ... always use index_type, which is 16 bits or less here ...
}

You'll also want to check the index_type and required_span_size() of the other mdspan. The above approach has the advantage that it only compiles an inner kernel for index types that you actually use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In point 2, what happens in extreme cases? Consider index_type=uint32_t with extents {2^32, 2}. In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala required_span_size() is not representable by index_type in this case. For layout_left and layout_right, required_span_size() and size() are the same mathematically. The only difference is the return type (index_type resp. size_t). For layout_stride, though, required_span_size() can be greater than the size(). For other layouts (e.g., the "matrix of a single value" layout that maps all multidimensional indices to the offset zero), required_span_size() can be less than size().

Note that while it's UB for users to violate preconditions, implementations aren't required to check preconditions. The reference implementation of layout_left does not currently check preconditions, as you can see here, for instance. This means two things.

  1. If someone gives you a layout_{left,right,stride}::mapping instance (e.g., in an mdspan), then you can assume that the precondition is satisfied.

  2. If you are constructing a layout_{left,right,stride}::mapping instance (e.g., by constructing an mdspan with a pointer and extents), then you are responsible for ensuring that the precondition is satisfied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala wrote:

In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Those are two separate questions, actually! : - )

  1. required_span_size() is not representable by index_type in this case.
  2. Giving this extents object to layout_{left,right,stride}::mapping's constructor violates the constructor's precondition. It could overflow, or it could open a portal to the Awesome Dimension and let loose a swarm of nasal demons who search out precondition violators and boop them gently on the nose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhoemmen thanks for the explanations! How do we really represent such edge cases and safely obtain the product of the extents? Sounds like size() is the safe way to obtain the product without violating any pre-conditions since it's representable by size_t?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala Gladly! : - )

How do we really represent such edge cases and safely obtain the product of the extents?

By the time the user has created a layout mapping, it's already too late. What I mean by that is that if required_span_size() doesn't fit index_type, then the user will likely get the wrong answer when they try to index into the mdspan.

In what follows in my comment, I'll distinguish between "the Preconditions in the spec" and "what the reference implementation does." The reference implementation currently does not check this precondition in the layout mapping. This means that it's possible for users to construct extents for which the mapping's required_span_size() can overflow.

We can prevent this by wrapping mdspan creation to check the extents object for potential overflow, before it goes into a layout mapping's constructor. It's not UB to construct, e.g., dextents<uint16_t, 2>( 2^{15} , 2^{15} ). We just need to intercept that naughty extents value before it goes into a layout mapping's constructor. Otherwise, the layout mapping has the freedom to do whatever it likes, including calling abort().

Our mdarray implementation's conversion to mdspan can also check, but again, we're probably better off making the wrapper explicit and not part of the mdarray proposal. WG21 likes Preconditions and wants violating them to be UB. If we want some specified behavior (e.g., throwing a particular exception, or calling terminate() after printing a helpful error message), then we'll have to implement that ourselves.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like out.required_span_size() does not work. How do I access this from the layout?

add<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
add<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of device scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[in] scalar raft::device_scalar_view
* @param[in] out Output
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_device_mdspan<OutType, InType>>
void add_scalar(const raft::handle_t& handle,
InType in,
OutType out,
raft::device_scalar_view<const typename InType::value_type, ScalarIdxType> scalar)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addDevScalar<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addDevScalar<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of host scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[in] scalar raft::host_scalar_view
* @param[in] out Output
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_device_mdspan<OutType, InType>>
void add_scalar(const raft::handle_t& handle,
const InType in,
OutType out,
raft::host_scalar_view<const typename InType::value_type, ScalarIdxType> scalar)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addScalar<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addScalar<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/** @} */ // end of group add

}; // end namespace linalg
}; // end namespace raft

Expand Down
73 changes: 73 additions & 0 deletions cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,79 @@ void axpy(const raft::handle_t& handle,
detail::axpy<T, DevicePointerMode>(handle, n, alpha, x, incx, y, incy, stream);
}

/**
* @defgroup axpy axpy
* @{
*/

/**
* @brief axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam MdspanType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [inout] y Output vector
* @param [in] incx stride between consecutive elements of x
* @param [in] incy stride between consecutive elements of y
*/
template <typename MdspanType, typename = raft::enable_if_device_mdspan<MdspanType>>
void axpy(const raft::handle_t& handle,
raft::device_scalar_view<const typename MdspanType::element_type, ScalarIdxType> alpha,
MdspanType x,
MdspanType y,
const int incx,
const int incy)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input")

axpy<typename MdspanType::element_type, true>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
incx,
y.data_handle(),
incy,
handle.get_stream());
}

/**
* @brief axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam MdspanType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [inout] y Output vector
* @param [in] incx stride between consecutive elements of x
* @param [in] incy stride between consecutive elements of y
*/
template <typename MdspanType, typename = raft::enable_if_device_mdspan<MdspanType>>
void axpy(const raft::handle_t& handle,
raft::host_scalar_view<const typename MdspanType::value_type, ScalarIdxType> alpha,
MdspanType x,
MdspanType y,
const int incx,
const int incy)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input")

axpy<typename MdspanType::value_type, false>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
incx,
y.data_handle(),
incy,
handle.get_stream());
}

/** @} */ // end of group axpy

} // namespace raft::linalg

#endif
47 changes: 47 additions & 0 deletions cpp/include/raft/linalg/binary_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

#include "detail/binary_op.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/input_validation.hpp>

namespace raft {
namespace linalg {
Expand Down Expand Up @@ -52,6 +55,50 @@ void binaryOp(
detail::binaryOp(out, in1, in2, len, op, stream);
}

/**
* @defgroup binary_op Element-Wise Binary Operation
* @{
*/

/**
* @brief perform element-wise binary operation on the input arrays
* @tparam InType Input Type raft::device_mdspan
* @tparam Lambda the device-lambda performing the actual operation
* @tparam OutType Output Type raft::device_mdspan
* @param[in] handle raft::handle_t
* @param[in] in1 First input
* @param[in] in2 Second input
* @param[out] out Output
* @param[in] op the device-lambda
* @note Lambda must be a functor with the following signature:
* `OutType func(const InType& val1, const InType& val2);`
*/
template <typename InType,
typename Lambda,
typename OutType,
typename = raft::enable_if_device_mdspan<InType, OutType>>
void binary_op(const raft::handle_t& handle, InType in1, InType in2, OutType out, Lambda op)
{
RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
binaryOp<in_value_t, Lambda, out_value_t, std::uint32_t>(
out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream());
} else {
binaryOp<in_value_t, Lambda, out_value_t, std::uint64_t>(
out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream());
}
}

/** @} */ // end of group binary_op

}; // end namespace linalg
}; // end namespace raft

Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/linalg/cholesky_r1_update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace linalg {

/**
* @brief Rank 1 update of Cholesky decomposition.
* NOTE: The new mdspan-based API will not be provided for this function.
*
* This method is useful if an algorithm iteratively builds up matrix A, and
* the Cholesky decomposition of A is required at each step.
Expand Down Expand Up @@ -109,7 +110,7 @@ namespace linalg {
* @param L device array for to store the triangular matrix L, and the new
* column of A in column major format, size [n*n]
* @param n number of elements in the new row.
* @param ld stride of colums in L
* @param ld stride of columns in L
* @param workspace device pointer to workspace shall be nullptr ar an array
* of size [n_bytes].
* @param n_bytes size of workspace is returned here if workspace==nullptr.
Expand All @@ -132,6 +133,7 @@ void choleskyRank1Update(const raft::handle_t& handle,
{
detail::choleskyRank1Update(handle, L, n, ld, workspace, n_bytes, uplo, stream, eps);
}

}; // namespace linalg
}; // namespace raft

Expand Down
Loading