From 6b8a6ce90c6a145e8cd568d4bd0ee14139fe9b09 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Tue, 20 Sep 2022 16:40:20 -0700 Subject: [PATCH] mdspan-ify permute Add overloads of permute that take arrays as mdspan (std::optional> if the array is an optional output argument). Improve documentation of the existing overload and document the new overloads. Add tests. Make sure that the function compiles if users pass in std::nullopt for either or both optional arguments. --- cpp/include/raft/random/permute.cuh | 176 ++++++++++++++++++++++++---- cpp/test/random/permute.cu | 141 +++++++++++++++++++--- 2 files changed, 279 insertions(+), 38 deletions(-) diff --git a/cpp/include/raft/random/permute.cuh b/cpp/include/raft/random/permute.cuh index 1c01d589f4..710baa915c 100644 --- a/cpp/include/raft/random/permute.cuh +++ b/cpp/include/raft/random/permute.cuh @@ -21,30 +21,162 @@ #include "detail/permute.cuh" +#include +#include +#include + namespace raft::random { /** - * @brief Generate permutations of the input array. Pretty useful primitive for - * shuffling the input datasets in ML algos. See note at the end for some of its - * limitations! - * @tparam Type Data type of the array to be shuffled - * @tparam IntType Integer type used for ther perms array - * @tparam IdxType Integer type used for addressing indices - * @tparam TPB threads per block - * @param perms the output permutation indices. Typically useful only when - * one wants to refer back. If you don't need this, pass a nullptr - * @param out the output shuffled array. Pass nullptr if you don't want this to - * be written. For eg: when you only want the perms array to be filled. - * @param in input array (in-place is not supported due to race conditions!) - * @param D number of columns of the input array - * @param N length of the input array (or number of rows) - * @param rowMajor whether the input/output matrices are row or col major - * @param stream cuda stream where to launch the work - * - * @note This is NOT a uniform permutation generator! In fact, it only generates - * very small percentage of permutations. If your application really requires a - * high quality permutation generator, it is recommended that you pick - * Knuth Shuffle. + * @brief Randomly permute the rows of the input matrix. + * + * We do not support in-place permutation, so that we can compute + * in parallel without race conditions. This function is useful + * for shuffling input data sets in machine learning algorithms. + * + * @tparam InputOutputValueType Type of each element of the input matrix, + * and the type of each element of the output matrix (if provided) + * @tparam IntType Integer type of each element of `permsOut` + * @tparam IdxType Integer type of the extents of the mdspan parameters + * @tparam Layout Either `raft::row_major` or `raft::col_major` + * + * @param[in] handle RAFT handle containing the CUDA stream + * on which to run. + * @param[out] permsOut If provided, the indices of the permutation. + * @param[out] out If provided, the output matrix, containing the + * permuted rows of the input matrix `in`. (Not providing this + * is only useful if you provide `permsOut`.) + * @param[in] in input matrix + * + * @pre If `permsOut.has_value()` is `true`, + * then `(*permsOut).extent(0) == in.extent(0)` is `true`. + * + * @pre If `out.has_value()` is `true`, + * then `(*out).extents() == in.extents()` is `true`. + * + * @note This is NOT a uniform permutation generator! + * It only generates a small fraction of all possible random permutations. + * If your application needs a high-quality permutation generator, + * then we recommend Knuth Shuffle. + */ +template +void permute(const raft::handle_t& handle, + raft::device_matrix_view in, + std::optional> permsOut, + std::optional> out) +{ + static_assert(std::is_integral_v, + "permute: The type of each element " + "of permsOut (if provided) must be an integral type."); + static_assert(std::is_integral_v, + "permute: The index type " + "of each mdspan argument must be an integral type."); + constexpr bool is_row_major = std::is_same_v; + constexpr bool is_col_major = std::is_same_v; + static_assert(is_row_major || is_col_major, + "permute: Layout must be either " + "raft::row_major or raft::col_major (or one of their aliases)"); + + const bool permsOut_has_value = permsOut.has_value(); + const bool out_has_value = out.has_value(); + + RAFT_EXPECTS(!permsOut_has_value || (*permsOut).extent(0) == in.extent(0), + "permute: If 'permsOut' is provided, then its extent(0) " + "must equal the number of rows of the input matrix 'in'."); + RAFT_EXPECTS(!out_has_value || (*out).extents() == in.extents(), + "permute: If 'out' is provided, then both its extents " + "must match the extents of the input matrix 'in'."); + + IntType* permsOut_ptr = permsOut_has_value ? (*permsOut).data_handle() : nullptr; + InputOutputValueType* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + + if (permsOut_ptr != nullptr || out_ptr != nullptr) { + const IdxType N = in.extent(0); + const IdxType D = in.extent(1); + detail::permute( + permsOut_ptr, out_ptr, in.data_handle(), D, N, is_row_major, handle.get_stream()); + } +} + +namespace permute_impl { + +template +struct perms_out_view { +}; + +template +struct perms_out_view { + // permsOut won't have a value anyway, + // so we can pick any integral value type we want. + using type = raft::device_vector_view; +}; + +template +struct perms_out_view>, + InputOutputValueType, + IdxType, + Layout> { + using type = raft::device_vector_view; +}; + +template +using perms_out_view_t = typename perms_out_view::type; + +} // namespace permute_impl + +/** + * @brief Overload of `permute` that compiles if users pass in `std::nullopt` + * for either or both of `permsOut` and `out`. + */ +template +void permute(const raft::handle_t& handle, + raft::device_matrix_view in, + PermsOutType&& permsOut, + OutType&& out) +{ + // If PermsOutType is std::optional> + // for some T, then that type T need not be related to any of the + // other template parameters. Thus, we have to deduce it specially. + using perms_out_view_type = permute_impl:: + perms_out_view_t, InputOutputValueType, IdxType, Layout>; + using out_view_type = raft::device_matrix_view; + + static_assert(std::is_same_v, std::nullopt_t> || + std::is_same_v, std::optional>, + "permute: The type of 'out' must be either std::optional<" + "raft::device_matrix_view>, " + "or std::nullopt."); + + std::optional permsOut_arg = std::forward(permsOut); + std::optional out_arg = std::forward(out); + permute(handle, in, permsOut_arg, out_arg); +} + +/** + * @brief Legacy overload of `permute` that takes raw arrays instead of mdspan. + * + * @tparam Type Type of each element of the input matrix to be permuted + * @tparam IntType Integer type of each element of the permsOut matrix + * @tparam IdxType Integer type of the dimensions of the matrices + * @tparam TPB threads per block (do not use any value other than the default) + * + * @param[out] perms If nonnull, the indices of the permutation + * @param[out] out If nonnull, the output matrix, containing the + * permuted rows of the input matrix @c in. (Not providing this + * is only useful if you provide @c perms.) + * @param[in] in input matrix + * @param[in] D number of columns in the matrices + * @param[in] N number of rows in the matrices + * @param[in] rowMajor true if the matrices are row major, + * false if they are column major + * @param[in] stream CUDA stream on which to run */ template void permute(IntType* perms, @@ -60,4 +192,4 @@ void permute(IntType* perms, }; // end namespace raft::random -#endif \ No newline at end of file +#endif diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index a0e9f2f25f..32e5540d51 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -40,6 +40,9 @@ template template class PermTest : public ::testing::TestWithParam> { + public: + using test_data_type = T; + protected: PermTest() : in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream()) @@ -81,6 +84,89 @@ class PermTest : public ::testing::TestWithParam> { int* outPerms_ptr = nullptr; }; +template +class PermMdspanTest : public ::testing::TestWithParam> { + public: + using test_data_type = T; + + protected: + PermMdspanTest() + : in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream()) + { + } + + private: + using index_type = int; + + template + using matrix_view_t = raft::device_matrix_view; + + template + using vector_view_t = raft::device_vector_view; + + protected: + void SetUp() override + { + auto stream = handle.get_stream(); + params = ::testing::TestWithParam>::GetParam(); + // forcefully set needPerms, since we need it for unit-testing! + if (params.needShuffle) { params.needPerms = true; } + raft::random::RngState r(params.seed); + int N = params.N; + int D = params.D; + int len = N * D; + if (params.needPerms) { + outPerms.resize(N, stream); + outPerms_ptr = outPerms.data(); + } + if (params.needShuffle) { + in.resize(len, stream); + out.resize(len, stream); + in_ptr = in.data(); + out_ptr = out.data(); + uniform(handle, r, in_ptr, len, T(-1.0), T(1.0)); + } + + auto set_up_views_and_test = [&](auto layout) { + using layout_type = std::decay_t; + + matrix_view_t in_view(in_ptr, N, D); + std::optional> out_view; + if (out_ptr != nullptr) { out_view.emplace(out_ptr, N, D); } + std::optional> outPerms_view; + if (outPerms_ptr != nullptr) { outPerms_view.emplace(outPerms_ptr, N); } + + permute(handle, in_view, outPerms_view, out_view); + + // None of these three permute calls should have an effect. + // The point is to test whether the function can deduce the + // element type of outPerms if given nullopt. + std::optional> out_view_empty; + std::optional> outPerms_view_empty; + permute(handle, in_view, std::nullopt, out_view_empty); + permute(handle, in_view, outPerms_view_empty, std::nullopt); + permute(handle, in_view, std::nullopt, std::nullopt); + }; + + if (params.rowMajor) { + set_up_views_and_test(raft::row_major{}); + } else { + set_up_views_and_test(raft::col_major{}); + } + + handle.sync_stream(); + } + + protected: + raft::handle_t handle; + PermInputs params; + rmm::device_uvector in, out; + T* in_ptr = nullptr; + T* out_ptr = nullptr; + rmm::device_uvector outPerms; + int* outPerms_ptr = nullptr; +}; + template ::testing::AssertionResult devArrMatchRange( const T* actual, size_t size, T start, L eq_compare, bool doSort = true, cudaStream_t stream = 0) @@ -169,19 +255,38 @@ const std::vector> inputsf = { {100000, 32, true, true, false, 1234567890ULL}, {100001, 33, true, true, false, 1234567890ULL}}; -typedef PermTest PermTestF; +#define _PERMTEST_BODY(DATA_TYPE) \ + do { \ + if (params.needPerms) { \ + ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare())); \ + } \ + if (params.needShuffle) { \ + ASSERT_TRUE(devArrMatchShuffle(outPerms_ptr, \ + out_ptr, \ + in_ptr, \ + params.D, \ + params.N, \ + params.rowMajor, \ + raft::Compare())); \ + } \ + } while (false) + +using PermTestF = PermTest; TEST_P(PermTestF, Result) { - if (params.needPerms) { - ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare())); - } - if (params.needShuffle) { - ASSERT_TRUE(devArrMatchShuffle( - outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare())); - } + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); } INSTANTIATE_TEST_CASE_P(PermTests, PermTestF, ::testing::ValuesIn(inputsf)); +using PermMdspanTestF = PermMdspanTest; +TEST_P(PermMdspanTestF, Result) +{ + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); +} +INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = { // only generate permutations {32, 8, true, false, true, 1234ULL}, @@ -219,18 +324,22 @@ const std::vector> inputsd = { {100000, 32, true, true, false, 1234ULL}, {100000, 32, true, true, false, 1234567890ULL}, {100001, 33, true, true, false, 1234567890ULL}}; -typedef PermTest PermTestD; + +using PermTestD = PermTest; TEST_P(PermTestD, Result) { - if (params.needPerms) { - ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare())); - } - if (params.needShuffle) { - ASSERT_TRUE(devArrMatchShuffle( - outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare())); - } + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); } INSTANTIATE_TEST_CASE_P(PermTests, PermTestD, ::testing::ValuesIn(inputsd)); +using PermMdspanTestD = PermMdspanTest; +TEST_P(PermMdspanTestD, Result) +{ + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); +} +INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestD, ::testing::ValuesIn(inputsd)); + } // end namespace random } // end namespace raft