diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 3ff591627a..550b818c9a 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -344,108 +344,78 @@ void laplace(const raft::handle_t& handle, } /** - * @brief Sample the input array without replacement, optionally based on the - * input weight vector for each element in the array + * @brief Sample the input vector without replacement, optionally based on the + * input weight vector for each element in the array. * - * Implementation here is based on the `one-pass sampling` algo described here: - * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * The implementation is based on the `one-pass sampling` algorithm described in + * ["Accelerating weighted random sampling without + * replacement,"](https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf) + * a technical report by Kirill Mueller. * - * @note In the sampled array the elements which are picked will always appear - * in the increasing order of their weights as computed using the exponential - * distribution. So, if you're particular about the order (for eg. array - * permutations), then this might not be the right choice! + * If no input weight vector is provided, then input elements will be + * sampled uniformly. Otherwise, the elements sampled from the input + * vector will always appear in increasing order of their weights as + * computed using the exponential distribution. So, if you are + * particular about the order (for e.g., array permutations), then + * this might not be the right choice. * - * @tparam DataT data type - * @tparam WeightsT weights type - * @tparam IdxT index type - * @param[in] handle raft handle for resource management - * @param[in] rng_state random number generator state - * @param[out] out output sampled array (of length 'sampledLen') - * @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass - * a nullptr if this is not required. - * @param[in] in input array to be sampled (of length 'len') - * @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform - * sampling is desired - * @param[in] sampledLen output sampled array length - * @param[in] len input array length - */ -template -void sampleWithoutReplacement(const raft::handle_t& handle, - RngState& rng_state, - DataT* out, - IdxT* outIdx, - const DataT* in, - const WeightsT* wts, - IdxT sampledLen, - IdxT len) -{ - detail::sampleWithoutReplacement( - rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); -} - -/** - * @brief Sample the input array without replacement, optionally based on the - * input weight vector for each element in the array + * @tparam DataT type of each element of the input array @c in + * @tparam IdxT type of the dimensions of the arrays; output index type + * @tparam WeightsT type of each elements of the weights array @c wts * - * Implementation here is based on the `one-pass sampling` algo described here: - * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * @note Please do not specify template parameters explicitly, + * as the compiler can deduce them from the arguments. * - * @note In the sampled array the elements which are picked will always appear - * in the increasing order of their weights as computed using the exponential - * distribution. So, if you're particular about the order (for eg. array - * permutations), then this might not be the right choice! + * @param[in] handle RAFT handle containing (among other resources) + * the CUDA stream on which to run. + * @param[inout] rng_state Pseudorandom number generator state. + * @param[in] in Input vector to be sampled. + * @param[in] wts Optional weights vector. + * If not provided, uniform sampling will be used. + * @param[out] out Vector of samples from the input vector. + * @param[out] outIdx If provided, vector of the indices + * sampled from the input array. * - * @tparam DataT data type (do not specify this explicitly, - * as the compiler will deduce it from @c out ) - * @tparam IdxT index type (do not specify this explicitly, - * as the compiler will deduce it from the arguments) - * @tparam WeightsT weights type (defaults to @c double , - * so that passing in std::nullopt works) + * @pre The number of samples `out.extent(0)` + * is less than or equal to the number of inputs `in.extent(0)`. * - * @param handle - * @param rng_state random number generator state - * @param out Array (of length 'sampledLen') - * of samples from the input array - * @param outIdx If provided, array (of length 'sampledLen') - * of the indices sampled from the input array - * @param in Input array to be sampled (of length 'len') - * @param wts Weights array (of length 'len'). - * If not provided, uniform sampling will be used. + * @pre The number of weights `wts.extent(0)` + * equals the number of inputs `in.extent(0)`. */ -template -void sampleWithoutReplacement(const raft::handle_t& handle, - RngState& rng_state, - raft::device_vector_view out, - std::optional> outIdx, - raft::device_vector_view in, - std::optional> wts) +template +void sample_without_replacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view in, + std::optional> wts, + raft::device_vector_view out, + std::optional> outIdx) { static_assert(std::is_integral::value, "IdxT must be an integral type."); const IdxT sampledLen = out.extent(0); const IdxT len = in.extent(0); - ASSERT(sampledLen <= len, - "sampleWithoutReplacement: " - "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); - ASSERT(len == 0 || in.data_handle() != nullptr, - "sampleWithoutReplacement: " - "If in.data_handle() is not null, then in.extent(0) must be nonzero"); - ASSERT(sampledLen == 0 || out.data_handle() != nullptr, - "sampleWithoutReplacement: " - "If out.data_handle() is not null, then out.extent(0) must be nonzero"); + RAFT_EXPECTS(sampledLen <= len, + "sampleWithoutReplacement: " + "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); + RAFT_EXPECTS(len == 0 || in.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If in.extents(0) != 0, then in.data_handle() must be nonnull"); + RAFT_EXPECTS(sampledLen == 0 || out.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If out.extents(0) != 0, then out.data_handle() must be nonnull"); const bool outIdx_has_value = outIdx.has_value(); if (outIdx_has_value) { - ASSERT((*outIdx).extent(0) == sampledLen, - "sampleWithoutReplacement: " - "If outIdx is provided, its extent(0) must equal out.extent(0)"); + RAFT_EXPECTS((*outIdx).extent(0) == sampledLen, + "sampleWithoutReplacement: " + "If outIdx is provided, its extent(0) must equal out.extent(0)"); } IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr; const bool wts_has_value = wts.has_value(); if (wts_has_value) { - ASSERT((*wts).extent(0) == len, - "sampleWithoutReplacement: " - "If wts is provided, its extent(0) must equal in.extent(0)"); + RAFT_EXPECTS((*wts).extent(0) == len, + "sampleWithoutReplacement: " + "If wts is provided, its extent(0) must equal in.extent(0)"); } const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr; @@ -459,6 +429,83 @@ void sampleWithoutReplacement(const raft::handle_t& handle, handle.get_stream()); } +namespace sample_without_replacement_impl { +template +struct weight_alias { +}; + +template <> +struct weight_alias { + using type = double; +}; + +template +struct weight_alias>> { + using type = typename raft::device_vector_view::value_type; +}; + +template +using weight_t = typename weight_alias::type; +} // namespace sample_without_replacement_impl + +/** + * @brief Overload of `sample_without_replacement` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `sample_without_replacement`. + */ +template +void sample_without_replacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view in, + WeightsVectorType&& wts, + raft::device_vector_view out, + OutIndexVectorType&& outIdx) +{ + using weight_type = sample_without_replacement_impl::weight_t< + std::remove_const_t>>; + std::optional> weights = + std::forward(wts); + std::optional> output_indices = + std::forward(outIdx); + + sample_without_replacement(handle, rng_state, in, weights, out, output_indices); +} + +/** + * @brief Legacy version of @c sample_without_replacement (see above) + * that takes raw arrays instead of device mdspan. + * + * @tparam DataT data type + * @tparam WeightsT weights type + * @tparam IdxT index type + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output sampled array (of length 'sampledLen') + * @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass + * a nullptr if this is not required. + * @param[in] in input array to be sampled (of length 'len') + * @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform + * sampling is desired + * @param[in] sampledLen output sampled array length + * @param[in] len input array length + */ +template +void sampleWithoutReplacement(const raft::handle_t& handle, + RngState& rng_state, + DataT* out, + IdxT* outIdx, + const DataT* in, + const WeightsT* wts, + IdxT sampledLen, + IdxT len) +{ + detail::sampleWithoutReplacement( + rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 5cf8106008..482b35168a 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -70,7 +70,7 @@ class SWoRTest : public ::testing::TestWithParam> { } sampleWithoutReplacement( handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len); - update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream); + update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); handle.sync_stream(stream); } @@ -93,7 +93,9 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { in(params.len, stream), wts(params.len, stream), out(params.sampledLen, stream), - outIdx(params.sampledLen, stream) + out2(params.sampledLen, stream), + outIdx(params.sampledLen, stream), + outIdx2(params.sampledLen, stream) { } @@ -115,7 +117,7 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { using output_idxs_view = raft::device_vector_view; std::optional outIdx_view{std::in_place, outIdx.data(), outIdx.size()}; - ASSERT_TRUE(outIdx_view.has_value() && outIdx_view.value().extent(0) == params.sampledLen); + ASSERT_TRUE(outIdx_view.value().extent(0) == params.sampledLen); using input_view = raft::device_vector_view; input_view in_view{in.data(), in.size()}; @@ -123,9 +125,19 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { using weights_view = raft::device_vector_view; std::optional wts_view{std::in_place, wts.data(), wts.size()}; - ASSERT_TRUE(wts_view.has_value() && wts_view.value().extent(0) == params.len); + ASSERT_TRUE(wts_view.value().extent(0) == params.len); + + sample_without_replacement(handle, r, in_view, wts_view, out_view, outIdx_view); - sampleWithoutReplacement(handle, r, out_view, outIdx_view, in_view, wts_view); + output_view out2_view{out2.data(), out2.size()}; + ASSERT_TRUE(out2_view.extent(0) == params.sampledLen); + std::optional outIdx2_view{std::in_place, outIdx2.data(), outIdx2.size()}; + ASSERT_TRUE(outIdx2_view.value().extent(0) == params.sampledLen); + + // For now, just test that these calls compile. + sample_without_replacement(handle, r, in_view, wts_view, out2_view, std::nullopt); + sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, outIdx2_view); + sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, std::nullopt); } update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); handle.sync_stream(stream); @@ -136,8 +148,8 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { cudaStream_t stream; SWoRInputs params; - rmm::device_uvector in, out, wts; - rmm::device_uvector outIdx; + rmm::device_uvector in, out, wts, out2; + rmm::device_uvector outIdx, outIdx2; std::vector h_outIdx; };