Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] mdspanify raft::random functions uniformInt, normalTable, fill, bernoulli, and scaled_bernoulli #897

161 changes: 159 additions & 2 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <type_traits>
#include <variant>

namespace raft::random {

Expand Down Expand Up @@ -75,6 +76,34 @@ void uniform(const raft::handle_t& handle,
/**
* @brief Generate uniformly distributed integers in the given range
*
* @tparam OutputValueType Integral type; value type of the output vector
* @tparam IndexType Type used to represent length of the output vector
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output vector of random numbers
* @param[in] start start of the range
* @param[in] end end of the range
*/
template <typename OutputValueType, typename IndexType>
void uniformInt(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType start,
OutputValueType end)
{
static_assert(
std::is_same<OutputValueType, typename std::remove_cv<OutputValueType>::type>::value,
"uniformInt: The output vector must be a view of nonconst, "
"so that we can write to it.");
static_assert(std::is_integral<OutputValueType>::value,
"uniformInt: The elements of the output vector must have integral type.");
detail::uniformInt(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `uniformInt`
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand Down Expand Up @@ -170,7 +199,70 @@ void normalInt(const raft::handle_t& handle,
*
* Each row in this table conforms to a normally distributed n-dim vector
* whose mean is the input vector and standard deviation is the corresponding
* vector or scalar. Correlations among the dimensions itself is assumed to
* vector or scalar. Correlations among the dimensions itself are assumed to
* be absent.
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[in] mu_vec mean vector (of length `out.extent(1)`)
* @param[in] sigma_vec Either the standard-deviation vector
* (of length `out.extent(1)`) of each component,
* or a scalar standard deviation for all components.
* @param[out] ptr the output table
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename OutputValueType, typename IndexType>
void normalTable(
const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const OutputValueType, IndexType> mu_vec,
std::variant<raft::device_vector_view<const OutputValueType, IndexType>, OutputValueType> sigma,
raft::device_matrix_view<OutputValueType, IndexType, raft::row_major> out)
{
const OutputValueType* sigma_vec_ptr = nullptr;
OutputValueType sigma_value{};

using sigma_vec_type = raft::device_vector_view<const OutputValueType, IndexType>;
if (std::holds_alternative<sigma_vec_type>(sigma)) {
auto sigma_vec = std::get<sigma_vec_type>(sigma);
RAFT_EXPECTS(sigma_vec.extent(0) == out.extent(1),
"normalTable: The sigma vector "
"has length %zu, which does not equal the number of columns "
"in the output table %zu.",
static_cast<size_t>(sigma_vec.extent(0)),
static_cast<size_t>(out.extent(1)));
// The extra length check makes this work even if sigma_vec views a std::vector,
// where .data() need not return nullptr even if .size() is zero.
sigma_vec_ptr = sigma_vec.extent(0) == 0 ? nullptr : sigma_vec.data_handle();
} else {
sigma_value = std::get<OutputValueType>(sigma);
}

RAFT_EXPECTS(mu_vec.extent(0) == out.extent(1),
"normalTable: The mu vector "
"has length %zu, which does not equal the number of columns "
"in the output table %zu.",
static_cast<size_t>(mu_vec.extent(0)),
static_cast<size_t>(out.extent(1)));

detail::normalTable(rng_state,
out.data_handle(),
out.extent(0),
out.extent(1),
mu_vec.data_handle(),
sigma_vec_ptr,
sigma_value,
handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `normalTable`.
*
* Each row in this table conforms to a normally distributed n-dim vector
* whose mean is the input vector and standard deviation is the corresponding
* vector or scalar. Correlations among the dimensions itself are assumed to
* be absent.
*
* @tparam OutType data type of output random number
Expand Down Expand Up @@ -200,7 +292,27 @@ void normalTable(const raft::handle_t& handle,
}

/**
* @brief Fill an array with the given value
* @brief Fill a vector with the given value
*
* @tparam OutputValueType Value type of the output vector
* @tparam IndexType Integral type used to represent length of the output vector
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[in] val value with which to fill the output vector
* @param[out] out the output vector
*/
template <typename OutputValueType, typename IndexType>
void fill(const raft::handle_t& handle,
RngState& rng_state,
OutputValueType val,
raft::device_vector_view<OutputValueType, IndexType> out)
{
detail::fill(rng_state, out.data_handle(), out.extent(0), val, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `fill`
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
Expand All @@ -219,6 +331,28 @@ void fill(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenTy
/**
* @brief Generate bernoulli distributed boolean array
*
* @tparam OutputValueType Type of each element of the output vector;
* must be able to represent boolean values (e.g., `bool`)
* @tparam IndexType Integral type of the output vector's length
* @tparam Type Data type in which to compute the probabilities
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output vector
* @param[in] prob coin-toss probability for heads
*/
template <typename OutputValueType, typename IndexType, typename Type>
void bernoulli(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
Type prob)
{
detail::bernoulli(rng_state, out.data_handle(), out.extent(0), prob, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `bernoulli`
*
* @tparam Type data type in which to compute the probabilities
* @tparam OutType output data type
* @tparam LenType data type used to represent length of the arrays
Expand All @@ -239,6 +373,29 @@ void bernoulli(
/**
* @brief Generate bernoulli distributed array and applies scale
*
* @tparam OutputValueType Data type in which to compute the probabilities
* @tparam IndexType Integral type of the output vector's length
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output vector
* @param[in] prob coin-toss probability for heads
* @param[in] scale scaling factor
*/
template <typename OutputValueType, typename IndexType>
void scaled_bernoulli(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType prob,
OutputValueType scale)
{
detail::scaled_bernoulli(
rng_state, out.data_handle(), out.extent(0), prob, scale, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `scaled_bernoulli`
*
* @tparam OutType data type in which to compute the probabilities
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand Down
Loading