Skip to content

Commit

Permalink
mdspan-ify permute
Browse files Browse the repository at this point in the history
Add overloads of permute that take arrays as mdspan
(std::optional<mdspan<...>> if the array is an optional
output argument).

Improve documentation of the existing overload
and document the new overloads.  Add tests.
Make sure that the function compiles if users pass in
std::nullopt for either or both optional arguments.
  • Loading branch information
mhoemmen committed Sep 22, 2022
1 parent 02017b6 commit 6b8a6ce
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 38 deletions.
176 changes: 154 additions & 22 deletions cpp/include/raft/random/permute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,162 @@

#include "detail/permute.cuh"

#include <optional>
#include <raft/mdarray.hpp>
#include <type_traits>

namespace raft::random {

/**
* @brief Generate permutations of the input array. Pretty useful primitive for
* shuffling the input datasets in ML algos. See note at the end for some of its
* limitations!
* @tparam Type Data type of the array to be shuffled
* @tparam IntType Integer type used for ther perms array
* @tparam IdxType Integer type used for addressing indices
* @tparam TPB threads per block
* @param perms the output permutation indices. Typically useful only when
* one wants to refer back. If you don't need this, pass a nullptr
* @param out the output shuffled array. Pass nullptr if you don't want this to
* be written. For eg: when you only want the perms array to be filled.
* @param in input array (in-place is not supported due to race conditions!)
* @param D number of columns of the input array
* @param N length of the input array (or number of rows)
* @param rowMajor whether the input/output matrices are row or col major
* @param stream cuda stream where to launch the work
*
* @note This is NOT a uniform permutation generator! In fact, it only generates
* very small percentage of permutations. If your application really requires a
* high quality permutation generator, it is recommended that you pick
* Knuth Shuffle.
* @brief Randomly permute the rows of the input matrix.
*
* We do not support in-place permutation, so that we can compute
* in parallel without race conditions. This function is useful
* for shuffling input data sets in machine learning algorithms.
*
* @tparam InputOutputValueType Type of each element of the input matrix,
* and the type of each element of the output matrix (if provided)
* @tparam IntType Integer type of each element of `permsOut`
* @tparam IdxType Integer type of the extents of the mdspan parameters
* @tparam Layout Either `raft::row_major` or `raft::col_major`
*
* @param[in] handle RAFT handle containing the CUDA stream
* on which to run.
* @param[out] permsOut If provided, the indices of the permutation.
* @param[out] out If provided, the output matrix, containing the
* permuted rows of the input matrix `in`. (Not providing this
* is only useful if you provide `permsOut`.)
* @param[in] in input matrix
*
* @pre If `permsOut.has_value()` is `true`,
* then `(*permsOut).extent(0) == in.extent(0)` is `true`.
*
* @pre If `out.has_value()` is `true`,
* then `(*out).extents() == in.extents()` is `true`.
*
* @note This is NOT a uniform permutation generator!
* It only generates a small fraction of all possible random permutations.
* If your application needs a high-quality permutation generator,
* then we recommend Knuth Shuffle.
*/
template <typename InputOutputValueType, typename IntType, typename IdxType, typename Layout>
void permute(const raft::handle_t& handle,
raft::device_matrix_view<const InputOutputValueType, IdxType, Layout> in,
std::optional<raft::device_vector_view<IntType, IdxType>> permsOut,
std::optional<raft::device_matrix_view<InputOutputValueType, IdxType, Layout>> out)
{
static_assert(std::is_integral_v<IntType>,
"permute: The type of each element "
"of permsOut (if provided) must be an integral type.");
static_assert(std::is_integral_v<IdxType>,
"permute: The index type "
"of each mdspan argument must be an integral type.");
constexpr bool is_row_major = std::is_same_v<Layout, raft::row_major>;
constexpr bool is_col_major = std::is_same_v<Layout, raft::col_major>;
static_assert(is_row_major || is_col_major,
"permute: Layout must be either "
"raft::row_major or raft::col_major (or one of their aliases)");

const bool permsOut_has_value = permsOut.has_value();
const bool out_has_value = out.has_value();

RAFT_EXPECTS(!permsOut_has_value || (*permsOut).extent(0) == in.extent(0),
"permute: If 'permsOut' is provided, then its extent(0) "
"must equal the number of rows of the input matrix 'in'.");
RAFT_EXPECTS(!out_has_value || (*out).extents() == in.extents(),
"permute: If 'out' is provided, then both its extents "
"must match the extents of the input matrix 'in'.");

IntType* permsOut_ptr = permsOut_has_value ? (*permsOut).data_handle() : nullptr;
InputOutputValueType* out_ptr = out_has_value ? (*out).data_handle() : nullptr;

if (permsOut_ptr != nullptr || out_ptr != nullptr) {
const IdxType N = in.extent(0);
const IdxType D = in.extent(1);
detail::permute<InputOutputValueType, IntType, IdxType>(
permsOut_ptr, out_ptr, in.data_handle(), D, N, is_row_major, handle.get_stream());
}
}

namespace permute_impl {

template <typename T, typename InputOutputValueType, typename IdxType, typename Layout>
struct perms_out_view {
};

template <typename InputOutputValueType, typename IdxType, typename Layout>
struct perms_out_view<std::nullopt_t, InputOutputValueType, IdxType, Layout> {
// permsOut won't have a value anyway,
// so we can pick any integral value type we want.
using type = raft::device_vector_view<IdxType, IdxType>;
};

template <typename PermutationIndexType,
typename InputOutputValueType,
typename IdxType,
typename Layout>
struct perms_out_view<std::optional<raft::device_vector_view<PermutationIndexType, IdxType>>,
InputOutputValueType,
IdxType,
Layout> {
using type = raft::device_vector_view<PermutationIndexType, IdxType>;
};

template <typename T, typename InputOutputValueType, typename IdxType, typename Layout>
using perms_out_view_t = typename perms_out_view<T, InputOutputValueType, IdxType, Layout>::type;

} // namespace permute_impl

/**
* @brief Overload of `permute` that compiles if users pass in `std::nullopt`
* for either or both of `permsOut` and `out`.
*/
template <typename InputOutputValueType,
typename IdxType,
typename Layout,
typename PermsOutType,
typename OutType>
void permute(const raft::handle_t& handle,
raft::device_matrix_view<const InputOutputValueType, IdxType, Layout> in,
PermsOutType&& permsOut,
OutType&& out)
{
// If PermsOutType is std::optional<device_vector_view<T, IdxType>>
// for some T, then that type T need not be related to any of the
// other template parameters. Thus, we have to deduce it specially.
using perms_out_view_type = permute_impl::
perms_out_view_t<std::decay_t<PermsOutType>, InputOutputValueType, IdxType, Layout>;
using out_view_type = raft::device_matrix_view<InputOutputValueType, IdxType, Layout>;

static_assert(std::is_same_v<std::decay_t<OutType>, std::nullopt_t> ||
std::is_same_v<std::decay_t<OutType>, std::optional<out_view_type>>,
"permute: The type of 'out' must be either std::optional<"
"raft::device_matrix_view<InputOutputViewType, IdxType, Layout>>, "
"or std::nullopt.");

std::optional<perms_out_view_type> permsOut_arg = std::forward<PermsOutType>(permsOut);
std::optional<out_view_type> out_arg = std::forward<OutType>(out);
permute(handle, in, permsOut_arg, out_arg);
}

/**
* @brief Legacy overload of `permute` that takes raw arrays instead of mdspan.
*
* @tparam Type Type of each element of the input matrix to be permuted
* @tparam IntType Integer type of each element of the permsOut matrix
* @tparam IdxType Integer type of the dimensions of the matrices
* @tparam TPB threads per block (do not use any value other than the default)
*
* @param[out] perms If nonnull, the indices of the permutation
* @param[out] out If nonnull, the output matrix, containing the
* permuted rows of the input matrix @c in. (Not providing this
* is only useful if you provide @c perms.)
* @param[in] in input matrix
* @param[in] D number of columns in the matrices
* @param[in] N number of rows in the matrices
* @param[in] rowMajor true if the matrices are row major,
* false if they are column major
* @param[in] stream CUDA stream on which to run
*/
template <typename Type, typename IntType = int, typename IdxType = int, int TPB = 256>
void permute(IntType* perms,
Expand All @@ -60,4 +192,4 @@ void permute(IntType* perms,

}; // end namespace raft::random

#endif
#endif
141 changes: 125 additions & 16 deletions cpp/test/random/permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ template <typename T>

template <typename T>
class PermTest : public ::testing::TestWithParam<PermInputs<T>> {
public:
using test_data_type = T;

protected:
PermTest()
: in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream())
Expand Down Expand Up @@ -81,6 +84,89 @@ class PermTest : public ::testing::TestWithParam<PermInputs<T>> {
int* outPerms_ptr = nullptr;
};

template <typename T>
class PermMdspanTest : public ::testing::TestWithParam<PermInputs<T>> {
public:
using test_data_type = T;

protected:
PermMdspanTest()
: in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream())
{
}

private:
using index_type = int;

template <class ElementType, class Layout>
using matrix_view_t = raft::device_matrix_view<ElementType, index_type, Layout>;

template <class ElementType>
using vector_view_t = raft::device_vector_view<ElementType, index_type>;

protected:
void SetUp() override
{
auto stream = handle.get_stream();
params = ::testing::TestWithParam<PermInputs<T>>::GetParam();
// forcefully set needPerms, since we need it for unit-testing!
if (params.needShuffle) { params.needPerms = true; }
raft::random::RngState r(params.seed);
int N = params.N;
int D = params.D;
int len = N * D;
if (params.needPerms) {
outPerms.resize(N, stream);
outPerms_ptr = outPerms.data();
}
if (params.needShuffle) {
in.resize(len, stream);
out.resize(len, stream);
in_ptr = in.data();
out_ptr = out.data();
uniform(handle, r, in_ptr, len, T(-1.0), T(1.0));
}

auto set_up_views_and_test = [&](auto layout) {
using layout_type = std::decay_t<decltype(layout)>;

matrix_view_t<const T, layout_type> in_view(in_ptr, N, D);
std::optional<matrix_view_t<T, layout_type>> out_view;
if (out_ptr != nullptr) { out_view.emplace(out_ptr, N, D); }
std::optional<vector_view_t<index_type>> outPerms_view;
if (outPerms_ptr != nullptr) { outPerms_view.emplace(outPerms_ptr, N); }

permute(handle, in_view, outPerms_view, out_view);

// None of these three permute calls should have an effect.
// The point is to test whether the function can deduce the
// element type of outPerms if given nullopt.
std::optional<matrix_view_t<T, layout_type>> out_view_empty;
std::optional<vector_view_t<index_type>> outPerms_view_empty;
permute(handle, in_view, std::nullopt, out_view_empty);
permute(handle, in_view, outPerms_view_empty, std::nullopt);
permute(handle, in_view, std::nullopt, std::nullopt);
};

if (params.rowMajor) {
set_up_views_and_test(raft::row_major{});
} else {
set_up_views_and_test(raft::col_major{});
}

handle.sync_stream();
}

protected:
raft::handle_t handle;
PermInputs<T> params;
rmm::device_uvector<T> in, out;
T* in_ptr = nullptr;
T* out_ptr = nullptr;
rmm::device_uvector<int> outPerms;
int* outPerms_ptr = nullptr;
};

template <typename T, typename L>
::testing::AssertionResult devArrMatchRange(
const T* actual, size_t size, T start, L eq_compare, bool doSort = true, cudaStream_t stream = 0)
Expand Down Expand Up @@ -169,19 +255,38 @@ const std::vector<PermInputs<float>> inputsf = {
{100000, 32, true, true, false, 1234567890ULL},
{100001, 33, true, true, false, 1234567890ULL}};

typedef PermTest<float> PermTestF;
#define _PERMTEST_BODY(DATA_TYPE) \
do { \
if (params.needPerms) { \
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>())); \
} \
if (params.needShuffle) { \
ASSERT_TRUE(devArrMatchShuffle(outPerms_ptr, \
out_ptr, \
in_ptr, \
params.D, \
params.N, \
params.rowMajor, \
raft::Compare<DATA_TYPE>())); \
} \
} while (false)

