Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] mdspanify sampleWithoutReplacement #830

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 135 additions & 8 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
#include "detail/rng_impl.cuh"
#include "detail/rng_impl_deprecated.cuh" // necessary for now (to be removed)
#include "rng_state.hpp"
#include <cassert>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <type_traits>

namespace raft::random {

Expand Down Expand Up @@ -340,20 +344,143 @@ void laplace(const raft::handle_t& handle,
}

/**
* @brief Sample the input array without replacement, optionally based on the
* input weight vector for each element in the array
* @brief Sample the input vector without replacement, optionally based on the
* input weight vector for each element in the array.
*
* Implementation here is based on the `one-pass sampling` algo described here:
* https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
* The implementation is based on the `one-pass sampling` algorithm described in
* ["Accelerating weighted random sampling without
* replacement,"](https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf)
* a technical report by Kirill Mueller.
*
* @note In the sampled array the elements which are picked will always appear
* in the increasing order of their weights as computed using the exponential
* distribution. So, if you're particular about the order (for eg. array
* permutations), then this might not be the right choice!
* If no input weight vector is provided, then input elements will be
* sampled uniformly. Otherwise, the elements sampled from the input
* vector will always appear in increasing order of their weights as
* computed using the exponential distribution. So, if you are
* particular about the order (for e.g., array permutations), then
* this might not be the right choice.
*
* @tparam DataT type of each element of the input array @c in
* @tparam IdxT type of the dimensions of the arrays; output index type
* @tparam WeightsT type of each elements of the weights array @c wts
*
* @note Please do not specify template parameters explicitly,
* as the compiler can deduce them from the arguments.
*
* @param[in] handle RAFT handle containing (among other resources)
* the CUDA stream on which to run.
* @param[inout] rng_state Pseudorandom number generator state.
* @param[in] in Input vector to be sampled.
* @param[in] wts Optional weights vector.
* If not provided, uniform sampling will be used.
* @param[out] out Vector of samples from the input vector.
* @param[out] outIdx If provided, vector of the indices
* sampled from the input array.
*
* @pre The number of samples `out.extent(0)`
* is less than or equal to the number of inputs `in.extent(0)`.
*
* @pre The number of weights `wts.extent(0)`
* equals the number of inputs `in.extent(0)`.
*/
template <typename DataT, typename IdxT, typename WeightsT>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx)
{
static_assert(std::is_integral<IdxT>::value, "IdxT must be an integral type.");
const IdxT sampledLen = out.extent(0);
const IdxT len = in.extent(0);
RAFT_EXPECTS(sampledLen <= len,
"sampleWithoutReplacement: "
"sampledLen (out.extent(0)) must be <= len (in.extent(0))");
RAFT_EXPECTS(len == 0 || in.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If in.extents(0) != 0, then in.data_handle() must be nonnull");
RAFT_EXPECTS(sampledLen == 0 || out.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If out.extents(0) != 0, then out.data_handle() must be nonnull");

const bool outIdx_has_value = outIdx.has_value();
if (outIdx_has_value) {
RAFT_EXPECTS((*outIdx).extent(0) == sampledLen,
"sampleWithoutReplacement: "
"If outIdx is provided, its extent(0) must equal out.extent(0)");
}
IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr;

const bool wts_has_value = wts.has_value();
if (wts_has_value) {
RAFT_EXPECTS((*wts).extent(0) == len,
"sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
}
const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;

detail::sampleWithoutReplacement(rng_state,
out.data_handle(),
outIdx_ptr,
in.data_handle(),
wts_ptr,
sampledLen,
len,
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Overload of `sample_without_replacement` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sample_without_replacement`.
*/
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
WeightsVectorType&& wts,
raft::device_vector_view<DataT, IdxT> out,
OutIndexVectorType&& outIdx)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;
std::optional<raft::device_vector_view<const weight_type, IdxT>> weights =
std::forward<WeightsVectorType>(wts);
std::optional<raft::device_vector_view<IdxT, IdxT>> output_indices =
std::forward<OutIndexVectorType>(outIdx);

sample_without_replacement(handle, rng_state, in, weights, out, output_indices);
}

/**
* @brief Legacy version of @c sample_without_replacement (see above)
* that takes raw arrays instead of device mdspan.
*
* @tparam DataT data type
* @tparam WeightsT weights type
* @tparam IdxT index type
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output sampled array (of length 'sampledLen')
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/random/rng_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#pragma once

#include <cstdint>

namespace raft {
namespace random {

Expand Down
4 changes: 2 additions & 2 deletions cpp/test/nvtx.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
*/
#ifdef NVTX_ENABLED
#include <gtest/gtest.h>
#include <raft/common/detail/nvtx.hpp>
#include <raft/core/detail/nvtx.hpp>
/**
* tests for the functionality of generating next color based on string
* entered in the NVTX Range marker wrappers
Expand Down
142 changes: 105 additions & 37 deletions cpp/test/random/sample_without_replacement.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SWoRTest : public ::testing::TestWithParam<SWoRInputs<T>> {
}
sampleWithoutReplacement(
handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len);
update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream);
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
}

Expand All @@ -84,7 +84,75 @@ class SWoRTest : public ::testing::TestWithParam<SWoRInputs<T>> {
std::vector<int> h_outIdx;
};

typedef SWoRTest<float> SWoRTestF;
template <typename T>
class SWoRMdspanTest : public ::testing::TestWithParam<SWoRInputs<T>> {
public:
SWoRMdspanTest()
: params(::testing::TestWithParam<SWoRInputs<T>>::GetParam()),
stream(handle.get_stream()),
in(params.len, stream),
wts(params.len, stream),
out(params.sampledLen, stream),
out2(params.sampledLen, stream),
outIdx(params.sampledLen, stream),
outIdx2(params.sampledLen, stream)
{
}

protected:
void SetUp() override
{
RngState r(params.seed, params.gtype);
h_outIdx.resize(params.sampledLen);
uniform(handle, r, in.data(), params.len, T(-1.0), T(1.0));
uniform(handle, r, wts.data(), params.len, T(1.0), T(2.0));
if (params.largeWeightIndex >= 0) {
update_device(wts.data() + params.largeWeightIndex, &params.largeWeight, 1, stream);
}
{
using index_type = int;
using output_view = raft::device_vector_view<T, index_type>;
output_view out_view{out.data(), out.size()};
ASSERT_TRUE(out_view.extent(0) == params.sampledLen);

using output_idxs_view = raft::device_vector_view<index_type, index_type>;
std::optional<output_idxs_view> outIdx_view{std::in_place, outIdx.data(), outIdx.size()};
ASSERT_TRUE(outIdx_view.value().extent(0) == params.sampledLen);

using input_view = raft::device_vector_view<const T, index_type>;
input_view in_view{in.data(), in.size()};
ASSERT_TRUE(in_view.extent(0) == params.len);

using weights_view = raft::device_vector_view<const T, index_type>;
std::optional<weights_view> wts_view{std::in_place, wts.data(), wts.size()};
ASSERT_TRUE(wts_view.value().extent(0) == params.len);

sample_without_replacement(handle, r, in_view, wts_view, out_view, outIdx_view);

output_view out2_view{out2.data(), out2.size()};
ASSERT_TRUE(out2_view.extent(0) == params.sampledLen);
std::optional<output_idxs_view> outIdx2_view{std::in_place, outIdx2.data(), outIdx2.size()};
ASSERT_TRUE(outIdx2_view.value().extent(0) == params.sampledLen);

// For now, just test that these calls compile.
sample_without_replacement(handle, r, in_view, wts_view, out2_view, std::nullopt);
sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, outIdx2_view);
sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, std::nullopt);
}
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
}

protected:
raft::handle_t handle;
cudaStream_t stream;

SWoRInputs<T> params;
rmm::device_uvector<T> in, out, wts, out2;
rmm::device_uvector<int> outIdx, outIdx2;
std::vector<int> h_outIdx;
};

const std::vector<SWoRInputs<float>> inputsf = {{1024, 512, -1, 0.f, GenPhilox, 1234ULL},
{1024, 1024, -1, 0.f, GenPhilox, 1234ULL},
{1024, 512 + 1, -1, 0.f, GenPhilox, 1234ULL},
Expand Down Expand Up @@ -125,26 +193,37 @@ const std::vector<SWoRInputs<float>> inputsf = {{1024, 512, -1, 0.f, GenPhilox,
{1024 + 2, 1024 + 2, -1, 0.f, GenPC, 1234ULL},
{1024, 512, 10, 100000.f, GenPC, 1234ULL}};

TEST_P(SWoRTestF, Result)
{
std::set<int> occurence;
for (int i = 0; i < params.sampledLen; ++i) {
auto val = h_outIdx[i];
// indices must be in the given range
ASSERT_TRUE(0 <= val && val < params.len)
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen;
// indices should not repeat
ASSERT_TRUE(occurence.find(val) == occurence.end())
<< "repeated index @i=" << i << " idx=" << val;
occurence.insert(val);
}
// if there's a skewed distribution, the top index should correspond to the
// particular item with a large weight
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); }
}
// This needs to be a macro because it has to live in the scope
// of the class whose name is the first parameter of TEST_P.
//
// We test the following.
//
// 1. Output indices are in the given range.
// 2. Output indices do not repeat.
// 3. If there's a skewed distribution, the top index should
// correspond to the particular item with a large weight.
#define _RAFT_SWOR_TEST_CONTENTS() \
do { \
std::set<int> occurrence; \
for (int i = 0; i < params.sampledLen; ++i) { \
auto val = h_outIdx[i]; \
ASSERT_TRUE(0 <= val && val < params.len) \
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen; \
ASSERT_TRUE(occurrence.find(val) == occurrence.end()) \
<< "repeated index @i=" << i << " idx=" << val; \
occurrence.insert(val); \
} \
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); } \
} while (false)

using SWoRTestF = SWoRTest<float>;
TEST_P(SWoRTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestF, ::testing::ValuesIn(inputsf));

typedef SWoRTest<double> SWoRTestD;
using SWoRMdspanTestF = SWoRMdspanTest<float>;
TEST_P(SWoRMdspanTestF, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestF, ::testing::ValuesIn(inputsf));

const std::vector<SWoRInputs<double>> inputsd = {{1024, 512, -1, 0.0, GenPhilox, 1234ULL},
{1024, 1024, -1, 0.0, GenPhilox, 1234ULL},
{1024, 512 + 1, -1, 0.0, GenPhilox, 1234ULL},
Expand Down Expand Up @@ -185,24 +264,13 @@ const std::vector<SWoRInputs<double>> inputsd = {{1024, 512, -1, 0.0, GenPhilox,
{1024 + 2, 1024 + 2, -1, 0.0, GenPC, 1234ULL},
{1024, 512, 10, 100000.0, GenPC, 1234ULL}};

TEST_P(SWoRTestD, Result)
{
std::set<int> occurence;
for (int i = 0; i < params.sampledLen; ++i) {
auto val = h_outIdx[i];
// indices must be in the given range
ASSERT_TRUE(0 <= val && val < params.len)
<< "out-of-range index @i=" << i << " val=" << val << " sampledLen=" << params.sampledLen;
// indices should not repeat
ASSERT_TRUE(occurence.find(val) == occurence.end())
<< "repeated index @i=" << i << " idx=" << val;
occurence.insert(val);
}
// if there's a skewed distribution, the top index should correspond to the
// particular item with a large weight
if (params.largeWeightIndex >= 0) { ASSERT_EQ(h_outIdx[0], params.largeWeightIndex); }
}
using SWoRTestD = SWoRTest<double>;
TEST_P(SWoRTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests, SWoRTestD, ::testing::ValuesIn(inputsd));

using SWoRMdspanTestD = SWoRMdspanTest<double>;
TEST_P(SWoRMdspanTestD, Result) { _RAFT_SWOR_TEST_CONTENTS(); }
INSTANTIATE_TEST_SUITE_P(SWoRTests2, SWoRMdspanTestD, ::testing::ValuesIn(inputsd));

} // namespace random
} // namespace raft