Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating raft::linalg APIs to use mdspan #809

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/include/raft/core/cudart_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cuda.h>
#include <cuda_runtime.h>

#include <chrono>
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ using device_scalar_view = device_mdspan<ElementType, scalar_extent<IndexType>>;
* @brief Shorthand for 1-dim host mdspan.
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
Expand Down
144 changes: 136 additions & 8 deletions cpp/include/raft/linalg/add.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

#include "detail/add.cuh"

#include <raft/core/mdarray.hpp>

namespace raft {
namespace linalg {

Expand All @@ -46,7 +48,7 @@ using detail::adds_scalar;
* @param stream cuda stream where to launch work
*/
template <typename InT, typename OutT = InT, typename IdxType = int>
void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream)
void addScalar(OutT* out, const InT* in, const InT scalar, IdxType len, cudaStream_t stream)
{
detail::addScalar(out, in, scalar, len, stream);
}
Expand All @@ -72,24 +74,150 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st

/** Substract single value pointed by singleScalarDev parameter in device memory from inDev[i] and
* write result to outDev[i]
* @tparam math_t data-type upon which the math operation will be performed
* @tparam InT input data-type. Also the data-type upon which the math ops
* will be performed
* @tparam OutT output data-type
* @tparam IdxType Integer type used to for addressing
* @param outDev the output buffer
* @param inDev the input buffer
* @param singleScalarDev pointer to the scalar located in device memory
* @param len number of elements in the input and output buffer
* @param stream cuda stream
*/
template <typename math_t, typename IdxType = int>
void addDevScalar(math_t* outDev,
const math_t* inDev,
const math_t* singleScalarDev,
IdxType len,
cudaStream_t stream)
template <typename InT, typename OutT = InT, typename IdxType = int>
void addDevScalar(
OutT* outDev, const InT* inDev, const InT* singleScalarDev, IdxType len, cudaStream_t stream)
{
detail::addDevScalar(outDev, inDev, singleScalarDev, len, stream);
}

/**
* @defgroup add Addition Arithmetic
* @{
*/

/**
* @brief Elementwise add operation on the input buffers
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @param handle raft::handle_t
* @param out Output
* @param in1 First Input
* @param in2 Second Input
*/
template <typename InType,
typename OutType = InType,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
typename = raft::enable_if_device_mdspan<OutType, InType>>
void add(const raft::handle_t& handle, OutType out, const InType in1, const InType in2)
{
using in_element_t = typename InType::element_type;
using out_element_t = typename OutType::element_type;
divyegala marked this conversation as resolved.
Show resolved Hide resolved

RAFT_EXPECTS(out.is_exhaustive(), "Output must be contiguous");
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_EXPECTS(in1.is_exhaustive(), "Input 1 must be contiguous");
RAFT_EXPECTS(in2.is_exhaustive(), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
Copy link
Contributor

@mhoemmen mhoemmen Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the goal here is to use 32-bit indices if possible, when inverting the layout mapping to use a 1-D loop index. This can be done, but there are two correctness issues with your approach.

  1. The right quantity to test here is out.required_span_size(), not out.size(). The layout mapping maps the input multidimensional index to the half-open interval of offsets [0, out.required_span_size()).

  2. The layout_{left, right, stride}::mapping constructors generally have as a precondition that the required span size of the input extents (and strides, if applicable) be representable as a value of type index_type.

Here is an approach that would address these issues.

template<class T>
constexpr bool is_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) == std::uint32_t;
template<class T>
constexpr bool is_greater_than_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) > std::uint32_t;

if constexpr (is_32_bit_integral_v<typename OutType::index_type>) {
  // ... always call 32-bit version ...
} else if constexpr (is_greater_than_32_bit_integral_v<typename OutType::index_type>) {
  // ... test the value of `required_span_size()`; dispatch to 32-bit or index_type (64 or more bits) as needed ...
} else {
  // ... always use index_type, which is 16 bits or less here ...
}

You'll also want to check the index_type and required_span_size() of the other mdspan. The above approach has the advantage that it only compiles an inner kernel for index types that you actually use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In point 2, what happens in extreme cases? Consider index_type=uint32_t with extents {2^32, 2}. In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala required_span_size() is not representable by index_type in this case. For layout_left and layout_right, required_span_size() and size() are the same mathematically. The only difference is the return type (index_type resp. size_t). For layout_stride, though, required_span_size() can be greater than the size(). For other layouts (e.g., the "matrix of a single value" layout that maps all multidimensional indices to the offset zero), required_span_size() can be less than size().

Note that while it's UB for users to violate preconditions, implementations aren't required to check preconditions. The reference implementation of layout_left does not currently check preconditions, as you can see here, for instance. This means two things.

