diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 106881fa1a..8ea985b559 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -24,6 +24,7 @@ #include #include #include +#include namespace raft::random { @@ -75,6 +76,34 @@ void uniform(const raft::handle_t& handle, /** * @brief Generate uniformly distributed integers in the given range * + * @tparam OutputValueType Integral type; value type of the output vector + * @tparam IndexType Type used to represent length of the output vector + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output vector of random numbers + * @param[in] start start of the range + * @param[in] end end of the range + */ +template +void uniformInt(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType start, + OutputValueType end) +{ + static_assert( + std::is_same::type>::value, + "uniformInt: The output vector must be a view of nonconst, " + "so that we can write to it."); + static_assert(std::is_integral::value, + "uniformInt: The elements of the output vector must have integral type."); + detail::uniformInt(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `uniformInt` + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -144,6 +173,35 @@ void normal(const raft::handle_t& handle, /** * @brief Generate normal distributed integers * + * @tparam OutputValueType Integral type; value type of the output vector + * @tparam IndexType Integral type of the output vector's length + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] mu mean of the distribution + * @param[in] sigma standard deviation of the distribution + */ +template +void normalInt(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType sigma) +{ + static_assert( + std::is_same::type>::value, + "normalInt: The output vector must be a view of nonconst, " + "so that we can write to it."); + static_assert(std::is_integral::value, + "normalInt: The output vector's value type must be an integer."); + + detail::normalInt(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `normalInt` + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -170,7 +228,70 @@ void normalInt(const raft::handle_t& handle, * * Each row in this table conforms to a normally distributed n-dim vector * whose mean is the input vector and standard deviation is the corresponding - * vector or scalar. Correlations among the dimensions itself is assumed to + * vector or scalar. Correlations among the dimensions itself are assumed to + * be absent. + * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[in] mu_vec mean vector (of length `out.extent(1)`) + * @param[in] sigma Either the standard-deviation vector + * (of length `out.extent(1)`) of each component, + * or a scalar standard deviation for all components. + * @param[out] out the output table + */ +template +void normalTable( + const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view mu_vec, + std::variant, OutputValueType> sigma, + raft::device_matrix_view out) +{ + const OutputValueType* sigma_vec_ptr = nullptr; + OutputValueType sigma_value{}; + + using sigma_vec_type = raft::device_vector_view; + if (std::holds_alternative(sigma)) { + auto sigma_vec = std::get(sigma); + RAFT_EXPECTS(sigma_vec.extent(0) == out.extent(1), + "normalTable: The sigma vector " + "has length %zu, which does not equal the number of columns " + "in the output table %zu.", + static_cast(sigma_vec.extent(0)), + static_cast(out.extent(1))); + // The extra length check makes this work even if sigma_vec views a std::vector, + // where .data() need not return nullptr even if .size() is zero. + sigma_vec_ptr = sigma_vec.extent(0) == 0 ? nullptr : sigma_vec.data_handle(); + } else { + sigma_value = std::get(sigma); + } + + RAFT_EXPECTS(mu_vec.extent(0) == out.extent(1), + "normalTable: The mu vector " + "has length %zu, which does not equal the number of columns " + "in the output table %zu.", + static_cast(mu_vec.extent(0)), + static_cast(out.extent(1))); + + detail::normalTable(rng_state, + out.data_handle(), + out.extent(0), + out.extent(1), + mu_vec.data_handle(), + sigma_vec_ptr, + sigma_value, + handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `normalTable`. + * + * Each row in this table conforms to a normally distributed n-dim vector + * whose mean is the input vector and standard deviation is the corresponding + * vector or scalar. Correlations among the dimensions itself are assumed to * be absent. * * @tparam OutType data type of output random number @@ -200,7 +321,27 @@ void normalTable(const raft::handle_t& handle, } /** - * @brief Fill an array with the given value + * @brief Fill a vector with the given value + * + * @tparam OutputValueType Value type of the output vector + * @tparam IndexType Integral type used to represent length of the output vector + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[in] val value with which to fill the output vector + * @param[out] out the output vector + */ +template +void fill(const raft::handle_t& handle, + RngState& rng_state, + OutputValueType val, + raft::device_vector_view out) +{ + detail::fill(rng_state, out.data_handle(), out.extent(0), val, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `fill` * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays @@ -219,6 +360,28 @@ void fill(const raft::handle_t& handle, RngState& rng_state, OutType* ptr, LenTy /** * @brief Generate bernoulli distributed boolean array * + * @tparam OutputValueType Type of each element of the output vector; + * must be able to represent boolean values (e.g., `bool`) + * @tparam IndexType Integral type of the output vector's length + * @tparam Type Data type in which to compute the probabilities + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output vector + * @param[in] prob coin-toss probability for heads + */ +template +void bernoulli(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + Type prob) +{ + detail::bernoulli(rng_state, out.data_handle(), out.extent(0), prob, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `bernoulli` + * * @tparam Type data type in which to compute the probabilities * @tparam OutType output data type * @tparam LenType data type used to represent length of the arrays @@ -239,6 +402,29 @@ void bernoulli( /** * @brief Generate bernoulli distributed array and applies scale * + * @tparam OutputValueType Data type in which to compute the probabilities + * @tparam IndexType Integral type of the output vector's length + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output vector + * @param[in] prob coin-toss probability for heads + * @param[in] scale scaling factor + */ +template +void scaled_bernoulli(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType prob, + OutputValueType scale) +{ + detail::scaled_bernoulli( + rng_state, out.data_handle(), out.extent(0), prob, scale, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `scaled_bernoulli` + * * @tparam OutType data type in which to compute the probabilities * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 8b32742f34..82f6e0e247 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -464,12 +464,47 @@ class ScaledBernoulliTest : public ::testing::Test { rmm::device_uvector data; }; -typedef ScaledBernoulliTest ScaledBernoulliTest1; +template +class ScaledBernoulliMdspanTest : public ::testing::Test { + public: + ScaledBernoulliMdspanTest() : stream(handle.get_stream()), data(len, stream) {} + + protected: + void SetUp() override + { + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + RngState r(42); + + raft::device_vector_view data_view(data.data(), data.size()); + scaled_bernoulli(handle, r, data_view, T(0.5), T(scale)); + } + + void rangeCheck() + { + auto h_data = std::make_unique(len); + update_host(h_data.get(), data.data(), len, stream); + ASSERT_TRUE(std::none_of( + h_data.get(), h_data.get() + len, [](const T& a) { return a < -scale || a > scale; })); + } + + raft::handle_t handle; + cudaStream_t stream; + + rmm::device_uvector data; +}; + +using ScaledBernoulliTest1 = ScaledBernoulliTest; TEST_F(ScaledBernoulliTest1, RangeCheck) { rangeCheck(); } -typedef ScaledBernoulliTest ScaledBernoulliTest2; +using ScaledBernoulliMdspanTest1 = ScaledBernoulliMdspanTest; +TEST_F(ScaledBernoulliMdspanTest1, RangeCheck) { rangeCheck(); } + +using ScaledBernoulliTest2 = ScaledBernoulliTest; TEST_F(ScaledBernoulliTest2, RangeCheck) { rangeCheck(); } +using ScaledBernoulliMdspanTest2 = ScaledBernoulliMdspanTest; +TEST_F(ScaledBernoulliMdspanTest2, RangeCheck) { rangeCheck(); } + template class BernoulliTest : public ::testing::Test { public: @@ -499,12 +534,49 @@ class BernoulliTest : public ::testing::Test { rmm::device_uvector data; }; -typedef BernoulliTest BernoulliTest1; +template +class BernoulliMdspanTest : public ::testing::Test { + public: + BernoulliMdspanTest() : stream(handle.get_stream()), data(len, stream) {} + + protected: + void SetUp() override + { + RngState r(42); + + raft::device_vector_view data_view(data.data(), data.size()); + + bernoulli(handle, r, data_view, T(0.5)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void trueFalseCheck() + { + // both true and false values must be present + auto h_data = std::make_unique(len); + update_host(h_data.get(), data.data(), len, stream); + ASSERT_TRUE(std::any_of(h_data.get(), h_data.get() + len, [](bool a) { return a; })); + ASSERT_TRUE(std::any_of(h_data.get(), h_data.get() + len, [](bool a) { return !a; })); + } + + raft::handle_t handle; + cudaStream_t stream; + + rmm::device_uvector data; +}; + +using BernoulliTest1 = BernoulliTest; TEST_F(BernoulliTest1, TrueFalseCheck) { trueFalseCheck(); } -typedef BernoulliTest BernoulliTest2; +using BernoulliMdspanTest1 = BernoulliMdspanTest; +TEST_F(BernoulliMdspanTest1, TrueFalseCheck) { trueFalseCheck(); } + +using BernoulliTest2 = BernoulliTest; TEST_F(BernoulliTest2, TrueFalseCheck) { trueFalseCheck(); } +using BernoulliMdspanTest2 = BernoulliMdspanTest; +TEST_F(BernoulliMdspanTest2, TrueFalseCheck) { trueFalseCheck(); } + /** Rng::normalTable tests */ template struct RngNormalTableInputs { @@ -572,7 +644,62 @@ class RngNormalTableTest : public ::testing::TestWithParam RngNormalTableTestF; +template +class RngNormalTableMdspanTest : public ::testing::TestWithParam> { + public: + RngNormalTableMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(params.rows * params.cols, stream), + stats(2, stream), + mu_vec(params.cols, stream) + { + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, 2 * sizeof(T), stream)); + } + + protected: + void SetUp() override + { + // Tests are configured with their expected test-values sigma. For example, + // 4 x sigma indicates the test shouldn't fail 99.9% of the time. + num_sigma = 10; + int len = params.rows * params.cols; + RngState r(params.seed, params.gtype); + + raft::device_matrix_view data_view( + data.data(), params.rows, params.cols); + raft::device_vector_view mu_vec_view(mu_vec.data(), params.cols); + raft::device_vector_view mu_vec_nc_view(mu_vec.data(), params.cols); + std::variant, T> sigma_var(params.sigma); + + fill(handle, r, params.mu, mu_vec_nc_view); + normalTable(handle, r, mu_vec_view, sigma_var, data_view); + static const int threads = 128; + meanKernel + <<>>(stats.data(), data.data(), len); + update_host(h_stats, stats.data(), 2, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h_stats[0] /= len; + h_stats[1] = (h_stats[1] / len) - (h_stats[0] * h_stats[0]); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void getExpectedMeanVar(T meanvar[2]) + { + meanvar[0] = params.mu; + meanvar[1] = params.sigma * params.sigma; + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngNormalTableInputs params; + rmm::device_uvector data, stats, mu_vec; + T h_stats[2]; // mean, var + int num_sigma; +}; + const std::vector> inputsf_t = { {0.0055, 32, 1024, 1.f, 1.f, GenPhilox, 1234ULL}, {0.011, 8, 1024, 1.f, 1.f, GenPhilox, 1234ULL}, @@ -580,6 +707,7 @@ const std::vector> inputsf_t = { {0.0055, 32, 1024, 1.f, 1.f, GenPC, 1234ULL}, {0.011, 8, 1024, 1.f, 1.f, GenPC, 1234ULL}}; +using RngNormalTableTestF = RngNormalTableTest; TEST_P(RngNormalTableTestF, Result) { float meanvar[2]; @@ -589,13 +717,26 @@ TEST_P(RngNormalTableTestF, Result) } INSTANTIATE_TEST_SUITE_P(RngNormalTableTests, RngNormalTableTestF, ::testing::ValuesIn(inputsf_t)); -typedef RngNormalTableTest RngNormalTableTestD; +using RngNormalTableMdspanTestF = RngNormalTableMdspanTest; +TEST_P(RngNormalTableMdspanTestF, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, + RngNormalTableMdspanTestF, + ::testing::ValuesIn(inputsf_t)); + const std::vector> inputsd_t = { {0.0055, 32, 1024, 1.0, 1.0, GenPhilox, 1234ULL}, {0.011, 8, 1024, 1.0, 1.0, GenPhilox, 1234ULL}, {0.0055, 32, 1024, 1.0, 1.0, GenPC, 1234ULL}, {0.011, 8, 1024, 1.0, 1.0, GenPC, 1234ULL}}; + +using RngNormalTableTestD = RngNormalTableTest; TEST_P(RngNormalTableTestD, Result) { double meanvar[2]; @@ -605,6 +746,18 @@ TEST_P(RngNormalTableTestD, Result) } INSTANTIATE_TEST_SUITE_P(RngNormalTableTests, RngNormalTableTestD, ::testing::ValuesIn(inputsd_t)); +using RngNormalTableMdspanTestD = RngNormalTableMdspanTest; +TEST_P(RngNormalTableMdspanTestD, Result) +{ + double meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(num_sigma * params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(num_sigma * params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngNormalTableMdspanTests, + RngNormalTableMdspanTestD, + ::testing::ValuesIn(inputsd_t)); + struct RngAffineInputs { int n; unsigned long long seed; diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index 8efd9cd5af..d5270c456e 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -121,13 +121,67 @@ class RngTest : public ::testing::TestWithParam> { float h_stats[2]; // mean, var }; -typedef RngTest RngTestU32; +template +class RngMdspanTest : public ::testing::TestWithParam> { + public: + RngMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(0, stream), + stats(2, stream) + { + data.resize(params.len, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, 2 * sizeof(float), stream)); + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + raft::device_vector_view data_view(data.data(), data.size()); + + switch (params.type) { + case RNG_Uniform: uniformInt(handle, r, data_view, params.start, params.end); break; + }; + static const int threads = 128; + meanKernel<<>>( + stats.data(), data.data(), params.len); + update_host(h_stats, stats.data(), 2, stream); + handle.sync_stream(stream); + h_stats[0] /= params.len; + h_stats[1] = (h_stats[1] / params.len) - (h_stats[0] * h_stats[0]); + handle.sync_stream(stream); + } + + void getExpectedMeanVar(float meanvar[2]) + { + switch (params.type) { + case RNG_Uniform: + meanvar[0] = (params.start + params.end) * 0.5f; + meanvar[1] = params.end - params.start; + meanvar[1] = meanvar[1] * meanvar[1] / 12.f; + break; + }; + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngInputs params; + rmm::device_uvector data; + rmm::device_uvector stats; + float h_stats[2]; // mean, var +}; + const std::vector> inputs_u32 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestU32 = RngTest; TEST_P(RngTestU32, Result) { float meanvar[2]; @@ -137,13 +191,24 @@ TEST_P(RngTestU32, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestU32, ::testing::ValuesIn(inputs_u32)); -typedef RngTest RngTestU64; +using RngMdspanTestU32 = RngMdspanTest; +TEST_P(RngMdspanTestU32, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestU32, ::testing::ValuesIn(inputs_u32)); + const std::vector> inputs_u64 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestU64 = RngTest; TEST_P(RngTestU64, Result) { float meanvar[2]; @@ -153,13 +218,24 @@ TEST_P(RngTestU64, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestU64, ::testing::ValuesIn(inputs_u64)); -typedef RngTest RngTestS32; +using RngMdspanTestU64 = RngMdspanTest; +TEST_P(RngMdspanTestU64, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestU64, ::testing::ValuesIn(inputs_u64)); + const std::vector> inputs_s32 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestS32 = RngTest; TEST_P(RngTestS32, Result) { float meanvar[2]; @@ -169,13 +245,24 @@ TEST_P(RngTestS32, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestS32, ::testing::ValuesIn(inputs_s32)); -typedef RngTest RngTestS64; +using RngMdspanTestS32 = RngMdspanTest; +TEST_P(RngMdspanTestS32, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestS32, ::testing::ValuesIn(inputs_s32)); + const std::vector> inputs_s64 = { {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPhilox, 1234ULL}, {0.1f, 32 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}, {0.1f, 8 * 1024, 0, 20, RNG_Uniform, GenPC, 1234ULL}}; + +using RngTestS64 = RngTest; TEST_P(RngTestS64, Result) { float meanvar[2]; @@ -185,5 +272,15 @@ TEST_P(RngTestS64, Result) } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestS64, ::testing::ValuesIn(inputs_s64)); +using RngMdspanTestS64 = RngMdspanTest; +TEST_P(RngMdspanTestS64, Result) +{ + float meanvar[2]; + getExpectedMeanVar(meanvar); + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(params.tolerance))); + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestS64, ::testing::ValuesIn(inputs_s64)); + } // namespace random } // namespace raft diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index 58a3f5eaeb..537bde2272 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -95,6 +95,43 @@ class HistTest : public ::testing::TestWithParam { rmm::device_uvector in, bins, ref_bins; }; +class HistMdspanTest : public ::testing::TestWithParam { + protected: + HistMdspanTest() + : in(0, handle.get_stream()), bins(0, handle.get_stream()), ref_bins(0, handle.get_stream()) + { + } + + void SetUp() override + { + params = ::testing::TestWithParam::GetParam(); + raft::random::RngState r(params.seed); + auto stream = handle.get_stream(); + int len = params.nrows * params.ncols; + in.resize(len, stream); + + raft::device_vector_view in_view(in.data(), in.size()); + if (params.isNormal) { + normalInt(handle, r, in_view, params.start, params.end); + } else { + uniformInt(handle, r, in_view, params.start, params.end); + } + bins.resize(params.nbins * params.ncols, stream); + ref_bins.resize(params.nbins * params.ncols, stream); + RAFT_CUDA_TRY( + cudaMemsetAsync(ref_bins.data(), 0, sizeof(int) * params.nbins * params.ncols, stream)); + naiveHist(ref_bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); + histogram( + params.type, bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); + handle.sync_stream(); + } + + protected: + raft::handle_t handle; + HistInputs params; + rmm::device_uvector in, bins, ref_bins; +}; + static const int oneK = 1024; static const int oneM = oneK * oneK; const std::vector inputs = { @@ -252,6 +289,7 @@ const std::vector inputs = { {oneM + 2, 21, 2 * oneK, false, HistTypeAuto, 0, 2 * oneK, 1234ULL}, {oneM + 2, 21, 2 * oneK, true, HistTypeAuto, 1000, 50, 1234ULL}, }; + TEST_P(HistTest, Result) { ASSERT_TRUE(raft::devArrMatch( @@ -259,5 +297,12 @@ TEST_P(HistTest, Result) } INSTANTIATE_TEST_CASE_P(HistTests, HistTest, ::testing::ValuesIn(inputs)); +TEST_P(HistMdspanTest, Result) +{ + ASSERT_TRUE(raft::devArrMatch( + ref_bins.data(), bins.data(), params.nbins * params.ncols, raft::Compare())); +} +INSTANTIATE_TEST_CASE_P(HistMdspanTests, HistMdspanTest, ::testing::ValuesIn(inputs)); + } // end namespace stats } // end namespace raft