diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index ba6254bfc3..106881fa1a 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -30,6 +30,28 @@ namespace raft::random { /** * @brief Generate uniformly distributed numbers in the given range * + * @tparam OutputValueType Data type of output random number + * @tparam Index Data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] start start of the range + * @param[in] end end of the range + */ +template +void uniform(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType start, + OutputValueType end) +{ + detail::uniform(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); +} + +/** + * @brief Legacy overload of `uniform` taking raw pointers + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -75,6 +97,29 @@ void uniformInt(const raft::handle_t& handle, /** * @brief Generate normal distributed numbers + * with a given mean and standard deviation + * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] mu mean of the distribution + * @param[in] sigma std-dev of the distribution + */ +template +void normal(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType sigma) +{ + detail::normal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `normal`. * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays @@ -217,6 +262,29 @@ void scaled_bernoulli(const raft::handle_t& handle, /** * @brief Generate Gumbel distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] mu mean value + * @param[in] beta scale value + * @note https://en.wikipedia.org/wiki/Gumbel_distribution + */ +template +void gumbel(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType beta) +{ + detail::gumbel(rng_state, out.data_handle(), out.extent(0), mu, beta, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `gumbel`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -241,6 +309,28 @@ void gumbel(const raft::handle_t& handle, /** * @brief Generate lognormal distributed numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] mu mean of the distribution + * @param[in] sigma standard deviation of the distribution + */ +template +void lognormal(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType sigma) +{ + detail::lognormal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `lognormal`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -248,7 +338,7 @@ void gumbel(const raft::handle_t& handle, * @param[out] ptr the output array * @param[in] len the number of elements in the output * @param[in] mu mean of the distribution - * @param[in] sigma std-dev of the distribution + * @param[in] sigma standard deviation of the distribution */ template void lognormal(const raft::handle_t& handle, @@ -264,6 +354,28 @@ void lognormal(const raft::handle_t& handle, /** * @brief Generate logistic distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] mu mean value + * @param[in] scale scale value + */ +template +void logistic(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType scale) +{ + detail::logistic(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `logistic`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -287,13 +399,33 @@ void logistic(const raft::handle_t& handle, /** * @brief Generate exponentially distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] lambda the exponential distribution's lambda parameter + */ +template +void exponential(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType lambda) +{ + detail::exponential(rng_state, out.data_handle(), out.extent(0), lambda, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `exponential`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state * @param[out] ptr output array * @param[in] len number of elements in the output array - * @param[in] lambda the lambda + * @param[in] lambda the exponential distribution's lambda parameter */ template void exponential( @@ -305,13 +437,33 @@ void exponential( /** * @brief Generate rayleigh distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] sigma the distribution's sigma parameter + */ +template +void rayleigh(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType sigma) +{ + detail::rayleigh(rng_state, out.data_handle(), out.extent(0), sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `rayleigh`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state * @param[out] ptr output array * @param[in] len number of elements in the output array - * @param[in] sigma the sigma + * @param[in] sigma the distribution's sigma parameter */ template void rayleigh( @@ -323,6 +475,28 @@ void rayleigh( /** * @brief Generate laplace distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] mu the mean + * @param[in] scale the scale + */ +template +void laplace(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType scale) +{ + detail::laplace(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `laplace`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index d778555076..8b32742f34 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "../test_utils.h" @@ -187,7 +188,100 @@ class RngTest : public ::testing::TestWithParam> { T h_stats[2]; // mean, var }; -typedef RngTest RngTestF; +template +class RngMdspanTest : public ::testing::TestWithParam> { + public: + RngMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(0, stream), + stats(2, stream) + { + data.resize(params.len, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, 2 * sizeof(T), stream)); + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + + raft::device_vector_view data_view(data.data(), data.size()); + const auto len = data_view.extent(0); + + switch (params.type) { + case RNG_Normal: normal(handle, r, data_view, params.start, params.end); break; + case RNG_LogNormal: lognormal(handle, r, data_view, params.start, params.end); break; + case RNG_Uniform: uniform(handle, r, data_view, params.start, params.end); break; + case RNG_Gumbel: gumbel(handle, r, data_view, params.start, params.end); break; + case RNG_Logistic: logistic(handle, r, data_view, params.start, params.end); break; + case RNG_Exp: exponential(handle, r, data_view, params.start); break; + case RNG_Rayleigh: rayleigh(handle, r, data_view, params.start); break; + case RNG_Laplace: laplace(handle, r, data_view, params.start, params.end); break; + }; + static const int threads = 128; + meanKernel<<>>( + stats.data(), data.data(), params.len); + update_host(h_stats, stats.data(), 2, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h_stats[0] /= params.len; + h_stats[1] = (h_stats[1] / params.len) - (h_stats[0] * h_stats[0]); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void getExpectedMeanVar(T meanvar[2]) + { + switch (params.type) { + case RNG_Normal: + meanvar[0] = params.start; + meanvar[1] = params.end * params.end; + break; + case RNG_LogNormal: { + auto var = params.end * params.end; + auto mu = params.start; + meanvar[0] = raft::myExp(mu + var * T(0.5)); + meanvar[1] = (raft::myExp(var) - T(1.0)) * raft::myExp(T(2.0) * mu + var); + break; + } + case RNG_Uniform: + meanvar[0] = (params.start + params.end) * T(0.5); + meanvar[1] = params.end - params.start; + meanvar[1] = meanvar[1] * meanvar[1] / T(12.0); + break; + case RNG_Gumbel: { + auto gamma = T(0.577215664901532); + meanvar[0] = params.start + params.end * gamma; + meanvar[1] = T(3.1415) * T(3.1415) * params.end * params.end / T(6.0); + break; + } + case RNG_Logistic: + meanvar[0] = params.start; + meanvar[1] = T(3.1415) * T(3.1415) * params.end * params.end / T(3.0); + break; + case RNG_Exp: + meanvar[0] = T(1.0) / params.start; + meanvar[1] = meanvar[0] * meanvar[0]; + break; + case RNG_Rayleigh: + meanvar[0] = params.start * raft::mySqrt(T(3.1415 / 2.0)); + meanvar[1] = ((T(4.0) - T(3.1415)) / T(2.0)) * params.start * params.start; + break; + case RNG_Laplace: + meanvar[0] = params.start; + meanvar[1] = T(2.0) * params.end * params.end; + break; + }; + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngInputs params; + rmm::device_uvector data, stats; + T h_stats[2]; // mean, var +}; + const std::vector> inputsf = { // Test with Philox {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL}, @@ -206,16 +300,22 @@ const std::vector> inputsf = { {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL}, {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}}; -TEST_P(RngTestF, Result) -{ - float meanvar[2]; - getExpectedMeanVar(meanvar); - ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); - ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); -} +#define _RAFT_RNG_TEST_BODY(VALUE_TYPE) \ + do { \ + VALUE_TYPE meanvar[2]; \ + getExpectedMeanVar(meanvar); \ + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); \ + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); \ + } while (false) + +using RngTestF = RngTest; +TEST_P(RngTestF, Result) { _RAFT_RNG_TEST_BODY(float); } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestF, ::testing::ValuesIn(inputsf)); -typedef RngTest RngTestD; +using RngMdspanTestF = RngMdspanTest; +TEST_P(RngMdspanTestF, Result) { _RAFT_RNG_TEST_BODY(float); } +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = { // Test with Philox {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL}, @@ -234,15 +334,14 @@ const std::vector> inputsd = { {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL}, {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}}; -TEST_P(RngTestD, Result) -{ - double meanvar[2]; - getExpectedMeanVar(meanvar); - ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); - ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); -} +using RngTestD = RngTest; +TEST_P(RngTestD, Result) { _RAFT_RNG_TEST_BODY(double); } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestD, ::testing::ValuesIn(inputsd)); +using RngMdspanTestD = RngMdspanTest; +TEST_P(RngMdspanTestD, Result) { _RAFT_RNG_TEST_BODY(double); } +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestD, ::testing::ValuesIn(inputsd)); + // ---------------------------------------------------------------------- // // Test for expected variance in mean calculations @@ -353,11 +452,10 @@ class ScaledBernoulliTest : public ::testing::Test { void rangeCheck() { - T* h_data = new T[len]; - update_host(h_data, data.data(), len, stream); - ASSERT_TRUE( - std::none_of(h_data, h_data + len, [](const T& a) { return a < -scale || a > scale; })); - delete[] h_data; + auto h_data = std::make_unique(len); + update_host(h_data.get(), data.data(), len, stream); + ASSERT_TRUE(std::none_of( + h_data.get(), h_data.get() + len, [](const T& a) { return a < -scale || a > scale; })); } raft::handle_t handle;