Skip to content

Commit

Permalink
Address review feedback
Browse files Browse the repository at this point in the history
Rename new sampleWithoutReplacement overload to
sample_without_replacement.

Fix order of parameters: handle, rng, in, out, params.

Use RAFT_EXPECTS instead of ASSERT.

Make sure that users can pass in std::nullopt explicitly
for either or both std::optional parameters.

Use "vector" instead of "array" in documentation,
use [in], [out], and [inout] as appropriate,
and make other documentation improvements.

Make error messages more clear.
  • Loading branch information
mhoemmen committed Sep 22, 2022
1 parent c391f9e commit 5a0f012
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 89 deletions.
211 changes: 129 additions & 82 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -344,108 +344,78 @@ void laplace(const raft::handle_t& handle,
}

/**
* @brief Sample the input array without replacement, optionally based on the
* input weight vector for each element in the array
* @brief Sample the input vector without replacement, optionally based on the
* input weight vector for each element in the array.
*
* Implementation here is based on the `one-pass sampling` algo described here:
* https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
* The implementation is based on the `one-pass sampling` algorithm described in
* ["Accelerating weighted random sampling without
* replacement,"](https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf)
* a technical report by Kirill Mueller.
*
* @note In the sampled array the elements which are picked will always appear
* in the increasing order of their weights as computed using the exponential
* distribution. So, if you're particular about the order (for eg. array
* permutations), then this might not be the right choice!
* If no input weight vector is provided, then input elements will be
* sampled uniformly. Otherwise, the elements sampled from the input
* vector will always appear in increasing order of their weights as
* computed using the exponential distribution. So, if you are
* particular about the order (for e.g., array permutations), then
* this might not be the right choice.
*
* @tparam DataT data type
* @tparam WeightsT weights type
* @tparam IdxT index type
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output sampled array (of length 'sampledLen')
* @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass
* a nullptr if this is not required.
* @param[in] in input array to be sampled (of length 'len')
* @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform
* sampling is desired
* @param[in] sampledLen output sampled array length
* @param[in] len input array length
*/
template <typename DataT, typename WeightsT, typename IdxT = int>
void sampleWithoutReplacement(const raft::handle_t& handle,
RngState& rng_state,
DataT* out,
IdxT* outIdx,
const DataT* in,
const WeightsT* wts,
IdxT sampledLen,
IdxT len)
{
detail::sampleWithoutReplacement(
rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream());
}