  1. If someone gives you a layout_{left,right,stride}::mapping instance (e.g., in an mdspan), then you can assume that the precondition is satisfied.

  2. If you are constructing a layout_{left,right,stride}::mapping instance (e.g., by constructing an mdspan with a pointer and extents), then you are responsible for ensuring that the precondition is satisfied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala wrote:

In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Those are two separate questions, actually! : - )

  1. required_span_size() is not representable by index_type in this case.
  2. Giving this extents object to layout_{left,right,stride}::mapping's constructor violates the constructor's precondition. It could overflow, or it could open a portal to the Awesome Dimension and let loose a swarm of nasal demons who search out precondition violators and boop them gently on the nose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhoemmen thanks for the explanations! How do we really represent such edge cases and safely obtain the product of the extents? Sounds like size() is the safe way to obtain the product without violating any pre-conditions since it's representable by size_t?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala Gladly! : - )

How do we really represent such edge cases and safely obtain the product of the extents?

By the time the user has created a layout mapping, it's already too late. What I mean by that is that if required_span_size() doesn't fit index_type, then the user will likely get the wrong answer when they try to index into the mdspan.

In what follows in my comment, I'll distinguish between "the Preconditions in the spec" and "what the reference implementation does." The reference implementation currently does not check this precondition in the layout mapping. This means that it's possible for users to construct extents for which the mapping's required_span_size() can overflow.

We can prevent this by wrapping mdspan creation to check the extents object for potential overflow, before it goes into a layout mapping's constructor. It's not UB to construct, e.g., dextents<uint16_t, 2>( 2^{15} , 2^{15} ). We just need to intercept that naughty extents value before it goes into a layout mapping's constructor. Otherwise, the layout mapping has the freedom to do whatever it likes, including calling abort().

Our mdarray implementation's conversion to mdspan can also check, but again, we're probably better off making the wrapper explicit and not part of the mdarray proposal. WG21 likes Preconditions and wants violating them to be UB. If we want some specified behavior (e.g., throwing a particular exception, or calling terminate() after printing a helpful error message), then we'll have to implement that ourselves.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like out.required_span_size() does not work. How do I access this from the layout?

