Skip to content

Commit

Permalink
sampleWithoutReplacement: Add overload taking mdspans
Browse files Browse the repository at this point in the history
  • Loading branch information
mhoemmen committed Sep 20, 2022
1 parent a4155a2 commit 3ae4e34
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 30 deletions.
71 changes: 71 additions & 0 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed)
#include "rng_state.hpp"
#include <raft/core/handle.hpp>
#include <raft/mdarray.hpp>
#include <cassert>
#include <optional>
#include <type_traits>

namespace raft::random {

Expand Down Expand Up @@ -379,6 +383,72 @@ void sampleWithoutReplacement(const raft::handle_t& handle,
rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream());
}

/**
* @brief Sample the input array without replacement, optionally based on the
* input weight vector for each element in the array
*
* Implementation here is based on the `one-pass sampling` algo described here:
* https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
*
* @note In the sampled array the elements which are picked will always appear
* in the increasing order of their weights as computed using the exponential
* distribution. So, if you're particular about the order (for eg. array
* permutations), then this might not be the right choice!
*
* @tparam DataT data type (do not specify this explicitly,
* as the compiler will deduce it from @c out )
* @tparam IdxT index type (do not specify this explicitly,
* as the compiler will deduce it from the arguments)
* @tparam WeightsT weights type (defaults to @c double ,
* so that passing in std::nullopt works)
*
* @param handle
* @param rng_state random number generator state
* @param out Array (of length 'sampledLen')
* of samples from the input array
* @param outIdx If provided, array (of length 'sampledLen')
* of the indices sampled from the input array
* @param in Input array to be sampled (of length 'len')
* @param wts Weights array (of length 'len').
* If not provided, uniform sampling will be used.
*/
template <typename DataT, typename IdxT, typename WeightsT = double>
void sampleWithoutReplacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts)
{
static_assert(std::is_integral<IdxT>::value, "IdxT must be an integral type.");
const IdxT sampledLen = out.extent(0);
const IdxT len = in.extent(0);
ASSERT(sampledLen <= len, "sampleWithoutReplacement: "
"sampledLen (out.extent(0)) must be <= len (in.extent(0))");
ASSERT(len == 0 || in.data_handle() != nullptr, "sampleWithoutReplacement: "
"If in.data_handle() is not null, then in.extent(0) must be nonzero");
ASSERT(sampledLen == 0 || out.data_handle() != nullptr, "sampleWithoutReplacement: "
"If out.data_handle() is not null, then out.extent(0) must be nonzero");

const bool outIdx_has_value = outIdx.has_value();
if (outIdx_has_value) {
ASSERT((*outIdx).extent(0) == sampledLen, "sampleWithoutReplacement: "
"If outIdx is provided, its extent(0) must equal out.extent(0)");
}
IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr;

const bool wts_has_value = wts.has_value();
if (wts_has_value) {
ASSERT((*wts).extent(0) == len, "sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
}
const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;

detail::sampleWithoutReplacement(
rng_state, out.data_handle(), outIdx_ptr, in.data_handle(), wts_ptr,
sampledLen, len, handle.get_stream());
}

/**
* @brief Generates the 'a' and 'b' parameters for a modulo affine
* transformation equation: `(ax + b) % n`
Expand Down Expand Up @@ -714,6 +784,7 @@ class DEPR Rng : public detail::RngImpl {
detail::RngImpl::sampleWithoutReplacement(
handle, out, outIdx, in, wts, sampledLen, len, stream);
}

};

#undef DEPR
Expand Down
127 changes: 97 additions & 30 deletions cpp/test/random/sample_without_replacement.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,63 @@ class SWoRTest : public ::testing::TestWithParam<SWoRInputs<T>> {
std::vector<int> h_outIdx;
};

typedef SWoRTest<float> SWoRTestF;
template <typename T>
class SWoRMdspanTest : public ::testing::TestWithParam<SWoRInputs<T>> {
public:
SWoRMdspanTest()
: params(::testing::TestWithParam<SWoRInputs<T>>::GetParam()),
stream(handle.get_stream()),
in(params.len, stream),
wts(params.len, stream),
out(params.sampledLen, stream),
outIdx(params.sampledLen, stream)
{
}

protected:
void SetUp() override
{
RngState r(params.seed, params.gtype);
h_outIdx.resize(params.sampledLen);
uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0));
uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0));
if (params.largeWeightIndex >= 0) {
update_device(wts.data() + params.largeWeightIndex, &params.largeWeight, 1, stream);
}
{
using index_type = int;
using output_view = raft::device_vector_view<T, index_type>;
output_view out_view{out.data(), out.size()};
ASSERT_TRUE(out_view.extent(0) == params.sampledLen);

using output_idxs_view = raft::device_vector_view<index_type, index_type>;
std::optional<output_idxs_view> outIdx_view{std::in_place, outIdx.data(), outIdx.size()};
ASSERT_TRUE(outIdx_view.has_value() && outIdx_view.value().extent(0) == params.sampledLen);

using input_view = raft::device_vector_view<const T, index_type>;
input_view in_view{in.data(), in.size()};
ASSERT_TRUE(in_view.extent(0) == params.len);

using weights_view = raft::device_vector_view<const T, index_type>;
std::optional<weights_view> wts_view{std::in_place, wts.data(), wts.size()};
ASSERT_TRUE(wts_view.has_value() && wts_view.value().extent(0) == params.len);

sampleWithoutReplacement(handle, r, out_view, outIdx_view, in_view, wts_view);
}
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
}

protected:
raft::handle_t handle;
cudaStream_t stream;

SWoRInputs<T> params;
rmm::device_uvector<T> in, out, wts;
rmm::device_uvector<int> outIdx;
std::vector<int> h_outIdx;
};

const std::vector<SWoRInputs<float>> inputsf = {{1024, 512, -1, 0.f, GenPhilox, 1234ULL},
{1024, 1024, -1, 0.f, GenPhilox, 1234ULL},
{1024, 512 + 1, -1, 0.f, GenPhilox, 1234ULL},
Expand Down Expand Up @@ -125,26 +181,42 @@ const std::vector<SWoRInputs<float>> inputsf = {{1024, 512, -1, 0.f, GenPhilox,
{1024 + 2, 1024 + 2, -1, 0.f, GenPC, 1234ULL},
{1024, 512, 10, 100000.f, GenPC, 1234ULL}};

// This needs to be a macro because it has to live in the scope
// of the class whose name is the first parameter of TEST_P.
//
// We test the following.
//
// 1. Output indices are in the given range.
// 2. Output indices do not repeat.
// 3. If there's a skewed distribution, the top index should
// correspond to the particular item with a large weight.
#define _RAFT_SWOR_TEST_CONTENTS() do { \
std::set<int> occurrence; \
for (int i = 0; i < params.sampledLen; ++i) { \
auto val = h_outIdx[i]; \
ASSERT_TRUE(0 <= val && val < params.len) \
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; \
ASSERT_TRUE(occurrence.find(val) == occurrence.end()) \
<< "repeated index @i=" << i << " idx=" << val; \
occurrence.insert(val); \
} \
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } \
} while(false) \

using SWoRTestF = SWoRTest<float>;
TEST_P(SWoRTestF, Result)
{
std::set<int> occurence;
for (int i = 0; i < params.sampledLen; ++i) {
auto val = h_outIdx[i];
// indices must be in the given range
ASSERT_TRUE(0 <= val && val < params.len)
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen;
// indices should not repeat
ASSERT_TRUE(occurence.find(val) == occurence.end())
<< "repeated index @i=" << i << " idx=" << val;
occurence.insert(val);
}
// if there's a skewed distribution, the top index should correspond to the
// particular item with a large weight
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); }
_RAFT_SWOR_TEST_CONTENTS();
}
INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestF, ::testing::ValuesIn(inputsf));

typedef SWoRTest<double> SWoRTestD;
using SWoRMdspanTestF = SWoRMdspanTest<float>;
TEST_P(SWoRMdspanTestF, Result)
{
_RAFT_SWOR_TEST_CONTENTS();
}
INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestF, ::testing::ValuesIn(inputsf));

const std::vector<SWoRInputs<double>> inputsd = {{1024, 512, -1, 0.0, GenPhilox, 1234ULL},
{1024, 1024, -1, 0.0, GenPhilox, 1234ULL},
{1024, 512 + 1, -1, 0.0, GenPhilox, 1234ULL},
Expand Down Expand Up @@ -185,24 +257,19 @@ const std::vector<SWoRInputs<double>> inputsd = {{1024, 512, -1, 0.0, GenPhilox,
{1024 + 2, 1024 + 2, -1, 0.0, GenPC, 1234ULL},
{1024, 512, 10, 100000.0, GenPC, 1234ULL}};

using SWoRTestD = SWoRTest<double>;
TEST_P(SWoRTestD, Result)
{
std::set<int> occurence;
for (int i = 0; i < params.sampledLen; ++i) {
auto val = h_outIdx[i];
// indices must be in the given range
ASSERT_TRUE(0 <= val && val < params.len)
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen;
// indices should not repeat
ASSERT_TRUE(occurence.find(val) == occurence.end())
<< "repeated index @i=" << i << " idx=" << val;
occurence.insert(val);
}
// if there's a skewed distribution, the top index should correspond to the
// particular item with a large weight
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); }
_RAFT_SWOR_TEST_CONTENTS();
}
INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestD, ::testing::ValuesIn(inputsd));

using SWoRMdspanTestD = SWoRMdspanTest<double>;
TEST_P(SWoRMdspanTestD, Result)
{
_RAFT_SWOR_TEST_CONTENTS();
}
INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestD, ::testing::ValuesIn(inputsd));

} // namespace random
} // namespace raft

0 comments on commit 3ae4e34

Please sign in to comment.