Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RNG API fixes #630

Merged
merged 3 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 84 additions & 92 deletions cpp/include/raft/random/detail/rng_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,12 @@ namespace raft {
namespace random {
namespace detail {

#if defined(__RNG_H_INCLUSION_DEPRECATED)
#define POTENTIAL_DEPR \
[[deprecated("Include raft/random/rng_device.cuh to use the device-only API")]]
#else
#define POTENTIAL_DEPR
#endif

/**
* The device state used to communicate RNG state from host to device.
* As of now, it is just a templated version of `RngState`.
*/
template <typename GenType>
struct POTENTIAL_DEPR DeviceState {
struct DeviceState {
using gen_t = GenType;
static constexpr auto GEN_TYPE = gen_t::GEN_TYPE;

Expand All @@ -55,37 +48,37 @@ struct POTENTIAL_DEPR DeviceState {
};

template <typename OutType>
struct POTENTIAL_DEPR InvariantDistParams {
struct InvariantDistParams {
OutType const_val;
};

template <typename OutType>
struct POTENTIAL_DEPR UniformDistParams {
struct UniformDistParams {
OutType start;
OutType end;
};

template <typename OutType, typename DiffType>
struct POTENTIAL_DEPR UniformIntDistParams {
struct UniformIntDistParams {
OutType start;
OutType end;
DiffType diff;
};

template <typename OutType>
struct POTENTIAL_DEPR NormalDistParams {
struct NormalDistParams {
OutType mu;
OutType sigma;
};

template <typename IntType>
struct POTENTIAL_DEPR NormalIntDistParams {
struct NormalIntDistParams {
IntType mu;
IntType sigma;
};

template <typename OutType, typename LenType>
struct POTENTIAL_DEPR NormalTableDistParams {
struct NormalTableDistParams {
LenType n_rows;
LenType n_cols;
const OutType* mu_vec;
Expand All @@ -94,60 +87,59 @@ struct POTENTIAL_DEPR NormalTableDistParams {
};

template <typename OutType>
struct POTENTIAL_DEPR BernoulliDistParams {
struct BernoulliDistParams {
OutType prob;
};

template <typename OutType>
struct POTENTIAL_DEPR ScaledBernoulliDistParams {
struct ScaledBernoulliDistParams {
OutType prob;
OutType scale;
};

template <typename OutType>
struct POTENTIAL_DEPR GumbelDistParams {
struct GumbelDistParams {
OutType mu;
OutType beta;
};

template <typename OutType>
struct POTENTIAL_DEPR LogNormalDistParams {
struct LogNormalDistParams {
OutType mu;
OutType sigma;
};

template <typename OutType>
struct POTENTIAL_DEPR LogisticDistParams {
struct LogisticDistParams {
OutType mu;
OutType scale;
};

template <typename OutType>
struct POTENTIAL_DEPR ExponentialDistParams {
struct ExponentialDistParams {
OutType lambda;
};

template <typename OutType>
struct POTENTIAL_DEPR RayleighDistParams {
struct RayleighDistParams {
OutType sigma;
};

template <typename OutType>
struct POTENTIAL_DEPR LaplaceDistParams {
struct LaplaceDistParams {
OutType mu;
OutType scale;
};

// Not really a distro, useful for sample without replacement function
template <typename WeightsT, typename IdxT>
struct POTENTIAL_DEPR SamplingParams {
struct SamplingParams {
IdxT* inIdxPtr;
const WeightsT* wts;
};

template <typename Type>
POTENTIAL_DEPR DI void box_muller_transform(
Type& val1, Type& val2, Type sigma1, Type mu1, Type sigma2, Type mu2)
DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1, Type sigma2, Type mu2)
{
constexpr Type twoPi = Type(2.0) * Type(3.141592654);
constexpr Type minus2 = -Type(2.0);
Expand All @@ -160,39 +152,39 @@ POTENTIAL_DEPR DI void box_muller_transform(
}

template <typename Type>
POTENTIAL_DEPR DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1)
DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1)
{
box_muller_transform<Type>(val1, val2, sigma1, mu1, sigma1, mu1);
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
InvariantDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
InvariantDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
*val = params.const_val;
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
UniformDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
UniformDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
OutType res;
gen.next(res);
*val = (res * (params.end - params.start)) + params.start;
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
UniformIntDistParams<OutType, uint32_t> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
UniformIntDistParams<OutType, uint32_t> params,
LenType idx = 0,
LenType stride = 0)
{
uint32_t x = 0;
uint32_t s = params.diff;
Expand All @@ -211,11 +203,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
UniformIntDistParams<OutType, uint64_t> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
UniformIntDistParams<OutType, uint64_t> params,
LenType idx = 0,
LenType stride = 0)
{
uint64_t x = 0;
gen.next(x);
Expand All @@ -236,7 +228,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(
DI void custom_next(
GenType& gen, OutType* val, NormalDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res1, res2;
Expand All @@ -253,11 +245,11 @@ POTENTIAL_DEPR DI void custom_next(
}

template <typename GenType, typename IntType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
IntType* val,
NormalIntDistParams<IntType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
IntType* val,
NormalIntDistParams<IntType> params,
LenType idx = 0,
LenType stride = 0)
{
IntType res1_int, res2_int;

Expand All @@ -276,11 +268,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
NormalTableDistParams<OutType, LenType> params,
LenType idx,
LenType stride)
DI void custom_next(GenType& gen,
OutType* val,
NormalTableDistParams<OutType, LenType> params,
LenType idx,
LenType stride)
{
OutType res1, res2;

Expand All @@ -301,7 +293,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename Type, typename LenType>
POTENTIAL_DEPR DI void custom_next(
DI void custom_next(
GenType& gen, OutType* val, BernoulliDistParams<Type> params, LenType idx = 0, LenType stride = 0)
{
Type res = 0;
Expand All @@ -310,19 +302,19 @@ POTENTIAL_DEPR DI void custom_next(
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
ScaledBernoulliDistParams<OutType> params,
LenType idx,
LenType stride)
DI void custom_next(GenType& gen,
OutType* val,
ScaledBernoulliDistParams<OutType> params,
LenType idx,
LenType stride)
{
OutType res = 0;
gen.next(res);
*val = res < params.prob ? -params.scale : params.scale;
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(
DI void custom_next(
GenType& gen, OutType* val, GumbelDistParams<OutType> params, LenType idx = 0, LenType stride = 0)
{
OutType res = 0;
Expand All @@ -335,11 +327,11 @@ POTENTIAL_DEPR DI void custom_next(
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
LogNormalDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
LogNormalDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
OutType res1 = 0, res2 = 0;
do {
Expand All @@ -353,11 +345,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
LogisticDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
LogisticDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
OutType res;

Expand All @@ -370,11 +362,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
ExponentialDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
ExponentialDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
OutType res;
gen.next(res);
Expand All @@ -383,11 +375,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
RayleighDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
RayleighDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
OutType res;
gen.next(res);
Expand All @@ -398,11 +390,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(GenType& gen,
OutType* val,
LaplaceDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
DI void custom_next(GenType& gen,
OutType* val,
LaplaceDistParams<OutType> params,
LenType idx = 0,
LenType stride = 0)
{
OutType res, out;

Expand All @@ -425,7 +417,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen,
}

template <typename GenType, typename OutType, typename LenType>
POTENTIAL_DEPR DI void custom_next(
DI void custom_next(
GenType& gen, OutType* val, SamplingParams<OutType, LenType> params, LenType idx, LenType stride)
{
OutType res;
Expand All @@ -442,7 +434,7 @@ POTENTIAL_DEPR DI void custom_next(

/** Philox-based random number generator */
// Courtesy: Jakub Szuppe
struct POTENTIAL_DEPR PhiloxGenerator {
struct PhiloxGenerator {
static constexpr auto GEN_TYPE = GeneratorType::GenPhilox;

/**
Expand Down Expand Up @@ -540,7 +532,7 @@ struct POTENTIAL_DEPR PhiloxGenerator {
};

/** PCG random number generator */
struct POTENTIAL_DEPR PCGenerator {
struct PCGenerator {
static constexpr auto GEN_TYPE = GeneratorType::GenPC;

/**
Expand Down
Loading