From 13cb362898f6baf437b74a340318df204a67783b Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Mon, 19 Sep 2022 16:56:02 -0700 Subject: [PATCH 1/5] Fix existing build errors --- cpp/include/raft/random/rng_state.hpp | 2 ++ cpp/test/nvtx.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/random/rng_state.hpp b/cpp/include/raft/random/rng_state.hpp index 44372902b1..ec15ef286f 100644 --- a/cpp/include/raft/random/rng_state.hpp +++ b/cpp/include/raft/random/rng_state.hpp @@ -19,6 +19,8 @@ #pragma once +#include + namespace raft { namespace random { diff --git a/cpp/test/nvtx.cpp b/cpp/test/nvtx.cpp index 81f692a215..d982642929 100644 --- a/cpp/test/nvtx.cpp +++ b/cpp/test/nvtx.cpp @@ -15,7 +15,7 @@ */ #ifdef NVTX_ENABLED #include -#include +#include /** * tests for the functionality of generating next color based on string * entered in the NVTX Range marker wrappers From 9af3166e8654c13b05cdf91859259d67aa95edb5 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Thu, 8 Sep 2022 09:42:00 -0700 Subject: [PATCH 2/5] sampleWithoutReplacement: Add overload taking mdspans --- cpp/include/raft/random/rng.cuh | 80 +++++++++++ cpp/test/random/sample_without_replacement.cu | 128 +++++++++++++----- 2 files changed, 172 insertions(+), 36 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 85d9abe263..3ff591627a 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -19,7 +19,11 @@ #include "detail/rng_impl.cuh" #include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed) #include "rng_state.hpp" +#include +#include #include +#include +#include namespace raft::random { @@ -379,6 +383,82 @@ void sampleWithoutReplacement(const raft::handle_t& handle, rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); } +/** + * @brief Sample the input array without replacement, optionally based on the + * input weight vector for each element in the array + * + * Implementation here is based on the `one-pass sampling` algo described here: + * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * + * @note In the sampled array the elements which are picked will always appear + * in the increasing order of their weights as computed using the exponential + * distribution. So, if you're particular about the order (for eg. array + * permutations), then this might not be the right choice! + * + * @tparam DataT data type (do not specify this explicitly, + * as the compiler will deduce it from @c out ) + * @tparam IdxT index type (do not specify this explicitly, + * as the compiler will deduce it from the arguments) + * @tparam WeightsT weights type (defaults to @c double , + * so that passing in std::nullopt works) + * + * @param handle + * @param rng_state random number generator state + * @param out Array (of length 'sampledLen') + * of samples from the input array + * @param outIdx If provided, array (of length 'sampledLen') + * of the indices sampled from the input array + * @param in Input array to be sampled (of length 'len') + * @param wts Weights array (of length 'len'). + * If not provided, uniform sampling will be used. + */ +template +void sampleWithoutReplacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + std::optional> outIdx, + raft::device_vector_view in, + std::optional> wts) +{ + static_assert(std::is_integral::value, "IdxT must be an integral type."); + const IdxT sampledLen = out.extent(0); + const IdxT len = in.extent(0); + ASSERT(sampledLen <= len, + "sampleWithoutReplacement: " + "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); + ASSERT(len == 0 || in.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If in.data_handle() is not null, then in.extent(0) must be nonzero"); + ASSERT(sampledLen == 0 || out.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If out.data_handle() is not null, then out.extent(0) must be nonzero"); + + const bool outIdx_has_value = outIdx.has_value(); + if (outIdx_has_value) { + ASSERT((*outIdx).extent(0) == sampledLen, + "sampleWithoutReplacement: " + "If outIdx is provided, its extent(0) must equal out.extent(0)"); + } + IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr; + + const bool wts_has_value = wts.has_value(); + if (wts_has_value) { + ASSERT((*wts).extent(0) == len, + "sampleWithoutReplacement: " + "If wts is provided, its extent(0) must equal in.extent(0)"); + } + const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr; + + detail::sampleWithoutReplacement(rng_state, + out.data_handle(), + outIdx_ptr, + in.data_handle(), + wts_ptr, + sampledLen, + len, + handle.get_stream()); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 653a9f9bc9..5cf8106008 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -84,7 +84,63 @@ class SWoRTest : public ::testing::TestWithParam> { std::vector h_outIdx; }; -typedef SWoRTest SWoRTestF; +template +class SWoRMdspanTest : public ::testing::TestWithParam> { + public: + SWoRMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + in(params.len, stream), + wts(params.len, stream), + out(params.sampledLen, stream), + outIdx(params.sampledLen, stream) + { + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + h_outIdx.resize(params.sampledLen); + uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0)); + uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0)); + if (params.largeWeightIndex >= 0) { + update_device(wts.data() + params.largeWeightIndex, ¶ms.largeWeight, 1, stream); + } + { + using index_type = int; + using output_view = raft::device_vector_view; + output_view out_view{out.data(), out.size()}; + ASSERT_TRUE(out_view.extent(0) == params.sampledLen); + + using output_idxs_view = raft::device_vector_view; + std::optional outIdx_view{std::in_place, outIdx.data(), outIdx.size()}; + ASSERT_TRUE(outIdx_view.has_value() && outIdx_view.value().extent(0) == params.sampledLen); + + using input_view = raft::device_vector_view; + input_view in_view{in.data(), in.size()}; + ASSERT_TRUE(in_view.extent(0) == params.len); + + using weights_view = raft::device_vector_view; + std::optional wts_view{std::in_place, wts.data(), wts.size()}; + ASSERT_TRUE(wts_view.has_value() && wts_view.value().extent(0) == params.len); + + sampleWithoutReplacement(handle, r, out_view, outIdx_view, in_view, wts_view); + } + update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); + handle.sync_stream(stream); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + SWoRInputs params; + rmm::device_uvector in, out, wts; + rmm::device_uvector outIdx; + std::vector h_outIdx; +}; + const std::vector> inputsf = {{1024, 512, -1, 0.f, GenPhilox, 1234ULL}, {1024, 1024, -1, 0.f, GenPhilox, 1234ULL}, {1024, 512 + 1, -1, 0.f, GenPhilox, 1234ULL}, @@ -125,26 +181,37 @@ const std::vector> inputsf = {{1024, 512, -1, 0.f, GenPhilox, {1024 + 2, 1024 + 2, -1, 0.f, GenPC, 1234ULL}, {1024, 512, 10, 100000.f, GenPC, 1234ULL}}; -TEST_P(SWoRTestF, Result) -{ - std::set occurence; - for (int i = 0; i < params.sampledLen; ++i) { - auto val = h_outIdx[i]; - // indices must be in the given range - ASSERT_TRUE(0 <= val && val < params.len) - << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; - // indices should not repeat - ASSERT_TRUE(occurence.find(val) == occurence.end()) - << "repeated index @i=" << i << " idx=" << val; - occurence.insert(val); - } - // if there's a skewed distribution, the top index should correspond to the - // particular item with a large weight - if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } -} +// This needs to be a macro because it has to live in the scope +// of the class whose name is the first parameter of TEST_P. +// +// We test the following. +// +// 1. Output indices are in the given range. +// 2. Output indices do not repeat. +// 3. If there's a skewed distribution, the top index should +// correspond to the particular item with a large weight. +#define _RAFT_SWOR_TEST_CONTENTS() \ + do { \ + std::set occurrence; \ + for (int i = 0; i < params.sampledLen; ++i) { \ + auto val = h_outIdx[i]; \ + ASSERT_TRUE(0 <= val && val < params.len) \ + << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; \ + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) \ + << "repeated index @i=" << i << " idx=" << val; \ + occurrence.insert(val); \ + } \ + if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } \ + } while (false) + +using SWoRTestF = SWoRTest; +TEST_P(SWoRTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); } INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestF, ::testing::ValuesIn(inputsf)); -typedef SWoRTest SWoRTestD; +using SWoRMdspanTestF = SWoRMdspanTest; +TEST_P(SWoRMdspanTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); } +INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = {{1024, 512, -1, 0.0, GenPhilox, 1234ULL}, {1024, 1024, -1, 0.0, GenPhilox, 1234ULL}, {1024, 512 + 1, -1, 0.0, GenPhilox, 1234ULL}, @@ -185,24 +252,13 @@ const std::vector> inputsd = {{1024, 512, -1, 0.0, GenPhilox, {1024 + 2, 1024 + 2, -1, 0.0, GenPC, 1234ULL}, {1024, 512, 10, 100000.0, GenPC, 1234ULL}}; -TEST_P(SWoRTestD, Result) -{ - std::set occurence; - for (int i = 0; i < params.sampledLen; ++i) { - auto val = h_outIdx[i]; - // indices must be in the given range - ASSERT_TRUE(0 <= val && val < params.len) - << "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; - // indices should not repeat - ASSERT_TRUE(occurence.find(val) == occurence.end()) - << "repeated index @i=" << i << " idx=" << val; - occurence.insert(val); - } - // if there's a skewed distribution, the top index should correspond to the - // particular item with a large weight - if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } -} +using SWoRTestD = SWoRTest; +TEST_P(SWoRTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); } INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestD, ::testing::ValuesIn(inputsd)); +using SWoRMdspanTestD = SWoRMdspanTest; +TEST_P(SWoRMdspanTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); } +INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestD, ::testing::ValuesIn(inputsd)); + } // namespace random } // namespace raft From c391f9e4a77e3b7a5b9d021a4e04b8ce86b3c3a7 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Tue, 20 Sep 2022 12:41:27 -0700 Subject: [PATCH 3/5] Fix copyright year --- cpp/test/nvtx.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/nvtx.cpp b/cpp/test/nvtx.cpp index d982642929..635fe55012 100644 --- a/cpp/test/nvtx.cpp +++ b/cpp/test/nvtx.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 5a0f01236890808fe0cff3d3c7c715047ea8b7cf Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Wed, 21 Sep 2022 07:13:32 -0700 Subject: [PATCH 4/5] Address review feedback Rename new sampleWithoutReplacement overload to sample_without_replacement. Fix order of parameters: handle, rng, in, out, params. Use RAFT_EXPECTS instead of ASSERT. Make sure that users can pass in std::nullopt explicitly for either or both std::optional parameters. Use "vector" instead of "array" in documentation, use [in], [out], and [inout] as appropriate, and make other documentation improvements. Make error messages more clear. --- cpp/include/raft/random/rng.cuh | 211 +++++++++++------- cpp/test/random/sample_without_replacement.cu | 26 ++- 2 files changed, 148 insertions(+), 89 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 3ff591627a..550b818c9a 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -344,108 +344,78 @@ void laplace(const raft::handle_t& handle, } /** - * @brief Sample the input array without replacement, optionally based on the - * input weight vector for each element in the array + * @brief Sample the input vector without replacement, optionally based on the + * input weight vector for each element in the array. * - * Implementation here is based on the `one-pass sampling` algo described here: - * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * The implementation is based on the `one-pass sampling` algorithm described in + * ["Accelerating weighted random sampling without + * replacement,"](https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf) + * a technical report by Kirill Mueller. * - * @note In the sampled array the elements which are picked will always appear - * in the increasing order of their weights as computed using the exponential - * distribution. So, if you're particular about the order (for eg. array - * permutations), then this might not be the right choice! + * If no input weight vector is provided, then input elements will be + * sampled uniformly. Otherwise, the elements sampled from the input + * vector will always appear in increasing order of their weights as + * computed using the exponential distribution. So, if you are + * particular about the order (for e.g., array permutations), then + * this might not be the right choice. * - * @tparam DataT data type - * @tparam WeightsT weights type - * @tparam IdxT index type - * @param[in] handle raft handle for resource management - * @param[in] rng_state random number generator state - * @param[out] out output sampled array (of length 'sampledLen') - * @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass - * a nullptr if this is not required. - * @param[in] in input array to be sampled (of length 'len') - * @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform - * sampling is desired - * @param[in] sampledLen output sampled array length - * @param[in] len input array length - */ -template -void sampleWithoutReplacement(const raft::handle_t& handle, - RngState& rng_state, - DataT* out, - IdxT* outIdx, - const DataT* in, - const WeightsT* wts, - IdxT sampledLen, - IdxT len) -{ - detail::sampleWithoutReplacement( - rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); -} - -/** - * @brief Sample the input array without replacement, optionally based on the - * input weight vector for each element in the array + * @tparam DataT type of each element of the input array @c in + * @tparam IdxT type of the dimensions of the arrays; output index type + * @tparam WeightsT type of each elements of the weights array @c wts * - * Implementation here is based on the `one-pass sampling` algo described here: - * https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf + * @note Please do not specify template parameters explicitly, + * as the compiler can deduce them from the arguments. * - * @note In the sampled array the elements which are picked will always appear - * in the increasing order of their weights as computed using the exponential - * distribution. So, if you're particular about the order (for eg. array - * permutations), then this might not be the right choice! + * @param[in] handle RAFT handle containing (among other resources) + * the CUDA stream on which to run. + * @param[inout] rng_state Pseudorandom number generator state. + * @param[in] in Input vector to be sampled. + * @param[in] wts Optional weights vector. + * If not provided, uniform sampling will be used. + * @param[out] out Vector of samples from the input vector. + * @param[out] outIdx If provided, vector of the indices + * sampled from the input array. * - * @tparam DataT data type (do not specify this explicitly, - * as the compiler will deduce it from @c out ) - * @tparam IdxT index type (do not specify this explicitly, - * as the compiler will deduce it from the arguments) - * @tparam WeightsT weights type (defaults to @c double , - * so that passing in std::nullopt works) + * @pre The number of samples `out.extent(0)` + * is less than or equal to the number of inputs `in.extent(0)`. * - * @param handle - * @param rng_state random number generator state - * @param out Array (of length 'sampledLen') - * of samples from the input array - * @param outIdx If provided, array (of length 'sampledLen') - * of the indices sampled from the input array - * @param in Input array to be sampled (of length 'len') - * @param wts Weights array (of length 'len'). - * If not provided, uniform sampling will be used. + * @pre The number of weights `wts.extent(0)` + * equals the number of inputs `in.extent(0)`. */ -template -void sampleWithoutReplacement(const raft::handle_t& handle, - RngState& rng_state, - raft::device_vector_view out, - std::optional> outIdx, - raft::device_vector_view in, - std::optional> wts) +template +void sample_without_replacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view in, + std::optional> wts, + raft::device_vector_view out, + std::optional> outIdx) { static_assert(std::is_integral::value, "IdxT must be an integral type."); const IdxT sampledLen = out.extent(0); const IdxT len = in.extent(0); - ASSERT(sampledLen <= len, - "sampleWithoutReplacement: " - "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); - ASSERT(len == 0 || in.data_handle() != nullptr, - "sampleWithoutReplacement: " - "If in.data_handle() is not null, then in.extent(0) must be nonzero"); - ASSERT(sampledLen == 0 || out.data_handle() != nullptr, - "sampleWithoutReplacement: " - "If out.data_handle() is not null, then out.extent(0) must be nonzero"); + RAFT_EXPECTS(sampledLen <= len, + "sampleWithoutReplacement: " + "sampledLen (out.extent(0)) must be <= len (in.extent(0))"); + RAFT_EXPECTS(len == 0 || in.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If in.extents(0) != 0, then in.data_handle() must be nonnull"); + RAFT_EXPECTS(sampledLen == 0 || out.data_handle() != nullptr, + "sampleWithoutReplacement: " + "If out.extents(0) != 0, then out.data_handle() must be nonnull"); const bool outIdx_has_value = outIdx.has_value(); if (outIdx_has_value) { - ASSERT((*outIdx).extent(0) == sampledLen, - "sampleWithoutReplacement: " - "If outIdx is provided, its extent(0) must equal out.extent(0)"); + RAFT_EXPECTS((*outIdx).extent(0) == sampledLen, + "sampleWithoutReplacement: " + "If outIdx is provided, its extent(0) must equal out.extent(0)"); } IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr; const bool wts_has_value = wts.has_value(); if (wts_has_value) { - ASSERT((*wts).extent(0) == len, - "sampleWithoutReplacement: " - "If wts is provided, its extent(0) must equal in.extent(0)"); + RAFT_EXPECTS((*wts).extent(0) == len, + "sampleWithoutReplacement: " + "If wts is provided, its extent(0) must equal in.extent(0)"); } const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr; @@ -459,6 +429,83 @@ void sampleWithoutReplacement(const raft::handle_t& handle, handle.get_stream()); } +namespace sample_without_replacement_impl { +template +struct weight_alias { +}; + +template <> +struct weight_alias { + using type = double; +}; + +template +struct weight_alias>> { + using type = typename raft::device_vector_view::value_type; +}; + +template +using weight_t = typename weight_alias::type; +} // namespace sample_without_replacement_impl + +/** + * @brief Overload of `sample_without_replacement` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `sample_without_replacement`. + */ +template +void sample_without_replacement(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view in, + WeightsVectorType&& wts, + raft::device_vector_view out, + OutIndexVectorType&& outIdx) +{ + using weight_type = sample_without_replacement_impl::weight_t< + std::remove_const_t>>; + std::optional> weights = + std::forward(wts); + std::optional> output_indices = + std::forward(outIdx); + + sample_without_replacement(handle, rng_state, in, weights, out, output_indices); +} + +/** + * @brief Legacy version of @c sample_without_replacement (see above) + * that takes raw arrays instead of device mdspan. + * + * @tparam DataT data type + * @tparam WeightsT weights type + * @tparam IdxT index type + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output sampled array (of length 'sampledLen') + * @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass + * a nullptr if this is not required. + * @param[in] in input array to be sampled (of length 'len') + * @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform + * sampling is desired + * @param[in] sampledLen output sampled array length + * @param[in] len input array length + */ +template +void sampleWithoutReplacement(const raft::handle_t& handle, + RngState& rng_state, + DataT* out, + IdxT* outIdx, + const DataT* in, + const WeightsT* wts, + IdxT sampledLen, + IdxT len) +{ + detail::sampleWithoutReplacement( + rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 5cf8106008..482b35168a 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -70,7 +70,7 @@ class SWoRTest : public ::testing::TestWithParam> { } sampleWithoutReplacement( handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len); - update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream); + update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); handle.sync_stream(stream); } @@ -93,7 +93,9 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { in(params.len, stream), wts(params.len, stream), out(params.sampledLen, stream), - outIdx(params.sampledLen, stream) + out2(params.sampledLen, stream), + outIdx(params.sampledLen, stream), + outIdx2(params.sampledLen, stream) { } @@ -115,7 +117,7 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { using output_idxs_view = raft::device_vector_view; std::optional outIdx_view{std::in_place, outIdx.data(), outIdx.size()}; - ASSERT_TRUE(outIdx_view.has_value() && outIdx_view.value().extent(0) == params.sampledLen); + ASSERT_TRUE(outIdx_view.value().extent(0) == params.sampledLen); using input_view = raft::device_vector_view; input_view in_view{in.data(), in.size()}; @@ -123,9 +125,19 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { using weights_view = raft::device_vector_view; std::optional wts_view{std::in_place, wts.data(), wts.size()}; - ASSERT_TRUE(wts_view.has_value() && wts_view.value().extent(0) == params.len); + ASSERT_TRUE(wts_view.value().extent(0) == params.len); + + sample_without_replacement(handle, r, in_view, wts_view, out_view, outIdx_view); - sampleWithoutReplacement(handle, r, out_view, outIdx_view, in_view, wts_view); + output_view out2_view{out2.data(), out2.size()}; + ASSERT_TRUE(out2_view.extent(0) == params.sampledLen); + std::optional outIdx2_view{std::in_place, outIdx2.data(), outIdx2.size()}; + ASSERT_TRUE(outIdx2_view.value().extent(0) == params.sampledLen); + + // For now, just test that these calls compile. + sample_without_replacement(handle, r, in_view, wts_view, out2_view, std::nullopt); + sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, outIdx2_view); + sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, std::nullopt); } update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream); handle.sync_stream(stream); @@ -136,8 +148,8 @@ class SWoRMdspanTest : public ::testing::TestWithParam> { cudaStream_t stream; SWoRInputs params; - rmm::device_uvector in, out, wts; - rmm::device_uvector outIdx; + rmm::device_uvector in, out, wts, out2; + rmm::device_uvector outIdx, outIdx2; std::vector h_outIdx; }; From 6371101dfd360d52107e92f363995c7cb6b7b22a Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Thu, 22 Sep 2022 09:27:21 -0700 Subject: [PATCH 5/5] Fix include after rebase --- cpp/include/raft/random/rng.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 550b818c9a..ba6254bfc3 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -21,8 +21,8 @@ #include "rng_state.hpp" #include #include +#include #include -#include #include namespace raft::random {