diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 85d9abe263..ba6254bfc3 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -19,7 +19,11 @@ #include "detail/rng_impl.cuh" #include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed) #include "rng_state.hpp" +#include +#include +#include #include +#include namespace raft::random { @@ -340,20 +344,143 @@ void laplace(const raft::handle_t& handle, } /** - * @brief Sample the input array without replacement, optionally based on the - * input weight vector for each element in the array + * @brief Sample the input vector without replacement, optionally based on the + * input weight vector for each element in the array. * - * Implementation here is based on the `one-pass sampling` algo described here: - * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * The implementation is based on the `one-pass sampling` algorithm described in + * ["Accelerating weighted random sampling without + * replacement,"](https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf) + * a technical report by Kirill Mueller. * - * @note In the sampled array the elements which are picked will always appear - * in the increasing order of their weights as computed using the exponential - * distribution. So, if you're particular about the order (for eg. array - * permutations), then this might not be the right choice! + * If no input weight vector is provided, then input elements will be + * sampled uniformly. Otherwise, the elements sampled from the input + * vector will always appear in increasing order of their weights as + * computed using the exponential distribution. So, if you are + * particular about the order (for e.g., array permutations), then + * this might not be the right choice. + * + * @tparam DataT type of each element of the input array @c in + * @tparam IdxT type of the dimensions of the arrays; output index type + * @tparam WeightsT type of each elements of the weights array @c wts + * + * @note Please do not specify template parameters explicitly, + * as the compiler can deduce them from the arguments. + * + * @param[in] handle RAFT handle containing (among other resources) + * the CUDA stream on which to run. + * @param[inout] rng_state Pseudorandom number generator state. + * @param[in] in Input vector to be sampled. + * @param[in] wts Optional weights vector. + * If not provided, uniform sampling will be used. + * @param[out] out Vector of samples from the input vector. + * @param[out] outIdx If provided, vector of the indices + * sampled from the input array. + * + * @pre The number of samples `out.extent(0)` + * is less than or equal to the number of inputs `in.extent(0)`. + * + * @pre The number of weights `wts.extent(0)` + * equals the number of inputs `in.extent(0)`. + */ +template +void sample_without_replacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view in, + std::optional> wts, + raft::device_vector_view out, + std::optional> outIdx) +{ + static_assert(std::is_integral::value, "IdxT must be an integral type."); + const IdxT sampledLen = out.extent(0); + const IdxT len = in.extent(0); + RAFT_EXPECTS(sampledLen <= len, + "sampleWithoutReplacement: " + "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); + RAFT_EXPECTS(len == 0 || in.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If in.extents(0) != 0, then in.data_handle() must be nonnull"); + RAFT_EXPECTS(sampledLen == 0 || out.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If out.extents(0) != 0, then out.data_handle() must be nonnull"); + + const bool outIdx_has_value = outIdx.has_value(); + if (outIdx_has_value) { + RAFT_EXPECTS((*outIdx).extent(0) == sampledLen, + "sampleWithoutReplacement: " + "If outIdx is provided, its extent(0) must equal out.extent(0)"); + } + IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr; + + const bool wts_has_value = wts.has_value(); + if (wts_has_value) { + RAFT_EXPECTS((*wts).extent(0) == len, + "sampleWithoutReplacement: " + "If wts is provided, its extent(0) must equal in.extent(0)"); + } + const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr; + + detail::sampleWithoutReplacement(rng_state, + out.data_handle(), + outIdx_ptr, + in.data_handle(), + wts_ptr, + sampledLen, + len, + handle.get_stream()); +} + +namespace sample_without_replacement_impl { +template +struct weight_alias { +}; + +template <> +struct weight_alias { + using type = double; +}; + +template +struct weight_alias>> { + using type = typename raft::device_vector_view::value_type; +}; + +template +using weight_t = typename weight_alias::type; +} // namespace sample_without_replacement_impl + +/** + * @brief Overload of `sample_without_replacement` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `sample_without_replacement`. + */ +template +void sample_without_replacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view in, + WeightsVectorType&& wts, + raft::device_vector_view out, + OutIndexVectorType&& outIdx) +{ + using weight_type = sample_without_replacement_impl::weight_t< + std::remove_const_t>>; + std::optional> weights = + std::forward(wts); + std::optional> output_indices = + std::forward(outIdx); + + sample_without_replacement(handle, rng_state, in, weights, out, output_indices); +} + +/** + * @brief Legacy version of @c sample_without_replacement (see above) + * that takes raw arrays instead of device mdspan. * * @tparam DataT data type * @tparam WeightsT weights type * @tparam IdxT index type + * * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state * @param[out] out output sampled array (of length 'sampledLen') diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 653a9f9bc9..482b35168a 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -70,7 +70,7 @@ class SWoRTest : public ::testing::TestWithParam> { } sampleWithoutReplacement( handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len); - update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream); + update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); handle.sync_stream(stream); } @@ -84,7 +84,75 @@ class SWoRTest : public ::testing::TestWithParam> { std::vector h_outIdx; }; -typedef SWoRTest SWoRTestF; +template +class SWoRMdspanTest : public ::testing::TestWithParam> { + public: + SWoRMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + in(params.len, stream), + wts(params.len, stream), + out(params.sampledLen, stream), + out2(params.sampledLen, stream), + outIdx(params.sampledLen, stream), + outIdx2(params.sampledLen, stream) + { + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + h_outIdx.resize(params.sampledLen); + uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0)); + uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0)); + if (params.largeWeightIndex >= 0) { + update_device(wts.data() + params.largeWeightIndex, ¶ms.largeWeight, 1, stream); + } + { + using index_type = int; + using output_view = raft::device_vector_view; + output_view out_view{out.data(), out.size()}; + ASSERT_TRUE(out_view.extent(0) == params.sampledLen); + + using output_idxs_view = raft::device_vector_view; + std::optional outIdx_view{std::in_place, outIdx.data(), outIdx.size()}; + ASSERT_TRUE(outIdx_view.value().extent(0) == params.sampledLen); + + using input_view = raft::device_vector_view; + input_view in_view{in.data(), in.size()}; + ASSERT_TRUE(in_view.extent(0) == params.len); + + using weights_view = raft::device_vector_view; + std::optional wts_view{std::in_place, wts.data(), wts.size()}; + ASSERT_TRUE(wts_view.value().extent(0) == params.len); + + sample_without_replacement(handle, r, in_view, wts_view, out_view, outIdx_view); + + output_view out2_view{out2.data(), out2.size()}; + ASSERT_TRUE(out2_view.extent(0) == params.sampledLen); + std::optional outIdx2_view{std::in_place, outIdx2.data(), outIdx2.size()}; + ASSERT_TRUE(outIdx2_view.value().extent(0) == params.sampledLen); + + // For now, just test that these calls compile. + sample_without_replacement(handle, r, in_view, wts_view, out2_view, std::nullopt); + sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, outIdx2_view); + sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, std::nullopt); + } + update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); + handle.sync_stream(stream); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + SWoRInputs params; + rmm::device_uvector in, out, wts, out2; + rmm::device_uvector outIdx, outIdx2; + std::vector h_outIdx; +}; + const std::vector> inputsf = {{1024, 512, -1, 0.f, GenPhilox, 1234ULL}, {1024, 1024, -1, 0.f, GenPhilox, 1234ULL}, {1024, 512 + 1, -1, 0.f, GenPhilox, 1234ULL}, @@ -125,26 +193,37 @@ const std::vector> inputsf = {{1024, 512, -1, 0.f, GenPhilox, {1024 + 2, 1024 + 2, -1, 0.f, GenPC, 1234ULL}, {1024, 512, 10, 100000.f, GenPC, 1234ULL}}; -TEST_P(SWoRTestF, Result) -{ - std::set occurence; - for (int i = 0; i < params.sampledLen; ++i) { - auto val = h_outIdx[i]; - // indices must be in the given range - ASSERT_TRUE(0 <= val && val < params.len) - << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; - // indices should not repeat - ASSERT_TRUE(occurence.find(val) == occurence.end()) - << "repeated index @i=" << i << " idx=" << val; - occurence.insert(val); - } - // if there's a skewed distribution, the top index should correspond to the - // particular item with a large weight - if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } -} +// This needs to be a macro because it has to live in the scope +// of the class whose name is the first parameter of TEST_P. +// +// We test the following. +// +// 1. Output indices are in the given range. +// 2. Output indices do not repeat. +// 3. If there's a skewed distribution, the top index should +// correspond to the particular item with a large weight. +#define _RAFT_SWOR_TEST_CONTENTS() \ + do { \ + std::set occurrence; \ + for (int i = 0; i < params.sampledLen; ++i) { \ + auto val = h_outIdx[i]; \ + ASSERT_TRUE(0 <= val && val < params.len) \ + << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; \ + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) \ + << "repeated index @i=" << i << " idx=" << val; \ + occurrence.insert(val); \ + } \ + if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } \ + } while (false) + +using SWoRTestF = SWoRTest; +TEST_P(SWoRTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); } INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestF, ::testing::ValuesIn(inputsf)); -typedef SWoRTest SWoRTestD; +using SWoRMdspanTestF = SWoRMdspanTest; +TEST_P(SWoRMdspanTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); } +INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = {{1024, 512, -1, 0.0, GenPhilox, 1234ULL}, {1024, 1024, -1, 0.0, GenPhilox, 1234ULL}, {1024, 512 + 1, -1, 0.0, GenPhilox, 1234ULL}, @@ -185,24 +264,13 @@ const std::vector> inputsd = {{1024, 512, -1, 0.0, GenPhilox, {1024 + 2, 1024 + 2, -1, 0.0, GenPC, 1234ULL}, {1024, 512, 10, 100000.0, GenPC, 1234ULL}}; -TEST_P(SWoRTestD, Result) -{ - std::set occurence; - for (int i = 0; i < params.sampledLen; ++i) { - auto val = h_outIdx[i]; - // indices must be in the given range - ASSERT_TRUE(0 <= val && val < params.len) - << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; - // indices should not repeat - ASSERT_TRUE(occurence.find(val) == occurence.end()) - << "repeated index @i=" << i << " idx=" << val; - occurence.insert(val); - } - // if there's a skewed distribution, the top index should correspond to the - // particular item with a large weight - if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } -} +using SWoRTestD = SWoRTest; +TEST_P(SWoRTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); } INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestD, ::testing::ValuesIn(inputsd)); +using SWoRMdspanTestD = SWoRMdspanTest; +TEST_P(SWoRMdspanTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); } +INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestD, ::testing::ValuesIn(inputsd)); + } // namespace random } // namespace raft