From 3ae4e34462dbd792a36bad22c065fa52c39e8eef Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Thu, 8 Sep 2022 09:42:00 -0700 Subject: [PATCH] sampleWithoutReplacement: Add overload taking mdspans --- cpp/include/raft/random/rng.cuh | 71 ++++++++++ cpp/test/random/sample_without_replacement.cu | 127 +++++++++++++----- 2 files changed, 168 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 85d9abe263..fc30925651 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -20,6 +20,10 @@ #include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed) #include "rng_state.hpp" #include +#include +#include +#include +#include namespace raft::random { @@ -379,6 +383,72 @@ void sampleWithoutReplacement(const raft::handle_t& handle, rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); } +/** + * @brief Sample the input array without replacement, optionally based on the + * input weight vector for each element in the array + * + * Implementation here is based on the `one-pass sampling` algo described here: + * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * + * @note In the sampled array the elements which are picked will always appear + * in the increasing order of their weights as computed using the exponential + * distribution. So, if you're particular about the order (for eg. array + * permutations), then this might not be the right choice! + * + * @tparam DataT data type (do not specify this explicitly, + * as the compiler will deduce it from @c out ) + * @tparam IdxT index type (do not specify this explicitly, + * as the compiler will deduce it from the arguments) + * @tparam WeightsT weights type (defaults to @c double , + * so that passing in std::nullopt works) + * + * @param handle + * @param rng_state random number generator state + * @param out Array (of length 'sampledLen') + * of samples from the input array + * @param outIdx If provided, array (of length 'sampledLen') + * of the indices sampled from the input array + * @param in Input array to be sampled (of length 'len') + * @param wts Weights array (of length 'len'). + * If not provided, uniform sampling will be used. + */ +template +void sampleWithoutReplacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + std::optional> outIdx, + raft::device_vector_view in, + std::optional> wts) +{ + static_assert(std::is_integral::value, "IdxT must be an integral type."); + const IdxT sampledLen = out.extent(0); + const IdxT len = in.extent(0); + ASSERT(sampledLen <= len, "sampleWithoutReplacement: " + "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); + ASSERT(len == 0 || in.data_handle() != nullptr, "sampleWithoutReplacement: " + "If in.data_handle() is not null, then in.extent(0) must be nonzero"); + ASSERT(sampledLen == 0 || out.data_handle() != nullptr, "sampleWithoutReplacement: " + "If out.data_handle() is not null, then out.extent(0) must be nonzero"); + + const bool outIdx_has_value = outIdx.has_value(); + if (outIdx_has_value) { + ASSERT((*outIdx).extent(0) == sampledLen, "sampleWithoutReplacement: " + "If outIdx is provided, its extent(0) must equal out.extent(0)"); + } + IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr; + + const bool wts_has_value = wts.has_value(); + if (wts_has_value) { + ASSERT((*wts).extent(0) == len, "sampleWithoutReplacement: " + "If wts is provided, its extent(0) must equal in.extent(0)"); + } + const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr; + + detail::sampleWithoutReplacement( + rng_state, out.data_handle(), outIdx_ptr, in.data_handle(), wts_ptr, + sampledLen, len, handle.get_stream()); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` @@ -714,6 +784,7 @@ class DEPR Rng : public detail::RngImpl { detail::RngImpl::sampleWithoutReplacement( handle, out, outIdx, in, wts, sampledLen, len, stream); } + }; #undef DEPR diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 653a9f9bc9..51e7ad2594 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -84,7 +84,63 @@ class SWoRTest : public ::testing::TestWithParam> { std::vector h_outIdx; }; -typedef SWoRTest SWoRTestF; +template +class SWoRMdspanTest : public ::testing::TestWithParam> { + public: + SWoRMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + in(params.len, stream), + wts(params.len, stream), + out(params.sampledLen, stream), + outIdx(params.sampledLen, stream) + { + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + h_outIdx.resize(params.sampledLen); + uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0)); + uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0)); + if (params.largeWeightIndex >= 0) { + update_device(wts.data() + params.largeWeightIndex, ¶ms.largeWeight, 1, stream); + } + { + using index_type = int; + using output_view = raft::device_vector_view; + output_view out_view{out.data(), out.size()}; + ASSERT_TRUE(out_view.extent(0) == params.sampledLen); + + using output_idxs_view = raft::device_vector_view; + std::optional outIdx_view{std::in_place, outIdx.data(), outIdx.size()}; + ASSERT_TRUE(outIdx_view.has_value() && outIdx_view.value().extent(0) == params.sampledLen); + + using input_view = raft::device_vector_view; + input_view in_view{in.data(), in.size()}; + ASSERT_TRUE(in_view.extent(0) == params.len); + + using weights_view = raft::device_vector_view; + std::optional wts_view{std::in_place, wts.data(), wts.size()}; + ASSERT_TRUE(wts_view.has_value() && wts_view.value().extent(0) == params.len); + + sampleWithoutReplacement(handle, r, out_view, outIdx_view, in_view, wts_view); + } + update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); + handle.sync_stream(stream); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + SWoRInputs params; + rmm::device_uvector in, out, wts; + rmm::device_uvector outIdx; + std::vector h_outIdx; +}; + const std::vector> inputsf = {{1024, 512, -1, 0.f, GenPhilox, 1234ULL}, {1024, 1024, -1, 0.f, GenPhilox, 1234ULL}, {1024, 512 + 1, -1, 0.f, GenPhilox, 1234ULL}, @@ -125,26 +181,42 @@ const std::vector> inputsf = {{1024, 512, -1, 0.f, GenPhilox, {1024 + 2, 1024 + 2, -1, 0.f, GenPC, 1234ULL}, {1024, 512, 10, 100000.f, GenPC, 1234ULL}}; +// This needs to be a macro because it has to live in the scope +// of the class whose name is the first parameter of TEST_P. +// +// We test the following. +// +// 1. Output indices are in the given range. +// 2. Output indices do not repeat. +// 3. If there's a skewed distribution, the top index should +// correspond to the particular item with a large weight. +#define _RAFT_SWOR_TEST_CONTENTS() do { \ + std::set occurrence; \ + for (int i = 0; i < params.sampledLen; ++i) { \ + auto val = h_outIdx[i]; \ + ASSERT_TRUE(0 <= val && val < params.len) \ + << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; \ + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) \ + << "repeated index @i=" << i << " idx=" << val; \ + occurrence.insert(val); \ + } \ + if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } \ +} while(false) \ + +using SWoRTestF = SWoRTest; TEST_P(SWoRTestF, Result) { - std::set occurence; - for (int i = 0; i < params.sampledLen; ++i) { - auto val = h_outIdx[i]; - // indices must be in the given range - ASSERT_TRUE(0 <= val && val < params.len) - << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; - // indices should not repeat - ASSERT_TRUE(occurence.find(val) == occurence.end()) - << "repeated index @i=" << i << " idx=" << val; - occurence.insert(val); - } - // if there's a skewed distribution, the top index should correspond to the - // particular item with a large weight - if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } + _RAFT_SWOR_TEST_CONTENTS(); } INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestF, ::testing::ValuesIn(inputsf)); -typedef SWoRTest SWoRTestD; +using SWoRMdspanTestF = SWoRMdspanTest; +TEST_P(SWoRMdspanTestF, Result) +{ + _RAFT_SWOR_TEST_CONTENTS(); +} +INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = {{1024, 512, -1, 0.0, GenPhilox, 1234ULL}, {1024, 1024, -1, 0.0, GenPhilox, 1234ULL}, {1024, 512 + 1, -1, 0.0, GenPhilox, 1234ULL}, @@ -185,24 +257,19 @@ const std::vector> inputsd = {{1024, 512, -1, 0.0, GenPhilox, {1024 + 2, 1024 + 2, -1, 0.0, GenPC, 1234ULL}, {1024, 512, 10, 100000.0, GenPC, 1234ULL}}; +using SWoRTestD = SWoRTest; TEST_P(SWoRTestD, Result) { - std::set occurence; - for (int i = 0; i < params.sampledLen; ++i) { - auto val = h_outIdx[i]; - // indices must be in the given range - ASSERT_TRUE(0 <= val && val < params.len) - << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; - // indices should not repeat - ASSERT_TRUE(occurence.find(val) == occurence.end()) - << "repeated index @i=" << i << " idx=" << val; - occurence.insert(val); - } - // if there's a skewed distribution, the top index should correspond to the - // particular item with a large weight - if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } + _RAFT_SWOR_TEST_CONTENTS(); } INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestD, ::testing::ValuesIn(inputsd)); +using SWoRMdspanTestD = SWoRMdspanTest; +TEST_P(SWoRMdspanTestD, Result) +{ + _RAFT_SWOR_TEST_CONTENTS(); +} +INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestD, ::testing::ValuesIn(inputsd)); + } // namespace random } // namespace raft