using PermTestF = PermTest<float>;
TEST_P(PermTestF, Result)
{
if (params.needPerms) {
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>()));
}
if (params.needShuffle) {
ASSERT_TRUE(devArrMatchShuffle(
outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare<float>()));
}
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermTests, PermTestF, ::testing::ValuesIn(inputsf));

using PermMdspanTestF = PermMdspanTest<float>;
TEST_P(PermMdspanTestF, Result)
{
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestF, ::testing::ValuesIn(inputsf));

const std::vector<PermInputs<double>> inputsd = {
// only generate permutations
{32, 8, true, false, true, 1234ULL},
Expand Down Expand Up @@ -219,18 +324,22 @@ const std::vector<PermInputs<double>> inputsd = {
{100000, 32, true, true, false, 1234ULL},
{100000, 32, true, true, false, 1234567890ULL},
{100001, 33, true, true, false, 1234567890ULL}};
typedef PermTest<double> PermTestD;

using PermTestD = PermTest<double>;
TEST_P(PermTestD, Result)
{
if (params.needPerms) {
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>()));
}
if (params.needShuffle) {
ASSERT_TRUE(devArrMatchShuffle(
outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare<double>()));
}
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermTests, PermTestD, ::testing::ValuesIn(inputsd));

using PermMdspanTestD = PermMdspanTest<double>;
TEST_P(PermMdspanTestD, Result)
{
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestD, ::testing::ValuesIn(inputsd));

} // end namespace random
} // end namespace raft

0 comments on commit 6b8a6ce

Please sign in to comment.