Skip to content

Commit

Permalink
mdspan-ify several rng functions
Browse files Browse the repository at this point in the history
Add overloads taking the output vector as mdspan,
of the following raft::random functions:

* normal
* lognormal
* uniform
* gumbel
* logistic
* exponential
* rayleigh
* laplace
  • Loading branch information
mhoemmen committed Sep 29, 2022
1 parent 894a7fe commit e855198
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 24 deletions.
180 changes: 177 additions & 3 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ namespace raft::random {
/**
* @brief Generate uniformly distributed numbers in the given range
*
* @tparam OutputValueType Data type of output random number
* @tparam Index Data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output array
* @param[in] start start of the range
* @param[in] end end of the range
*/
template <typename OutputValueType, typename IndexType>
void uniform(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType start,
OutputValueType end)
{
detail::uniform(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream());
}

/**
* @brief Legacy overload of `uniform` taking raw pointers
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand Down Expand Up @@ -75,6 +97,29 @@ void uniformInt(const raft::handle_t& handle,

/**
* @brief Generate normal distributed numbers
* with a given mean and standard deviation
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output array
* @param[in] mu mean of the distribution
* @param[in] sigma std-dev of the distribution
*/
template <typename OutputValueType, typename IndexType>
void normal(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType mu,
OutputValueType sigma)
{
detail::normal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `normal`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
Expand Down Expand Up @@ -217,6 +262,29 @@ void scaled_bernoulli(const raft::handle_t& handle,
/**
* @brief Generate Gumbel distributed random numbers
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] mu mean value
* @param[in] beta scale value
* @note https://en.wikipedia.org/wiki/Gumbel_distribution
*/
template <typename OutputValueType, typename IndexType = int>
void gumbel(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType mu,
OutputValueType beta)
{
detail::gumbel(rng_state, out.data_handle(), out.extent(0), mu, beta, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `gumbel`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand All @@ -241,14 +309,36 @@ void gumbel(const raft::handle_t& handle,
/**
* @brief Generate lognormal distributed numbers
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out the output array
* @param[in] mu mean of the distribution
* @param[in] sigma standard deviation of the distribution
*/
template <typename OutputValueType, typename IndexType>
void lognormal(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType mu,
OutputValueType sigma)
{
detail::lognormal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `lognormal`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] ptr the output array
* @param[in] len the number of elements in the output
* @param[in] mu mean of the distribution
* @param[in] sigma std-dev of the distribution
* @param[in] sigma standard deviation of the distribution
*/
template <typename OutType, typename LenType = int>
void lognormal(const raft::handle_t& handle,
Expand All @@ -264,6 +354,28 @@ void lognormal(const raft::handle_t& handle,
/**
* @brief Generate logistic distributed random numbers
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] mu mean value
* @param[in] scale scale value
*/
template <typename OutputValueType, typename IndexType = int>
void logistic(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType mu,
OutputValueType scale)
{
detail::logistic(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `logistic`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand All @@ -287,13 +399,33 @@ void logistic(const raft::handle_t& handle,
/**
* @brief Generate exponentially distributed random numbers
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] lambda the exponential distribution's lambda parameter
*/
template <typename OutputValueType, typename IndexType>
void exponential(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType lambda)
{
detail::exponential(rng_state, out.data_handle(), out.extent(0), lambda, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `exponential`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] ptr output array
* @param[in] len number of elements in the output array
* @param[in] lambda the lambda
* @param[in] lambda the exponential distribution's lambda parameter
*/
template <typename OutType, typename LenType = int>
void exponential(
Expand All @@ -305,13 +437,33 @@ void exponential(
/**
* @brief Generate rayleigh distributed random numbers
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] sigma the distribution's sigma parameter
*/
template <typename OutputValueType, typename IndexType>
void rayleigh(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType sigma)
{
detail::rayleigh(rng_state, out.data_handle(), out.extent(0), sigma, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `rayleigh`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] ptr output array
* @param[in] len number of elements in the output array
* @param[in] sigma the sigma
* @param[in] sigma the distribution's sigma parameter
*/
template <typename OutType, typename LenType = int>
void rayleigh(
Expand All @@ -323,6 +475,28 @@ void rayleigh(
/**
* @brief Generate laplace distributed random numbers
*
* @tparam OutputValueType data type of output random number
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] mu the mean
* @param[in] scale the scale
*/
template <typename OutputValueType, typename IndexType>
void laplace(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutputValueType, IndexType> out,
OutputValueType mu,
OutputValueType scale)
{
detail::laplace(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream());
}

/**
* @brief Legacy raw pointer overload of `laplace`.
*
* @tparam OutType data type of output random number
* @tparam LenType data type used to represent length of the arrays
* @param[in] handle raft handle for resource management
Expand Down
Loading

0 comments on commit e855198

Please sign in to comment.