Skip to content

Commit

Permalink
mdspanify sampleWithoutReplacement (#830)
Browse files Browse the repository at this point in the history
I added an overload of `sampleWithoutReplacement` that takes device `mdspan` instead of raw arrays.  The overload uses `std::optional<mdspan<...>>` to express optional mdspan input or output arguments.  I've added a unit test that imitates the existing unit test for the raw-array overload; it builds and passes.

I also opportunistically fixed some unrelated existing small build errors that were blocking forward progress.

Authors:
  - Mark Hoemmen (https://github.com/mhoemmen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #830
  • Loading branch information
mhoemmen authored Sep 23, 2022
1 parent f25e4c8 commit 1dd2feb
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 45 deletions.
143 changes: 135 additions & 8 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
#include "detail/rng_impl.cuh"
#include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed)
#include "rng_state.hpp"
#include <cassert>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <type_traits>

namespace raft::random {

Expand Down Expand Up @@ -340,20 +344,143 @@ void laplace(const raft::handle_t& handle,
}

/**
* @brief Sample the input array without replacement, optionally based on the
* input weight vector for each element in the array
* @brief Sample the input vector without replacement, optionally based on the
* input weight vector for each element in the array.
*
* Implementation here is based on the `one-pass sampling` algo described here:
* https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
* The implementation is based on the `one-pass sampling` algorithm described in
* ["Accelerating weighted random sampling without
* replacement,"](https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf)
* a technical report by Kirill Mueller.
*
* @note In the sampled array the elements which are picked will always appear
* in the increasing order of their weights as computed using the exponential
* distribution. So, if you're particular about the order (for eg. array
* permutations), then this might not be the right choice!
* If no input weight vector is provided, then input elements will be
* sampled uniformly. Otherwise, the elements sampled from the input
* vector will always appear in increasing order of their weights as
* computed using the exponential distribution. So, if you are
* particular about the order (for e.g., array permutations), then
* this might not be the right choice.
*
* @tparam DataT type of each element of the input array @c in
* @tparam IdxT type of the dimensions of the arrays; output index type
* @tparam WeightsT type of each elements of the weights array @c wts
*
* @note Please do not specify template parameters explicitly,
* as the compiler can deduce them from the arguments.
*
* @param[in] handle RAFT handle containing (among other resources)
* the CUDA stream on which to run.
* @param[inout] rng_state Pseudorandom number generator state.
* @param[in] in Input vector to be sampled.
* @param[in] wts Optional weights vector.
* If not provided, uniform sampling will be used.
* @param[out] out Vector of samples from the input vector.
* @param[out] outIdx If provided, vector of the indices
* sampled from the input array.
*
* @pre The number of samples `out.extent(0)`
* is less than or equal to the number of inputs `in.extent(0)`.
*
* @pre The number of weights `wts.extent(0)`
* equals the number of inputs `in.extent(0)`.
*/
template <typename DataT, typename IdxT, typename WeightsT>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx)
{
static_assert(std::is_integral<IdxT>::value, "IdxT must be an integral type.");
const IdxT sampledLen = out.extent(0);
const IdxT len = in.extent(0);
RAFT_EXPECTS(sampledLen <= len,
"sampleWithoutReplacement: "
"sampledLen (out.extent(0)) must be <= len (in.extent(0))");
RAFT_EXPECTS(len == 0 || in.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If in.extents(0) != 0, then in.data_handle() must be nonnull");
RAFT_EXPECTS(sampledLen == 0 || out.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If out.extents(0) != 0, then out.data_handle() must be nonnull");

const bool outIdx_has_value = outIdx.has_value();
if (outIdx_has_value) {
RAFT_EXPECTS((*outIdx).extent(0) == sampledLen,
"sampleWithoutReplacement: "
"If outIdx is provided, its extent(0) must equal out.extent(0)");
}
IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr;

const bool wts_has_value = wts.has_value();
if (wts_has_value) {
RAFT_EXPECTS((*wts).extent(0) == len,
"sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
}
const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;

detail::sampleWithoutReplacement(rng_state,
out.data_handle(),
outIdx_ptr,
in.data_handle(),
wts_ptr,
sampledLen,
len,
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Overload of `sample_without_replacement` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sample_without_replacement`.
*/
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
WeightsVectorType&& wts,
raft::device_vector_view<DataT, IdxT> out,
OutIndexVectorType&& outIdx)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;
std::optional<raft::device_vector_view<const weight_type, IdxT>> weights =
std::forward<WeightsVectorType>(wts);
std::optional<raft::device_vector_view<IdxT, IdxT>> output_indices =
std::forward<OutIndexVectorType>(outIdx);

sample_without_replacement(handle, rng_state, in, weights, out, output_indices);
}

/**
* @brief Legacy version of @c sample_without_replacement (see above)
* that takes raw arrays instead of device mdspan.
*
* @tparam DataT data type
* @tparam WeightsT weights type
* @tparam IdxT index type
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output sampled array (of length 'sampledLen')
Expand Down
142 changes: 105 additions & 37 deletions cpp/test/random/sample_without_replacement.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SWoRTest : public ::testing::TestWithParam<SWoRInputs<T>> {
}
sampleWithoutReplacement(
handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len);
update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream);
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
}

Expand All @@ -84,7 +84,75 @@ class SWoRTest : public ::testing::TestWithParam<SWoRInputs<T>> {
std::vector<int> h_outIdx;
};

typedef SWoRTest<float> SWoRTestF;
template <typename T>
class SWoRMdspanTest : public ::testing::TestWithParam<SWoRInputs<T>> {
public:
SWoRMdspanTest()
: params(::testing::TestWithParam<SWoRInputs<T>>::GetParam()),
stream(handle.get_stream()),
in(params.len, stream),
wts(params.len, stream),
out(params.sampledLen, stream),
out2(params.sampledLen, stream),
outIdx(params.sampledLen, stream),
outIdx2(params.sampledLen, stream)
{
}

protected:
void SetUp() override
{
RngState r(params.seed, params.gtype);
h_outIdx.resize(params.sampledLen);
uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0));
uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0));
if (params.largeWeightIndex >= 0) {
update_device(wts.data() + params.largeWeightIndex, &params.largeWeight, 1, stream);
}
{
using index_type = int;
using output_view = raft::device_vector_view<T, index_type>;
output_view out_view{out.data(), out.size()};
ASSERT_TRUE(out_view.extent(0) == params.sampledLen);

using output_idxs_view = raft::device_vector_view<index_type, index_type>;
std::optional<output_idxs_view> outIdx_view{std::in_place, outIdx.data(), outIdx.size()};
ASSERT_TRUE(outIdx_view.value().extent(0) == params.sampledLen);

using input_view = raft::device_vector_view<const T, index_type>;
input_view in_view{in.data(), in.size()};
ASSERT_TRUE(in_view.extent(0) == params.len);

using weights_view = raft::device_vector_view<const T, index_type>;
std::optional<weights_view> wts_view{std::in_place, wts.data(), wts.size()};
ASSERT_TRUE(wts_view.value().extent(0) == params.len);

sample_without_replacement(handle, r, in_view, wts_view, out_view, outIdx_view);

output_view out2_view{out2.data(), out2.size()};
ASSERT_TRUE(out2_view.extent(0) == params.sampledLen);
std::optional<output_idxs_view> outIdx2_view{std::in_place, outIdx2.data(), outIdx2.size()};
ASSERT_TRUE(outIdx2_view.value().extent(0) == params.sampledLen);

// For now, just test that these calls compile.
sample_without_replacement(handle, r, in_view, wts_view, out2_view, std::nullopt);
sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, outIdx2_view);
sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, std::nullopt);
}
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
}

protected:
raft::handle_t handle;
cudaStream_t stream;

SWoRInputs<T> params;
rmm::device_uvector<T> in, out, wts, out2;
rmm::device_uvector<int> outIdx, outIdx2;
std::vector<int> h_outIdx;
};

const std::vector<SWoRInputs<float>> inputsf = {{1024, 512, -1, 0.f, GenPhilox, 1234ULL},
{1024, 1024, -1, 0.f, GenPhilox, 1234ULL},
{1024, 512 + 1, -1, 0.f, GenPhilox, 1234ULL},
Expand Down Expand Up @@ -125,26 +193,37 @@ const std::vector<SWoRInputs<float>> inputsf = {{1024, 512, -1, 0.f, GenPhilox,
{1024 + 2, 1024 + 2, -1, 0.f, GenPC, 1234ULL},
{1024, 512, 10, 100000.f, GenPC, 1234ULL}};

TEST_P(SWoRTestF, Result)
{
std::set<int> occurence;
for (int i = 0; i < params.sampledLen; ++i) {
auto val = h_outIdx[i];
// indices must be in the given range
ASSERT_TRUE(0 <= val && val < params.len)
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen;
// indices should not repeat
ASSERT_TRUE(occurence.find(val) == occurence.end())
<< "repeated index @i=" << i << " idx=" << val;
occurence.insert(val);
}
// if there's a skewed distribution, the top index should correspond to the
// particular item with a large weight
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); }
}
// This needs to be a macro because it has to live in the scope
// of the class whose name is the first parameter of TEST_P.
//
// We test the following.
//
// 1. Output indices are in the given range.
// 2. Output indices do not repeat.
// 3. If there's a skewed distribution, the top index should
// correspond to the particular item with a large weight.
#define _RAFT_SWOR_TEST_CONTENTS() \
do { \
std::set<int> occurrence; \
for (int i = 0; i < params.sampledLen; ++i) { \
auto val = h_outIdx[i]; \
ASSERT_TRUE(0 <= val && val < params.len) \
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; \
ASSERT_TRUE(occurrence.find(val) == occurrence.end()) \
<< "repeated index @i=" << i << " idx=" << val; \
occurrence.insert(val); \
} \
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } \
} while (false)

using SWoRTestF = SWoRTest<float>;
TEST_P(SWoRTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestF, ::testing::ValuesIn(inputsf));

typedef SWoRTest<double> SWoRTestD;
using SWoRMdspanTestF = SWoRMdspanTest<float>;
TEST_P(SWoRMdspanTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestF, ::testing::ValuesIn(inputsf));

const std::vector<SWoRInputs<double>> inputsd = {{1024, 512, -1, 0.0, GenPhilox, 1234ULL},
{1024, 1024, -1, 0.0, GenPhilox, 1234ULL},
{1024, 512 + 1, -1, 0.0, GenPhilox, 1234ULL},
Expand Down Expand Up @@ -185,24 +264,13 @@ const std::vector<SWoRInputs<double>> inputsd = {{1024, 512, -1, 0.0, GenPhilox,
{1024 + 2, 1024 + 2, -1, 0.0, GenPC, 1234ULL},
{1024, 512, 10, 100000.0, GenPC, 1234ULL}};

TEST_P(SWoRTestD, Result)
{
std::set<int> occurence;
for (int i = 0; i < params.sampledLen; ++i) {
auto val = h_outIdx[i];
// indices must be in the given range
ASSERT_TRUE(0 <= val && val < params.len)
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen;
// indices should not repeat
ASSERT_TRUE(occurence.find(val) == occurence.end())
<< "repeated index @i=" << i << " idx=" << val;
occurence.insert(val);
}
// if there's a skewed distribution, the top index should correspond to the
// particular item with a large weight
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); }
}
using SWoRTestD = SWoRTest<double>;
TEST_P(SWoRTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestD, ::testing::ValuesIn(inputsd));

using SWoRMdspanTestD = SWoRMdspanTest<double>;
TEST_P(SWoRMdspanTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestD, ::testing::ValuesIn(inputsd));

} // namespace random
} // namespace raft

0 comments on commit 1dd2feb

Please sign in to comment.