add<in_element_t, out_element_t, std::uint32_t>(out.data_handle(),
divyegala marked this conversation as resolved.
Show resolved Hide resolved
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
add<in_element_t, out_element_t, std::uint64_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of device scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param handle raft::handle_t
* @param out Output
* @param in Input
* @param scalar raft::device_scalar_view
*/
template <typename InType,
typename OutType = InType,
typename ScalarIdxType = std::uint32_t,
typename = raft::enable_if_device_mdspan<OutType, InType>>
void add_scalar(const raft::handle_t& handle,
OutType out,
const InType in,
const raft::device_scalar_view<typename InType::element_type, ScalarIdxType> scalar)
{
using in_element_t = typename InType::element_type;
using out_element_t = typename OutType::element_type;

RAFT_EXPECTS(out.is_exhaustive(), "Output must be contiguous");
RAFT_EXPECTS(in.is_exhaustive(), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addDevScalar<in_element_t, out_element_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addDevScalar<in_element_t, out_element_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of host scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param handle raft::handle_t
* @param out Output
* @param in Input
* @param scalar raft::host_scalar_view
*/
template <typename InType,
typename OutType = InType,
typename ScalarIdxType = std::uint32_t,
typename = raft::enable_if_device_mdspan<OutType, InType>>
void add_scalar(const raft::handle_t& handle,
OutType out,
const InType in,
const raft::host_scalar_view<typename InType::element_type, ScalarIdxType> scalar)
{
using in_element_t = typename InType::element_type;
using out_element_t = typename OutType::element_type;

RAFT_EXPECTS(out.is_exhaustive(), "Output must be contiguous");
RAFT_EXPECTS(in.is_exhaustive(), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addScalar<in_element_t, out_element_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addScalar<in_element_t, out_element_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/** @} */ // end of group add

}; // end namespace linalg
}; // end namespace raft

Expand Down
28 changes: 28 additions & 0 deletions cpp/include/raft/linalg/apply.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace raft::linalg {

/**
* @brief Enum for reduction/broadcast where an operation is to be performed along
* a matrix's rows or columns
*
*/
enum class Apply { ALONG_ROWS, ALONG_COLUMNS };
divyegala marked this conversation as resolved.
Show resolved Hide resolved

} // end namespace raft::linalg
73 changes: 73 additions & 0 deletions cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,79 @@ void axpy(const raft::handle_t& handle,
detail::axpy<T, DevicePointerMode>(handle, n, alpha, x, incx, y, incy, stream);
}

/**
* @defgroup axpy cuBLAS axpy
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* @{
*/

/**
* @brief the wrapper of cublas axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam MdspanType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [in] incx stride between consecutive elements of x
* @param [inout] y Output vector
* @param [in] incy stride between consecutive elements of y
*/
template <typename MdspanType, typename = raft::enable_if_device_mdspan<MdspanType>>
void axpy(const raft::handle_t& handle,
raft::device_scalar_view<typename MdspanType::element_type, ScalarIdxType> alpha,
const MdspanType x,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
const int incx,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
MdspanType y,
const int incy)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input")

axpy<typename MdspanType::element_type, true>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
incx,
y.data_handle(),
incy,
handle.get_stream());
}

/**
* @brief the wrapper of cublas axpy function
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* It computes the following equation: y = alpha * x + y
*
* @tparam MdspanType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [in] incx stride between consecutive elements of x
* @param [inout] y Output vector
* @param [in] incy stride between consecutive elements of y
*/
template <typename MdspanType, typename = raft::enable_if_device_mdspan<MdspanType>>
void axpy(const raft::handle_t& handle,
raft::host_scalar_view<typename MdspanType::element_type, ScalarIdxType> alpha,
const MdspanType x,
const int incx,
MdspanType y,
const int incy)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input")

axpy<typename MdspanType::element_type, false>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
incx,
y.data_handle(),
incy,
handle.get_stream());
}

/** @} */ // end of group axpy

} // namespace raft::linalg

#endif
48 changes: 48 additions & 0 deletions cpp/include/raft/linalg/binary_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "detail/binary_op.cuh"

#include <raft/core/mdarray.hpp>
#include <raft/cuda_utils.cuh>

namespace raft {
Expand Down Expand Up @@ -52,6 +53,53 @@ void binaryOp(
detail::binaryOp(out, in1, in2, len, op, stream);
}

/**
* @defgroup binary_op Element-Wise Binary Operation
* @{
*/

/**
* @brief perform element-wise binary operation on the input arrays
* @tparam InType Input Type raft::device_mdspan
* @tparam Lambda the device-lambda performing the actual operation
* @tparam OutType Output Type raft::device_mdspan
* @tparam TPB threads-per-block in the final kernel launched
* @param handle raft::handle_t
* @param out Output
* @param in1 First input
* @param in2 Second input
* @param op the device-lambda
* @note Lambda must be a functor with the following signature:
* `OutType func(const InType& val1, const InType& val2);`
*/
template <typename InType,
typename Lambda,
typename OutType = InType,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
int TPB = 256,
typename = raft::enable_if_device_mdspan<InType, OutType>>
void binary_op(
const raft::handle_t& handle, OutType out, const InType in1, const InType in2, Lambda op)
{
RAFT_EXPECTS(out.is_exhaustive(), "Output must be contiguous");
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_EXPECTS(in1.is_exhaustive(), "Input 1 must be contiguous");
RAFT_EXPECTS(in2.is_exhaustive(), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

using in_element_t = typename InType::element_type;
using out_element_t = typename OutType::element_type;

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
binaryOp<in_element_t, Lambda, out_element_t, std::uint32_t, TPB>(
out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream());
} else {
binaryOp<in_element_t, Lambda, out_element_t, std::uint64_t, TPB>(
out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream());
}
}

/** @} */ // end of group binary_op

}; // end namespace linalg
}; // end namespace raft

Expand Down
Loading