Skip to content

Commit

Permalink
mdspanify raft::random functions uniformInt, normalTable, fill, berno…
Browse files Browse the repository at this point in the history
…ulli, and scaled_bernoulli (rapidsai#897)

Add mdspan overloads of `raft::random` functions `uniformInt`, `normalInt`, `normalTable`, `fill`, `bernoulli`, and `scaled_bernoulli`.  Improve `normalTable` documentation to explain that the output table has a row-major layout.

This is rebased atop PR rapidsai#896, which I've since closed.

This should complete all the `raft::random` mdspan overloads.

Authors:
  - Mark Hoemmen (https://github.com/mhoemmen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#897
  • Loading branch information
mhoemmen authored Oct 6, 2022
1 parent b3d5103 commit 7bbae13
Show file tree
Hide file tree
Showing 4 changed files with 493 additions and 12 deletions.
190 changes: 188 additions & 2 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <type_traits>
#include <variant>

namespace raft::random {

Expand Down Expand Up @@ -75,6 +76,34 @@ void uniform(const raft::handle_t& handle,
/**
* @brief Generate uniformly distributed integers in the given range
*
* @tparam OutputValueType Integral type; value type of the output vector
* @tparam IndexType Type used to represent length of the output vector
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output vector of random numbers
* @param[in] start start of the range
* @param[in] end end of the range
*/
template <typename OutputValueType, typename IndexType>
void uniformInt(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType start,
OutputValueType end)
{
static_assert(
std::is_same<OutputValueType, typename std::remove_cv<OutputValueType>::type>::value,
"uniformInt: The output vector must be a view of nonconst, "
"so that we can write to it.");
static_assert(std::is_integral<OutputValueType>::value,
"uniformInt: The elements of the output vector must have integral type.");
detail::uniformInt(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `uniformInt`
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand Down Expand Up @@ -144,6 +173,35 @@ void normal(const raft::handle_t& handle,
/**
* @brief Generate normal distributed integers
*
* @tparam OutputValueType Integral type; value type of the output vector
* @tparam IndexType Integral type of the output vector's length
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output array
* @param[in] mu mean of the distribution
* @param[in] sigma standard deviation of the distribution
*/
template <typename OutputValueType, typename IndexType>
void normalInt(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType mu,
OutputValueType sigma)
{
static_assert(
std::is_same<OutputValueType, typename std::remove_cv<OutputValueType>::type>::value,
"normalInt: The output vector must be a view of nonconst, "
"so that we can write to it.");
static_assert(std::is_integral<OutputValueType>::value,
"normalInt: The output vector's value type must be an integer.");

detail::normalInt(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `normalInt`
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand All @@ -170,7 +228,70 @@ void normalInt(const raft::handle_t& handle,
*
* Each row in this table conforms to a normally distributed n-dim vector
* whose mean is the input vector and standard deviation is the corresponding
* vector or scalar. Correlations among the dimensions itself is assumed to
* vector or scalar. Correlations among the dimensions itself are assumed to
* be absent.
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[in] mu_vec mean vector (of length `out.extent(1)`)
* @param[in] sigma Either the standard-deviation vector
* (of length `out.extent(1)`) of each component,
* or a scalar standard deviation for all components.
* @param[out] out the output table
*/
template <typename OutputValueType, typename IndexType>
void normalTable(
const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const OutputValueType, IndexType> mu_vec,
std::variant<raft::device_vector_view<const OutputValueType, IndexType>, OutputValueType> sigma,
raft::device_matrix_view<OutputValueType, IndexType, raft::row_major> out)
{
const OutputValueType* sigma_vec_ptr = nullptr;
OutputValueType sigma_value{};

using sigma_vec_type = raft::device_vector_view<const OutputValueType, IndexType>;
if (std::holds_alternative<sigma_vec_type>(sigma)) {
auto sigma_vec = std::get<sigma_vec_type>(sigma);
RAFT_EXPECTS(sigma_vec.extent(0) == out.extent(1),
"normalTable: The sigma vector "
"has length %zu, which does not equal the number of columns "
"in the output table %zu.",
static_cast<size_t>(sigma_vec.extent(0)),
static_cast<size_t>(out.extent(1)));
// The extra length check makes this work even if sigma_vec views a std::vector,
// where .data() need not return nullptr even if .size() is zero.
sigma_vec_ptr = sigma_vec.extent(0) == 0 ? nullptr : sigma_vec.data_handle();
} else {
sigma_value = std::get<OutputValueType>(sigma);
}

RAFT_EXPECTS(mu_vec.extent(0) == out.extent(1),
"normalTable: The mu vector "
"has length %zu, which does not equal the number of columns "
"in the output table %zu.",
static_cast<size_t>(mu_vec.extent(0)),
static_cast<size_t>(out.extent(1)));

detail::normalTable(rng_state,
out.data_handle(),
out.extent(0),
out.extent(1),
mu_vec.data_handle(),
sigma_vec_ptr,
sigma_value,
handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `normalTable`.
*
* Each row in this table conforms to a normally distributed n-dim vector
* whose mean is the input vector and standard deviation is the corresponding
* vector or scalar. Correlations among the dimensions itself are assumed to
* be absent.
*
* @tparam OutType data type of output random number
Expand Down Expand Up @@ -200,7 +321,27 @@ void normalTable(const raft::handle_t& handle,
}

/**
* @brief Fill an array with the given value
* @brief Fill a vector with the given value
*
* @tparam OutputValueType Value type of the output vector
* @tparam IndexType Integral type used to represent length of the output vector
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[in] val value with which to fill the output vector
* @param[out] out the output vector
*/
template <typename OutputValueType, typename IndexType>
void fill(const raft::handle_t& handle,
RngState& rng_state,
OutputValueType val,
raft::device_vector_view<OutputValueType, IndexType> out)
{
detail::fill(rng_state, out.data_handle(), out.extent(0), val, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `fill`
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
Expand All @@ -219,6 +360,28 @@ void fill(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenTy
/**
* @brief Generate bernoulli distributed boolean array
*
* @tparam OutputValueType Type of each element of the output vector;
* must be able to represent boolean values (e.g., `bool`)
* @tparam IndexType Integral type of the output vector's length
* @tparam Type Data type in which to compute the probabilities
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output vector
* @param[in] prob coin-toss probability for heads
*/
template <typename OutputValueType, typename IndexType, typename Type>
void bernoulli(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
Type prob)
{
detail::bernoulli(rng_state, out.data_handle(), out.extent(0), prob, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `bernoulli`
*
* @tparam Type data type in which to compute the probabilities
* @tparam OutType output data type
* @tparam LenType data type used to represent length of the arrays
Expand All @@ -239,6 +402,29 @@ void bernoulli(
/**
* @brief Generate bernoulli distributed array and applies scale
*
* @tparam OutputValueType Data type in which to compute the probabilities
* @tparam IndexType Integral type of the output vector's length
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output vector
* @param[in] prob coin-toss probability for heads
* @param[in] scale scaling factor
*/
template <typename OutputValueType, typename IndexType>
void scaled_bernoulli(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType prob,
OutputValueType scale)
{
detail::scaled_bernoulli(
rng_state, out.data_handle(), out.extent(0), prob, scale, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `scaled_bernoulli`
*
* @tparam OutType data type in which to compute the probabilities
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand Down
Loading

0 comments on commit 7bbae13

Please sign in to comment.