/**
* @brief Sample the input array without replacement, optionally based on the
* input weight vector for each element in the array
* @tparam DataT type of each element of the input array @c in
* @tparam IdxT type of the dimensions of the arrays; output index type
* @tparam WeightsT type of each elements of the weights array @c wts
*
* Implementation here is based on the `one-pass sampling` algo described here:
* https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
* @note Please do not specify template parameters explicitly,
* as the compiler can deduce them from the arguments.
*
* @note In the sampled array the elements which are picked will always appear
* in the increasing order of their weights as computed using the exponential
* distribution. So, if you're particular about the order (for eg. array
* permutations), then this might not be the right choice!
* @param[in] handle RAFT handle containing (among other resources)
* the CUDA stream on which to run.
* @param[inout] rng_state Pseudorandom number generator state.
* @param[in] in Input vector to be sampled.
* @param[in] wts Optional weights vector.
* If not provided, uniform sampling will be used.
* @param[out] out Vector of samples from the input vector.
* @param[out] outIdx If provided, vector of the indices
* sampled from the input array.
*
* @tparam DataT data type (do not specify this explicitly,
* as the compiler will deduce it from @c out )
* @tparam IdxT index type (do not specify this explicitly,
* as the compiler will deduce it from the arguments)
* @tparam WeightsT weights type (defaults to @c double ,
* so that passing in std::nullopt works)
* @pre The number of samples `out.extent(0)`
* is less than or equal to the number of inputs `in.extent(0)`.
*
* @param handle
* @param rng_state random number generator state
* @param out Array (of length 'sampledLen')
* of samples from the input array
* @param outIdx If provided, array (of length 'sampledLen')
* of the indices sampled from the input array
* @param in Input array to be sampled (of length 'len')
* @param wts Weights array (of length 'len').
* If not provided, uniform sampling will be used.
* @pre The number of weights `wts.extent(0)`
* equals the number of inputs `in.extent(0)`.
*/
template <typename DataT, typename IdxT, typename WeightsT = double>
void sampleWithoutReplacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts)
template <typename DataT, typename IdxT, typename WeightsT>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
std::optional<raft::device_vector_view<const WeightsT, IdxT>> wts,
raft::device_vector_view<DataT, IdxT> out,
std::optional<raft::device_vector_view<IdxT, IdxT>> outIdx)
{
static_assert(std::is_integral<IdxT>::value, "IdxT must be an integral type.");
const IdxT sampledLen = out.extent(0);
const IdxT len = in.extent(0);
ASSERT(sampledLen <= len,
"sampleWithoutReplacement: "
"sampledLen (out.extent(0)) must be <= len (in.extent(0))");
ASSERT(len == 0 || in.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If in.data_handle() is not null, then in.extent(0) must be nonzero");
ASSERT(sampledLen == 0 || out.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If out.data_handle() is not null, then out.extent(0) must be nonzero");
RAFT_EXPECTS(sampledLen <= len,
"sampleWithoutReplacement: "
"sampledLen (out.extent(0)) must be <= len (in.extent(0))");
RAFT_EXPECTS(len == 0 || in.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If in.extents(0) != 0, then in.data_handle() must be nonnull");
RAFT_EXPECTS(sampledLen == 0 || out.data_handle() != nullptr,
"sampleWithoutReplacement: "
"If out.extents(0) != 0, then out.data_handle() must be nonnull");

const bool outIdx_has_value = outIdx.has_value();
if (outIdx_has_value) {
ASSERT((*outIdx).extent(0) == sampledLen,
"sampleWithoutReplacement: "
"If outIdx is provided, its extent(0) must equal out.extent(0)");
RAFT_EXPECTS((*outIdx).extent(0) == sampledLen,
"sampleWithoutReplacement: "
"If outIdx is provided, its extent(0) must equal out.extent(0)");
}
IdxT* outIdx_ptr = outIdx_has_value ? (*outIdx).data_handle() : nullptr;

const bool wts_has_value = wts.has_value();
if (wts_has_value) {
ASSERT((*wts).extent(0) == len,
"sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
RAFT_EXPECTS((*wts).extent(0) == len,
"sampleWithoutReplacement: "
"If wts is provided, its extent(0) must equal in.extent(0)");
}
const WeightsT* wts_ptr = wts_has_value ? (*wts).data_handle() : nullptr;

Expand All @@ -459,6 +429,83 @@ void sampleWithoutReplacement(const raft::handle_t& handle,
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
};

template <>
struct weight_alias<std::nullopt_t> {
using type = double;
};

template <typename ElementType, typename IndexType>
struct weight_alias<std::optional<raft::device_vector_view<ElementType, IndexType>>> {
using type = typename raft::device_vector_view<ElementType, IndexType>::value_type;
};

template <typename T>
using weight_t = typename weight_alias<T>::type;
} // namespace sample_without_replacement_impl

/**
* @brief Overload of `sample_without_replacement` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sample_without_replacement`.
*/
template <typename DataT, typename IdxT, typename WeightsVectorType, class OutIndexVectorType>
void sample_without_replacement(const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<const DataT, IdxT> in,
WeightsVectorType&& wts,
raft::device_vector_view<DataT, IdxT> out,
OutIndexVectorType&& outIdx)
{
using weight_type = sample_without_replacement_impl::weight_t<
std::remove_const_t<std::remove_reference_t<WeightsVectorType>>>;
std::optional<raft::device_vector_view<const weight_type, IdxT>> weights =
std::forward<WeightsVectorType>(wts);
std::optional<raft::device_vector_view<IdxT, IdxT>> output_indices =
std::forward<OutIndexVectorType>(outIdx);

sample_without_replacement(handle, rng_state, in, weights, out, output_indices);
}

/**
* @brief Legacy version of @c sample_without_replacement (see above)
* that takes raw arrays instead of device mdspan.
*
* @tparam DataT data type
* @tparam WeightsT weights type
* @tparam IdxT index type
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output sampled array (of length 'sampledLen')
* @param[out] outIdx indices of the sampled array (of length 'sampledLen'). Pass
* a nullptr if this is not required.
* @param[in] in input array to be sampled (of length 'len')
* @param[in] wts weights array (of length 'len'). Pass a nullptr if uniform
* sampling is desired
* @param[in] sampledLen output sampled array length
* @param[in] len input array length
*/
template <typename DataT, typename WeightsT, typename IdxT = int>
void sampleWithoutReplacement(const raft::handle_t& handle,
RngState& rng_state,
DataT* out,
IdxT* outIdx,
const DataT* in,
const WeightsT* wts,
IdxT sampledLen,
IdxT len)
{
detail::sampleWithoutReplacement(
rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream());
}

/**
* @brief Generates the 'a' and 'b' parameters for a modulo affine
* transformation equation: `(ax + b) % n`
Expand Down
26 changes: 19 additions & 7 deletions cpp/test/random/sample_without_replacement.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SWoRTest : public ::testing::TestWithParam<SWoRInputs<T>> {
}
sampleWithoutReplacement(
handle, r, out.data(), outIdx.data(), in.data(), wts.data(), params.sampledLen, params.len);
update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream);
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
}

Expand All @@ -93,7 +93,9 @@ class SWoRMdspanTest : public ::testing::TestWithParam<SWoRInputs<T>> {
in(params.len, stream),
wts(params.len, stream),
out(params.sampledLen, stream),
outIdx(params.sampledLen, stream)
out2(params.sampledLen, stream),
outIdx(params.sampledLen, stream),
outIdx2(params.sampledLen, stream)
{
}

Expand All @@ -115,17 +117,27 @@ class SWoRMdspanTest : public ::testing::TestWithParam<SWoRInputs<T>> {

using output_idxs_view = raft::device_vector_view<index_type, index_type>;
std::optional<output_idxs_view> outIdx_view{std::in_place, outIdx.data(), outIdx.size()};
ASSERT_TRUE(outIdx_view.has_value() && outIdx_view.value().extent(0) == params.sampledLen);
ASSERT_TRUE(outIdx_view.value().extent(0) == params.sampledLen);

using input_view = raft::device_vector_view<const T, index_type>;
input_view in_view{in.data(), in.size()};
ASSERT_TRUE(in_view.extent(0) == params.len);

using weights_view = raft::device_vector_view<const T, index_type>;
std::optional<weights_view> wts_view{std::in_place, wts.data(), wts.size()};
ASSERT_TRUE(wts_view.has_value() && wts_view.value().extent(0) == params.len);
ASSERT_TRUE(wts_view.value().extent(0) == params.len);

sample_without_replacement(handle, r, in_view, wts_view, out_view, outIdx_view);

sampleWithoutReplacement(handle, r, out_view, outIdx_view, in_view, wts_view);
output_view out2_view{out2.data(), out2.size()};
ASSERT_TRUE(out2_view.extent(0) == params.sampledLen);
std::optional<output_idxs_view> outIdx2_view{std::in_place, outIdx2.data(), outIdx2.size()};
ASSERT_TRUE(outIdx2_view.value().extent(0) == params.sampledLen);

// For now, just test that these calls compile.
sample_without_replacement(handle, r, in_view, wts_view, out2_view, std::nullopt);
sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, outIdx2_view);
sample_without_replacement(handle, r, in_view, std::nullopt, out2_view, std::nullopt);
}
update_host(h_outIdx.data(), outIdx.data(), params.sampledLen, stream);
handle.sync_stream(stream);
Expand All @@ -136,8 +148,8 @@ class SWoRMdspanTest : public ::testing::TestWithParam<SWoRInputs<T>> {
cudaStream_t stream;

SWoRInputs<T> params;
rmm::device_uvector<T> in, out, wts;
rmm::device_uvector<int> outIdx;
rmm::device_uvector<T> in, out, wts, out2;
rmm::device_uvector<int> outIdx, outIdx2;
std::vector<int> h_outIdx;
};

Expand Down

0 comments on commit 5a0f012

Please sign in to comment.