From 597a16296ff89c75535bc4f85778e7a40ae33e13 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 09:23:27 -0700 Subject: [PATCH 1/9] Add mdspan overload of uniformInt and a unit test for it. --- cpp/include/raft/random/rng.cuh | 27 ++++++++ cpp/test/random/rng_int.cu | 107 ++++++++++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 106881fa1a..63b816020e 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -75,6 +75,33 @@ void uniform(const raft::handle_t& handle, /** * @brief Generate uniformly distributed integers in the given range * + * @tparam OutputValueType Integral type; value type of the output vector + * @tparam IndexType Type used to represent length of the output vector + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output vector of random numbers + * @param[in] start start of the range + * @param[in] end end of the range + */ +template +void uniformInt(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType start, + OutputValueType end) +{ + static_assert(std::is_same::type>::value, + "uniformInt: The output vector must be a view of nonconst, " + "so that we can write to it."); + static_assert(std::is_integral::value, + "uniformInt: The elements of the output vector must have integral type."); + detail::uniformInt(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `uniformInt` + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index 8efd9cd5af..8e1d3bd79c 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -121,13 +121,69 @@ class RngTest : public ::testing::TestWithParam> { float h_stats[2]; // mean, var }; -typedef RngTest RngTestU32; +template +class RngMdspanTest : public ::testing::TestWithParam> { + public: + RngMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(0, stream), + stats(2, stream) + { + data.resize(params.len, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, 2 * sizeof(float), stream)); + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + raft::device_vector_view data_view(data.data(), data.size()); + + switch (params.type) { + case RNG_Uniform: + uniformInt(handle, r, data_view, params.start, params.end); + break; + }; + static const int threads = 128; + meanKernel<<>>( + stats.data(), data.data(), params.len); + update_host(h_stats, stats.data(), 2, stream); + handle.sync_stream(stream); + h_stats[0] /= params.len; + h_stats[1] = (h_stats[1] / params.len) - (h_stats[0] * h_stats[0]); + handle.sync_stream(stream); + } + + void getExpectedMeanVar(float meanvar[2]) + { + switch (params.type) { + case RNG_Uniform: + meanvar[0] = (params.start + params.end) * 0.5f; + meanvar[1] = params.end - params.start; + meanvar[1] = meanvar[1] * meanvar[1] / 12.f; + break; + }; + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngInputs params; + rmm::device_uvector data; + rmm::device_uvector stats; + float h_stats[2]; // mean, var +}; + const std::vector> inputs_u32 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestU32 = RngTest; TEST_P(RngTestU32, Result) { float meanvar[2]; @@ -137,13 +193,24 @@ TEST_P(RngTestU32, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestU32, ::testing::ValuesIn(inputs_u32)); -typedef RngTest RngTestU64; +using RngMdspanTestU32 = RngMdspanTest; +TEST_P(RngMdspanTestU32, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestU32, ::testing::ValuesIn(inputs_u32)); + const std::vector> inputs_u64 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestU64 = RngTest; TEST_P(RngTestU64, Result) { float meanvar[2]; @@ -153,13 +220,24 @@ TEST_P(RngTestU64, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestU64, ::testing::ValuesIn(inputs_u64)); -typedef RngTest RngTestS32; +using RngMdspanTestU64 = RngMdspanTest; +TEST_P(RngMdspanTestU64, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestU64, ::testing::ValuesIn(inputs_u64)); + const std::vector> inputs_s32 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestS32 = RngTest; TEST_P(RngTestS32, Result) { float meanvar[2]; @@ -169,13 +247,24 @@ TEST_P(RngTestS32, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestS32, ::testing::ValuesIn(inputs_s32)); -typedef RngTest RngTestS64; +using RngMdspanTestS32 = RngMdspanTest; +TEST_P(RngMdspanTestS32, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestS32, ::testing::ValuesIn(inputs_s32)); + const std::vector> inputs_s64 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestS64 = RngTest; TEST_P(RngTestS64, Result) { float meanvar[2]; @@ -185,5 +274,15 @@ TEST_P(RngTestS64, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestS64, ::testing::ValuesIn(inputs_s64)); +using RngMdspanTestS64 = RngMdspanTest; +TEST_P(RngMdspanTestS64, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestS64, ::testing::ValuesIn(inputs_s64)); + } // namespace random } // namespace raft From 5f86bd0128743f01de19069092b58393967918b9 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 10:51:26 -0700 Subject: [PATCH 2/9] Add an mdspan overload of normalTable and a unit test for it. --- cpp/include/raft/random/rng.cuh | 56 ++++++++++++++++++++++- cpp/test/random/rng.cu | 80 ++++++++++++++++++++++++++++++++- 2 files changed, 133 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 63b816020e..84825cfd17 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -21,6 +21,7 @@ #include "rng_state.hpp" #include #include +#include #include #include #include @@ -197,7 +198,60 @@ void normalInt(const raft::handle_t& handle, * * Each row in this table conforms to a normally distributed n-dim vector * whose mean is the input vector and standard deviation is the corresponding - * vector or scalar. Correlations among the dimensions itself is assumed to + * vector or scalar. Correlations among the dimensions itself are assumed to + * be absent. + * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[in] mu_vec mean vector (of length `out.extent(1)`) + * @param[in] sigma_vec Either the standard-deviation vector + * (of length `out.extent(1)`) of each component, + * or a scalar standard deviation for all components. + * @param[out] ptr the output table + */ +template +void normalTable(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view mu_vec, + std::variant, OutputValueType> sigma, + raft::device_matrix_view out) +{ + const OutputValueType* sigma_vec_ptr = nullptr; + OutputValueType sigma_value{}; + + using sigma_vec_type = raft::device_vector_view; + if (std::holds_alternative(sigma)) { + auto sigma_vec = std::get(sigma); + RAFT_EXPECTS( sigma_vec.extent(0) == out.extent(1), "normalTable: The sigma vector " + "has length %zu, which does not equal the number of columns " + "in the output table %zu.", static_cast(sigma_vec.extent(0)), + static_cast(out.extent(1)) ); + // The extra length check makes this work even if sigma_vec views a std::vector, + // where .data() need not return nullptr even if .size() is zero. + sigma_vec_ptr = sigma_vec.extent(0) == 0 ? nullptr : sigma_vec.data_handle(); + } else { + sigma_value = std::get(sigma); + } + + RAFT_EXPECTS( mu_vec.extent(0) == out.extent(1), "normalTable: The mu vector " + "has length %zu, which does not equal the number of columns " + "in the output table %zu.", static_cast(mu_vec.extent(0)), + static_cast(out.extent(1)) ); + + detail::normalTable( + rng_state, out.data_handle(), out.extent(0), out.extent(1), + mu_vec.data_handle(), sigma_vec_ptr, sigma_value, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `normalTable`. + * + * Each row in this table conforms to a normally distributed n-dim vector + * whose mean is the input vector and standard deviation is the corresponding + * vector or scalar. Correlations among the dimensions itself are assumed to * be absent. * * @tparam OutType data type of output random number diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 8b32742f34..c4a8289e8d 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -572,7 +572,61 @@ class RngNormalTableTest : public ::testing::TestWithParam RngNormalTableTestF; +template +class RngNormalTableMdspanTest : public ::testing::TestWithParam> { + public: + RngNormalTableMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(params.rows * params.cols, stream), + stats(2, stream), + mu_vec(params.cols, stream) + { + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, 2 * sizeof(T), stream)); + } + + protected: + void SetUp() override + { + // Tests are configured with their expected test-values sigma. For example, + // 4 x sigma indicates the test shouldn't fail 99.9% of the time. + num_sigma = 10; + int len = params.rows * params.cols; + RngState r(params.seed, params.gtype); + + fill(handle, r, mu_vec.data(), params.cols, params.mu); + + raft::device_matrix_view data_view(data.data(), params.rows, params.cols); + raft::device_vector_view mu_vec_view(mu_vec.data(), params.cols); + std::variant, T> sigma_var(params.sigma); + + normalTable(handle, r, mu_vec_view, sigma_var, data_view); + static const int threads = 128; + meanKernel + <<>>(stats.data(), data.data(), len); + update_host(h_stats, stats.data(), 2, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h_stats[0] /= len; + h_stats[1] = (h_stats[1] / len) - (h_stats[0] * h_stats[0]); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void getExpectedMeanVar(T meanvar[2]) + { + meanvar[0] = params.mu; + meanvar[1] = params.sigma * params.sigma; + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngNormalTableInputs params; + rmm::device_uvector data, stats, mu_vec; + T h_stats[2]; // mean, var + int num_sigma; +}; + const std::vector> inputsf_t = { {0.0055, 32, 1024, 1.f, 1.f, GenPhilox, 1234ULL}, {0.011, 8, 1024, 1.f, 1.f, GenPhilox, 1234ULL}, @@ -580,6 +634,7 @@ const std::vector> inputsf_t = { {0.0055, 32, 1024, 1.f, 1.f, GenPC, 1234ULL}, {0.011, 8, 1024, 1.f, 1.f, GenPC, 1234ULL}}; +using RngNormalTableTestF = RngNormalTableTest; TEST_P(RngNormalTableTestF, Result) { float meanvar[2]; @@ -589,13 +644,24 @@ TEST_P(RngNormalTableTestF, Result) } INSTANTIATE_TEST_SUITE_P(RngNormalTableTests, RngNormalTableTestF, ::testing::ValuesIn(inputsf_t)); -typedef RngNormalTableTest RngNormalTableTestD; +using RngNormalTableMdspanTestF = RngNormalTableMdspanTest; +TEST_P(RngNormalTableMdspanTestF, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, RngNormalTableMdspanTestF, ::testing::ValuesIn(inputsf_t)); + const std::vector> inputsd_t = { {0.0055, 32, 1024, 1.0, 1.0, GenPhilox, 1234ULL}, {0.011, 8, 1024, 1.0, 1.0, GenPhilox, 1234ULL}, {0.0055, 32, 1024, 1.0, 1.0, GenPC, 1234ULL}, {0.011, 8, 1024, 1.0, 1.0, GenPC, 1234ULL}}; + +using RngNormalTableTestD = RngNormalTableTest; TEST_P(RngNormalTableTestD, Result) { double meanvar[2]; @@ -605,6 +671,16 @@ TEST_P(RngNormalTableTestD, Result) } INSTANTIATE_TEST_SUITE_P(RngNormalTableTests, RngNormalTableTestD, ::testing::ValuesIn(inputsd_t)); +using RngNormalTableMdspanTestD = RngNormalTableMdspanTest; +TEST_P(RngNormalTableMdspanTestD, Result) +{ + double meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, RngNormalTableMdspanTestD, ::testing::ValuesIn(inputsd_t)); + struct RngAffineInputs { int n; unsigned long long seed; From a35eb086dd2e68bce278f981351da3a218093506 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 10:59:58 -0700 Subject: [PATCH 3/9] Fix formatting --- cpp/include/raft/random/rng.cuh | 53 ++++++++++++++++++++------------- cpp/test/random/rng.cu | 11 +++++-- cpp/test/random/rng_int.cu | 4 +-- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 84825cfd17..6bd1eb3342 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -21,10 +21,10 @@ #include "rng_state.hpp" #include #include -#include #include #include #include +#include namespace raft::random { @@ -92,9 +92,10 @@ void uniformInt(const raft::handle_t& handle, OutputValueType start, OutputValueType end) { - static_assert(std::is_same::type>::value, - "uniformInt: The output vector must be a view of nonconst, " - "so that we can write to it."); + static_assert( + std::is_same::type>::value, + "uniformInt: The output vector must be a view of nonconst, " + "so that we can write to it."); static_assert(std::is_integral::value, "uniformInt: The elements of the output vector must have integral type."); detail::uniformInt(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); @@ -213,11 +214,12 @@ void normalInt(const raft::handle_t& handle, * @param[out] ptr the output table */ template -void normalTable(const raft::handle_t& handle, - RngState& rng_state, - raft::device_vector_view mu_vec, - std::variant, OutputValueType> sigma, - raft::device_matrix_view out) +void normalTable( + const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view mu_vec, + std::variant, OutputValueType> sigma, + raft::device_matrix_view out) { const OutputValueType* sigma_vec_ptr = nullptr; OutputValueType sigma_value{}; @@ -225,10 +227,12 @@ void normalTable(const raft::handle_t& handle, using sigma_vec_type = raft::device_vector_view; if (std::holds_alternative(sigma)) { auto sigma_vec = std::get(sigma); - RAFT_EXPECTS( sigma_vec.extent(0) == out.extent(1), "normalTable: The sigma vector " - "has length %zu, which does not equal the number of columns " - "in the output table %zu.", static_cast(sigma_vec.extent(0)), - static_cast(out.extent(1)) ); + RAFT_EXPECTS(sigma_vec.extent(0) == out.extent(1), + "normalTable: The sigma vector " + "has length %zu, which does not equal the number of columns " + "in the output table %zu.", + static_cast(sigma_vec.extent(0)), + static_cast(out.extent(1))); // The extra length check makes this work even if sigma_vec views a std::vector, // where .data() need not return nullptr even if .size() is zero. sigma_vec_ptr = sigma_vec.extent(0) == 0 ? nullptr : sigma_vec.data_handle(); @@ -236,14 +240,21 @@ void normalTable(const raft::handle_t& handle, sigma_value = std::get(sigma); } - RAFT_EXPECTS( mu_vec.extent(0) == out.extent(1), "normalTable: The mu vector " - "has length %zu, which does not equal the number of columns " - "in the output table %zu.", static_cast(mu_vec.extent(0)), - static_cast(out.extent(1)) ); - - detail::normalTable( - rng_state, out.data_handle(), out.extent(0), out.extent(1), - mu_vec.data_handle(), sigma_vec_ptr, sigma_value, handle.get_stream()); + RAFT_EXPECTS(mu_vec.extent(0) == out.extent(1), + "normalTable: The mu vector " + "has length %zu, which does not equal the number of columns " + "in the output table %zu.", + static_cast(mu_vec.extent(0)), + static_cast(out.extent(1))); + + detail::normalTable(rng_state, + out.data_handle(), + out.extent(0), + out.extent(1), + mu_vec.data_handle(), + sigma_vec_ptr, + sigma_value, + handle.get_stream()); } /** diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index c4a8289e8d..f81f36d1e6 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -596,7 +596,8 @@ class RngNormalTableMdspanTest : public ::testing::TestWithParam data_view(data.data(), params.rows, params.cols); + raft::device_matrix_view data_view( + data.data(), params.rows, params.cols); raft::device_vector_view mu_vec_view(mu_vec.data(), params.cols); std::variant, T> sigma_var(params.sigma); @@ -652,7 +653,9 @@ TEST_P(RngNormalTableMdspanTestF, Result) ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); } -INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, RngNormalTableMdspanTestF, ::testing::ValuesIn(inputsf_t)); +INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, + RngNormalTableMdspanTestF, + ::testing::ValuesIn(inputsf_t)); const std::vector> inputsd_t = { {0.0055, 32, 1024, 1.0, 1.0, GenPhilox, 1234ULL}, @@ -679,7 +682,9 @@ TEST_P(RngNormalTableMdspanTestD, Result) ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); } -INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, RngNormalTableMdspanTestD, ::testing::ValuesIn(inputsd_t)); +INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, + RngNormalTableMdspanTestD, + ::testing::ValuesIn(inputsd_t)); struct RngAffineInputs { int n; diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index 8e1d3bd79c..d5270c456e 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -141,9 +141,7 @@ class RngMdspanTest : public ::testing::TestWithParam> { raft::device_vector_view data_view(data.data(), data.size()); switch (params.type) { - case RNG_Uniform: - uniformInt(handle, r, data_view, params.start, params.end); - break; + case RNG_Uniform: uniformInt(handle, r, data_view, params.start, params.end); break; }; static const int threads = 128; meanKernel<<>>( From 67fa742ffe8e092dbcce2985d352f9e7547075a6 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 11:45:36 -0700 Subject: [PATCH 4/9] Add an mdspan overload of raft::random::fill There's no unit test specifically for this function, but it is used in the normalTable test, so I've made the mdspan version of the normalTable test call this function. --- cpp/include/raft/random/rng.cuh | 22 +++++++++++++++++++++- cpp/test/random/rng.cu | 4 ++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 6bd1eb3342..43c77a22bc 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -292,7 +292,27 @@ void normalTable(const raft::handle_t& handle, } /** - * @brief Fill an array with the given value + * @brief Fill a vector with the given value + * + * @tparam OutputValueType Value type of the output vector + * @tparam IndexType Integral type used to represent length of the output vector + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[in] val value with which to fill the output vector + * @param[out] out the output vector + */ +template +void fill(const raft::handle_t& handle, + RngState& rng_state, + OutputValueType val, + raft::device_vector_view out) +{ + detail::fill(rng_state, out.data_handle(), out.extent(0), val, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `fill` * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index f81f36d1e6..08acd1b5bd 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -594,13 +594,13 @@ class RngNormalTableMdspanTest : public ::testing::TestWithParam data_view( data.data(), params.rows, params.cols); raft::device_vector_view mu_vec_view(mu_vec.data(), params.cols); + raft::device_vector_view mu_vec_nc_view(mu_vec.data(), params.cols); std::variant, T> sigma_var(params.sigma); + fill(handle, r, params.mu, mu_vec_nc_view); normalTable(handle, r, mu_vec_view, sigma_var, data_view); static const int threads = 128; meanKernel From 9fc5ef234c1e143a7e8e1bf63941b5b4b81d29a0 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 12:32:08 -0700 Subject: [PATCH 5/9] Add an mdspan overload of raft::random::bernoulli --- cpp/include/raft/random/rng.cuh | 22 ++++++++++++++++++ cpp/test/random/rng.cu | 41 +++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 43c77a22bc..3a7c65769b 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -331,6 +331,28 @@ void fill(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenTy /** * @brief Generate bernoulli distributed boolean array * + * @tparam OutputValueType Type of each element of the output vector; + * must be able to represent boolean values (e.g., `bool`) + * @tparam IndexType Integral type of the output vector's length + * @tparam Type Data type in which to compute the probabilities + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output vector + * @param[in] prob coin-toss probability for heads + */ +template +void bernoulli(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + Type prob) +{ + detail::bernoulli(rng_state, out.data_handle(), out.extent(0), prob, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `bernoulli` + * * @tparam Type data type in which to compute the probabilities * @tparam OutType output data type * @tparam LenType data type used to represent length of the arrays diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 08acd1b5bd..7a3b4222b4 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -499,12 +499,49 @@ class BernoulliTest : public ::testing::Test { rmm::device_uvector data; }; -typedef BernoulliTest BernoulliTest1; +template +class BernoulliMdspanTest : public ::testing::Test { + public: + BernoulliMdspanTest() : stream(handle.get_stream()), data(len, stream) {} + + protected: + void SetUp() override + { + RngState r(42); + + raft::device_vector_view data_view(data.data(), data.size()); + + bernoulli(handle, r, data_view, T(0.5)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void trueFalseCheck() + { + // both true and false values must be present + auto h_data = std::make_unique(len); + update_host(h_data.get(), data.data(), len, stream); + ASSERT_TRUE(std::any_of(h_data.get(), h_data.get() + len, [](bool a) { return a; })); + ASSERT_TRUE(std::any_of(h_data.get(), h_data.get() + len, [](bool a) { return !a; })); + } + + raft::handle_t handle; + cudaStream_t stream; + + rmm::device_uvector data; +}; + +using BernoulliTest1 = BernoulliTest; TEST_F(BernoulliTest1, TrueFalseCheck) { trueFalseCheck(); } -typedef BernoulliTest BernoulliTest2; +using BernoulliMdspanTest1 = BernoulliMdspanTest; +TEST_F(BernoulliMdspanTest1, TrueFalseCheck) { trueFalseCheck(); } + +using BernoulliTest2 = BernoulliTest; TEST_F(BernoulliTest2, TrueFalseCheck) { trueFalseCheck(); } +using BernoulliMdspanTest2 = BernoulliMdspanTest; +TEST_F(BernoulliMdspanTest2, TrueFalseCheck) { trueFalseCheck(); } + /** Rng::normalTable tests */ template struct RngNormalTableInputs { From fa8acad17d48924f1d302fd61cc90700c24744ad Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 13:07:40 -0700 Subject: [PATCH 6/9] Add an mdspan overload of raft::random::scaled_bernoulli --- cpp/include/raft/random/rng.cuh | 23 +++++++++++++++++++ cpp/test/random/rng.cu | 39 +++++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 3a7c65769b..55b60a5c91 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -373,6 +373,29 @@ void bernoulli( /** * @brief Generate bernoulli distributed array and applies scale * + * @tparam OutputValueType Data type in which to compute the probabilities + * @tparam IndexType Integral type of the output vector's length + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output vector + * @param[in] prob coin-toss probability for heads + * @param[in] scale scaling factor + */ +template +void scaled_bernoulli(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType prob, + OutputValueType scale) +{ + detail::scaled_bernoulli( + rng_state, out.data_handle(), out.extent(0), prob, scale, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `scaled_bernoulli` + * * @tparam OutType data type in which to compute the probabilities * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 7a3b4222b4..82f6e0e247 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -464,12 +464,47 @@ class ScaledBernoulliTest : public ::testing::Test { rmm::device_uvector data; }; -typedef ScaledBernoulliTest ScaledBernoulliTest1; +template +class ScaledBernoulliMdspanTest : public ::testing::Test { + public: + ScaledBernoulliMdspanTest() : stream(handle.get_stream()), data(len, stream) {} + + protected: + void SetUp() override + { + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + RngState r(42); + + raft::device_vector_view data_view(data.data(), data.size()); + scaled_bernoulli(handle, r, data_view, T(0.5), T(scale)); + } + + void rangeCheck() + { + auto h_data = std::make_unique(len); + update_host(h_data.get(), data.data(), len, stream); + ASSERT_TRUE(std::none_of( + h_data.get(), h_data.get() + len, [](const T& a) { return a < -scale || a > scale; })); + } + + raft::handle_t handle; + cudaStream_t stream; + + rmm::device_uvector data; +}; + +using ScaledBernoulliTest1 = ScaledBernoulliTest; TEST_F(ScaledBernoulliTest1, RangeCheck) { rangeCheck(); } -typedef ScaledBernoulliTest ScaledBernoulliTest2; +using ScaledBernoulliMdspanTest1 = ScaledBernoulliMdspanTest; +TEST_F(ScaledBernoulliMdspanTest1, RangeCheck) { rangeCheck(); } + +using ScaledBernoulliTest2 = ScaledBernoulliTest; TEST_F(ScaledBernoulliTest2, RangeCheck) { rangeCheck(); } +using ScaledBernoulliMdspanTest2 = ScaledBernoulliMdspanTest; +TEST_F(ScaledBernoulliMdspanTest2, RangeCheck) { rangeCheck(); } + template class BernoulliTest : public ::testing::Test { public: From d6889529c977228de2dd27e98ff76ab7fd63aef2 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 5 Oct 2022 16:50:53 -0400 Subject: [PATCH 7/9] Small doxygen fix in rng.cuh --- cpp/include/raft/random/rng.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 55b60a5c91..0a4eff72d4 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -211,7 +211,7 @@ void normalInt(const raft::handle_t& handle, * @param[in] sigma_vec Either the standard-deviation vector * (of length `out.extent(1)`) of each component, * or a scalar standard deviation for all components. - * @param[out] ptr the output table + * @param[out] out the output table */ template void normalTable( From 2e9d2b491fb75dd289733a1320551fc8381ed59c Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 5 Oct 2022 16:02:34 -0700 Subject: [PATCH 8/9] Add an mdspan overload of raft::random::normalInt This touches a test file changed by PR #802, but this commit should not conflict with that PR. --- cpp/include/raft/random/rng.cuh | 29 +++++++++++++++++++++ cpp/test/stats/histogram.cu | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 0a4eff72d4..41617be26a 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -173,6 +173,35 @@ void normal(const raft::handle_t& handle, /** * @brief Generate normal distributed integers * + * @tparam OutputValueType Integral type; value type of the output vector + * @tparam IndexType Integral type of the output vector's length + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] mu mean of the distribution + * @param[in] sigma standard deviation of the distribution + */ +template +void normalInt(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType sigma) +{ + static_assert( + std::is_same::type>::value, + "normalInt: The output vector must be a view of nonconst, " + "so that we can write to it."); + static_assert(std::is_integral::value, + "normalInt: The output vector's value type must be an integer."); + + detail::normalInt(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `normalInt` + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index 58a3f5eaeb..537bde2272 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -95,6 +95,43 @@ class HistTest : public ::testing::TestWithParam { rmm::device_uvector in, bins, ref_bins; }; +class HistMdspanTest : public ::testing::TestWithParam { + protected: + HistMdspanTest() + : in(0, handle.get_stream()), bins(0, handle.get_stream()), ref_bins(0, handle.get_stream()) + { + } + + void SetUp() override + { + params = ::testing::TestWithParam::GetParam(); + raft::random::RngState r(params.seed); + auto stream = handle.get_stream(); + int len = params.nrows * params.ncols; + in.resize(len, stream); + + raft::device_vector_view in_view(in.data(), in.size()); + if (params.isNormal) { + normalInt(handle, r, in_view, params.start, params.end); + } else { + uniformInt(handle, r, in_view, params.start, params.end); + } + bins.resize(params.nbins * params.ncols, stream); + ref_bins.resize(params.nbins * params.ncols, stream); + RAFT_CUDA_TRY( + cudaMemsetAsync(ref_bins.data(), 0, sizeof(int) * params.nbins * params.ncols, stream)); + naiveHist(ref_bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); + histogram( + params.type, bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); + handle.sync_stream(); + } + + protected: + raft::handle_t handle; + HistInputs params; + rmm::device_uvector in, bins, ref_bins; +}; + static const int oneK = 1024; static const int oneM = oneK * oneK; const std::vector inputs = { @@ -252,6 +289,7 @@ const std::vector inputs = { {oneM + 2, 21, 2 * oneK, false, HistTypeAuto, 0, 2 * oneK, 1234ULL}, {oneM + 2, 21, 2 * oneK, true, HistTypeAuto, 1000, 50, 1234ULL}, }; + TEST_P(HistTest, Result) { ASSERT_TRUE(raft::devArrMatch( @@ -259,5 +297,12 @@ TEST_P(HistTest, Result) } INSTANTIATE_TEST_CASE_P(HistTests, HistTest, ::testing::ValuesIn(inputs)); +TEST_P(HistMdspanTest, Result) +{ + ASSERT_TRUE(raft::devArrMatch( + ref_bins.data(), bins.data(), params.nbins * params.ncols, raft::Compare())); +} +INSTANTIATE_TEST_CASE_P(HistMdspanTests, HistMdspanTest, ::testing::ValuesIn(inputs)); + } // end namespace stats } // end namespace raft From 4743341f3bd0980f50d39d0653f3eb7e7cf282e2 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 6 Oct 2022 06:25:14 -0400 Subject: [PATCH 9/9] Proper doxygen --- cpp/include/raft/random/rng.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 41617be26a..8ea985b559 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -237,7 +237,7 @@ void normalInt(const raft::handle_t& handle, * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state * @param[in] mu_vec mean vector (of length `out.extent(1)`) - * @param[in] sigma_vec Either the standard-deviation vector + * @param[in] sigma Either the standard-deviation vector * (of length `out.extent(1)`) of each component, * or a scalar standard deviation for all components. * @param[out] out the output table