diff --git a/cpp/include/raft/random/detail/rng_device.cuh b/cpp/include/raft/random/detail/rng_device.cuh index 0fd36473c5..7e2863e030 100644 --- a/cpp/include/raft/random/detail/rng_device.cuh +++ b/cpp/include/raft/random/detail/rng_device.cuh @@ -29,19 +29,12 @@ namespace raft { namespace random { namespace detail { -#if defined(__RNG_H_INCLUSION_DEPRECATED) -#define POTENTIAL_DEPR \ - [[deprecated("Include raft/random/rng_device.cuh to use the device-only API")]] -#else -#define POTENTIAL_DEPR -#endif - /** * The device state used to communicate RNG state from host to device. * As of now, it is just a templated version of `RngState`. */ template -struct POTENTIAL_DEPR DeviceState { +struct DeviceState { using gen_t = GenType; static constexpr auto GEN_TYPE = gen_t::GEN_TYPE; @@ -55,37 +48,37 @@ struct POTENTIAL_DEPR DeviceState { }; template -struct POTENTIAL_DEPR InvariantDistParams { +struct InvariantDistParams { OutType const_val; }; template -struct POTENTIAL_DEPR UniformDistParams { +struct UniformDistParams { OutType start; OutType end; }; template -struct POTENTIAL_DEPR UniformIntDistParams { +struct UniformIntDistParams { OutType start; OutType end; DiffType diff; }; template -struct POTENTIAL_DEPR NormalDistParams { +struct NormalDistParams { OutType mu; OutType sigma; }; template -struct POTENTIAL_DEPR NormalIntDistParams { +struct NormalIntDistParams { IntType mu; IntType sigma; }; template -struct POTENTIAL_DEPR NormalTableDistParams { +struct NormalTableDistParams { LenType n_rows; LenType n_cols; const OutType* mu_vec; @@ -94,60 +87,59 @@ struct POTENTIAL_DEPR NormalTableDistParams { }; template -struct POTENTIAL_DEPR BernoulliDistParams { +struct BernoulliDistParams { OutType prob; }; template -struct POTENTIAL_DEPR ScaledBernoulliDistParams { +struct ScaledBernoulliDistParams { OutType prob; OutType scale; }; template -struct POTENTIAL_DEPR GumbelDistParams { +struct GumbelDistParams { OutType mu; OutType beta; }; template -struct POTENTIAL_DEPR LogNormalDistParams { +struct LogNormalDistParams { OutType mu; OutType sigma; }; template -struct POTENTIAL_DEPR LogisticDistParams { +struct LogisticDistParams { OutType mu; OutType scale; }; template -struct POTENTIAL_DEPR ExponentialDistParams { +struct ExponentialDistParams { OutType lambda; }; template -struct POTENTIAL_DEPR RayleighDistParams { +struct RayleighDistParams { OutType sigma; }; template -struct POTENTIAL_DEPR LaplaceDistParams { +struct LaplaceDistParams { OutType mu; OutType scale; }; // Not really a distro, useful for sample without replacement function template -struct POTENTIAL_DEPR SamplingParams { +struct SamplingParams { IdxT* inIdxPtr; const WeightsT* wts; }; template -POTENTIAL_DEPR DI void box_muller_transform( - Type& val1, Type& val2, Type sigma1, Type mu1, Type sigma2, Type mu2) +DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1, Type sigma2, Type mu2) { constexpr Type twoPi = Type(2.0) * Type(3.141592654); constexpr Type minus2 = -Type(2.0); @@ -160,27 +152,27 @@ POTENTIAL_DEPR DI void box_muller_transform( } template -POTENTIAL_DEPR DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1) +DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1) { box_muller_transform(val1, val2, sigma1, mu1, sigma1, mu1); } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - InvariantDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + InvariantDistParams params, + LenType idx = 0, + LenType stride = 0) { *val = params.const_val; } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - UniformDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + UniformDistParams params, + LenType idx = 0, + LenType stride = 0) { OutType res; gen.next(res); @@ -188,11 +180,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - UniformIntDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + UniformIntDistParams params, + LenType idx = 0, + LenType stride = 0) { uint32_t x = 0; uint32_t s = params.diff; @@ -211,11 +203,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - UniformIntDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + UniformIntDistParams params, + LenType idx = 0, + LenType stride = 0) { uint64_t x = 0; gen.next(x); @@ -236,7 +228,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next( +DI void custom_next( GenType& gen, OutType* val, NormalDistParams params, LenType idx = 0, LenType stride = 0) { OutType res1, res2; @@ -253,11 +245,11 @@ POTENTIAL_DEPR DI void custom_next( } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - IntType* val, - NormalIntDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + IntType* val, + NormalIntDistParams params, + LenType idx = 0, + LenType stride = 0) { IntType res1_int, res2_int; @@ -276,11 +268,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - NormalTableDistParams params, - LenType idx, - LenType stride) +DI void custom_next(GenType& gen, + OutType* val, + NormalTableDistParams params, + LenType idx, + LenType stride) { OutType res1, res2; @@ -301,7 +293,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next( +DI void custom_next( GenType& gen, OutType* val, BernoulliDistParams params, LenType idx = 0, LenType stride = 0) { Type res = 0; @@ -310,11 +302,11 @@ POTENTIAL_DEPR DI void custom_next( } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - ScaledBernoulliDistParams params, - LenType idx, - LenType stride) +DI void custom_next(GenType& gen, + OutType* val, + ScaledBernoulliDistParams params, + LenType idx, + LenType stride) { OutType res = 0; gen.next(res); @@ -322,7 +314,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next( +DI void custom_next( GenType& gen, OutType* val, GumbelDistParams params, LenType idx = 0, LenType stride = 0) { OutType res = 0; @@ -335,11 +327,11 @@ POTENTIAL_DEPR DI void custom_next( } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - LogNormalDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + LogNormalDistParams params, + LenType idx = 0, + LenType stride = 0) { OutType res1 = 0, res2 = 0; do { @@ -353,11 +345,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - LogisticDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + LogisticDistParams params, + LenType idx = 0, + LenType stride = 0) { OutType res; @@ -370,11 +362,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - ExponentialDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + ExponentialDistParams params, + LenType idx = 0, + LenType stride = 0) { OutType res; gen.next(res); @@ -383,11 +375,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - RayleighDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + RayleighDistParams params, + LenType idx = 0, + LenType stride = 0) { OutType res; gen.next(res); @@ -398,11 +390,11 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next(GenType& gen, - OutType* val, - LaplaceDistParams params, - LenType idx = 0, - LenType stride = 0) +DI void custom_next(GenType& gen, + OutType* val, + LaplaceDistParams params, + LenType idx = 0, + LenType stride = 0) { OutType res, out; @@ -425,7 +417,7 @@ POTENTIAL_DEPR DI void custom_next(GenType& gen, } template -POTENTIAL_DEPR DI void custom_next( +DI void custom_next( GenType& gen, OutType* val, SamplingParams params, LenType idx, LenType stride) { OutType res; @@ -442,7 +434,7 @@ POTENTIAL_DEPR DI void custom_next( /** Philox-based random number generator */ // Courtesy: Jakub Szuppe -struct POTENTIAL_DEPR PhiloxGenerator { +struct PhiloxGenerator { static constexpr auto GEN_TYPE = GeneratorType::GenPhilox; /** @@ -540,7 +532,7 @@ struct POTENTIAL_DEPR PhiloxGenerator { }; /** PCG random number generator */ -struct POTENTIAL_DEPR PCGenerator { +struct PCGenerator { static constexpr auto GEN_TYPE = GeneratorType::GenPC; /** diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 62d9ccc3ab..26ce93c068 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -16,11 +16,10 @@ #pragma once -#include "rng_device.cuh" - #include #include #include +#include #include namespace raft { @@ -56,44 +55,21 @@ namespace detail { * RAFT_CALL_RNG_FUNC(rng_state, (my_kernel<1, 2><<<1, 1>>>), 5); * @endcode */ -#define RAFT_CALL_RNG_FUNC(rng_state, func, ...) \ - switch ((rng_state).type) { \ - case GeneratorType::GenPhilox: { \ - DeviceState dev_state_philox{(rng_state)}; \ - RAFT_DEPAREN(func)(dev_state_philox, ##__VA_ARGS__); \ - break; \ - } \ - case GeneratorType::GenPC: { \ - DeviceState dev_state_pc{(rng_state)}; \ - RAFT_DEPAREN(func)(dev_state_pc, ##__VA_ARGS__); \ - break; \ - } \ - default: RAFT_FAIL("Unepxected generator type '%d'", int((rng_state).type)); \ +#define RAFT_CALL_RNG_FUNC(rng_state, func, ...) \ + switch ((rng_state).type) { \ + case raft::random::GeneratorType::GenPhilox: { \ + raft::random::DeviceState r_phil{(rng_state)}; \ + RAFT_DEPAREN(func)(r_phil, ##__VA_ARGS__); \ + break; \ + } \ + case raft::random::GeneratorType::GenPC: { \ + raft::random::DeviceState r_pc{(rng_state)}; \ + RAFT_DEPAREN(func)(r_pc, ##__VA_ARGS__); \ + break; \ + } \ + default: RAFT_FAIL("Unepxected generator type '%d'", int((rng_state).type)); \ } -/** - * This function is useful if all template arguments to `func` can be inferred - * by the compiler. Otherwise use the MACRO `RAFT_CALL_RNG_FUNC` which - * can accept incomplete template specializations (or kernel calls) as the function - */ -template -void call_rng_func(RngState const& rng_state, FuncT func, ArgsT... args) -{ - switch (rng_state.type) { - case GeneratorType::GenPhilox: { - DeviceState dev_state_philox{rng_state}; - func(dev_state_philox, args...); - break; - } - case GeneratorType::GenPC: { - DeviceState dev_state_pc{rng_state}; - func(dev_state_pc, args...); - break; - } - default: RAFT_FAIL("Unepxected generator type '%d'", int(rng_state.type)); - } -} - template void call_rng_kernel(DeviceState const& dev_state, RngState& rng_state, diff --git a/cpp/include/raft/random/detail/rng_impl_deprecated.cuh b/cpp/include/raft/random/detail/rng_impl_deprecated.cuh index 45fe4eba3a..29af59d502 100644 --- a/cpp/include/raft/random/detail/rng_impl_deprecated.cuh +++ b/cpp/include/raft/random/detail/rng_impl_deprecated.cuh @@ -41,6 +41,7 @@ class RngImpl { public: RngImpl(uint64_t seed, GeneratorType _t = GenPhilox) : state{seed, 0, _t}, + type(_t), // simple heuristic to make sure all SMs will be occupied properly // and also not too many initialization calls will be made by each thread nBlocks(4 * getMultiProcessorCount()) @@ -293,6 +294,7 @@ class RngImpl { } RngState state; + GeneratorType type; /** number of blocks to launch */ int nBlocks; static const int nThreads = 256; diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 2810d54fa3..33d712ac15 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -23,16 +23,9 @@ #include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed) #include "rng_state.hpp" -#define __RNG_H_INCLUSION_DEPRECATED -// we include this for now, but only for backward compatibility -#include "rng_device.cuh" -#undef __RNG_H_INCLUSION_DEPRECATED - namespace raft { namespace random { -using detail::call_rng_func; - using detail::bernoulli; using detail::exponential; using detail::fill; diff --git a/cpp/include/raft/random/rng_device.cuh b/cpp/include/raft/random/rng_device.cuh index b91bd28b6e..7d017fe4a9 100644 --- a/cpp/include/raft/random/rng_device.cuh +++ b/cpp/include/raft/random/rng_device.cuh @@ -20,6 +20,7 @@ #pragma once #include "detail/rng_device.cuh" +#include "rng.cuh" #include "rng_state.hpp" namespace raft {