diff --git a/cpp/include/raft/random/permute.cuh b/cpp/include/raft/random/permute.cuh index 1c01d589f4..710baa915c 100644 --- a/cpp/include/raft/random/permute.cuh +++ b/cpp/include/raft/random/permute.cuh @@ -21,30 +21,162 @@ #include "detail/permute.cuh" +#include +#include +#include + namespace raft::random { /** - * @brief Generate permutations of the input array. Pretty useful primitive for - * shuffling the input datasets in ML algos. See note at the end for some of its - * limitations! - * @tparam Type Data type of the array to be shuffled - * @tparam IntType Integer type used for ther perms array - * @tparam IdxType Integer type used for addressing indices - * @tparam TPB threads per block - * @param perms the output permutation indices. Typically useful only when - * one wants to refer back. If you don't need this, pass a nullptr - * @param out the output shuffled array. Pass nullptr if you don't want this to - * be written. For eg: when you only want the perms array to be filled. - * @param in input array (in-place is not supported due to race conditions!) - * @param D number of columns of the input array - * @param N length of the input array (or number of rows) - * @param rowMajor whether the input/output matrices are row or col major - * @param stream cuda stream where to launch the work - * - * @note This is NOT a uniform permutation generator! In fact, it only generates - * very small percentage of permutations. If your application really requires a - * high quality permutation generator, it is recommended that you pick - * Knuth Shuffle. + * @brief Randomly permute the rows of the input matrix. + * + * We do not support in-place permutation, so that we can compute + * in parallel without race conditions. This function is useful + * for shuffling input data sets in machine learning algorithms. + * + * @tparam InputOutputValueType Type of each element of the input matrix, + * and the type of each element of the output matrix (if provided) + * @tparam IntType Integer type of each element of `permsOut` + * @tparam IdxType Integer type of the extents of the mdspan parameters + * @tparam Layout Either `raft::row_major` or `raft::col_major` + * + * @param[in] handle RAFT handle containing the CUDA stream + * on which to run. + * @param[out] permsOut If provided, the indices of the permutation. + * @param[out] out If provided, the output matrix, containing the + * permuted rows of the input matrix `in`. (Not providing this + * is only useful if you provide `permsOut`.) + * @param[in] in input matrix + * + * @pre If `permsOut.has_value()` is `true`, + * then `(*permsOut).extent(0) == in.extent(0)` is `true`. + * + * @pre If `out.has_value()` is `true`, + * then `(*out).extents() == in.extents()` is `true`. + * + * @note This is NOT a uniform permutation generator! + * It only generates a small fraction of all possible random permutations. + * If your application needs a high-quality permutation generator, + * then we recommend Knuth Shuffle. + */ +template +void permute(const raft::handle_t& handle, + raft::device_matrix_view in, + std::optional> permsOut, + std::optional> out) +{ + static_assert(std::is_integral_v, + "permute: The type of each element " + "of permsOut (if provided) must be an integral type."); + static_assert(std::is_integral_v, + "permute: The index type " + "of each mdspan argument must be an integral type."); + constexpr bool is_row_major = std::is_same_v; + constexpr bool is_col_major = std::is_same_v; + static_assert(is_row_major || is_col_major, + "permute: Layout must be either " + "raft::row_major or raft::col_major (or one of their aliases)"); + + const bool permsOut_has_value = permsOut.has_value(); + const bool out_has_value = out.has_value(); + + RAFT_EXPECTS(!permsOut_has_value || (*permsOut).extent(0) == in.extent(0), + "permute: If 'permsOut' is provided, then its extent(0) " + "must equal the number of rows of the input matrix 'in'."); + RAFT_EXPECTS(!out_has_value || (*out).extents() == in.extents(), + "permute: If 'out' is provided, then both its extents " + "must match the extents of the input matrix 'in'."); + + IntType* permsOut_ptr = permsOut_has_value ? (*permsOut).data_handle() : nullptr; + InputOutputValueType* out_ptr = out_has_value ? (*out).data_handle() : nullptr; + + if (permsOut_ptr != nullptr || out_ptr != nullptr) { + const IdxType N = in.extent(0); + const IdxType D = in.extent(1); + detail::permute( + permsOut_ptr, out_ptr, in.data_handle(), D, N, is_row_major, handle.get_stream()); + } +} + +namespace permute_impl { + +template +struct perms_out_view { +}; + +template +struct perms_out_view { + // permsOut won't have a value anyway, + // so we can pick any integral value type we want. + using type = raft::device_vector_view; +}; + +template +struct perms_out_view>, + InputOutputValueType, + IdxType, + Layout> { + using type = raft::device_vector_view; +}; + +template +using perms_out_view_t = typename perms_out_view::type; + +} // namespace permute_impl + +/** + * @brief Overload of `permute` that compiles if users pass in `std::nullopt` + * for either or both of `permsOut` and `out`. + */ +template +void permute(const raft::handle_t& handle, + raft::device_matrix_view in, + PermsOutType&& permsOut, + OutType&& out) +{ + // If PermsOutType is std::optional> + // for some T, then that type T need not be related to any of the + // other template parameters. Thus, we have to deduce it specially. + using perms_out_view_type = permute_impl:: + perms_out_view_t, InputOutputValueType, IdxType, Layout>; + using out_view_type = raft::device_matrix_view; + + static_assert(std::is_same_v, std::nullopt_t> || + std::is_same_v, std::optional>, + "permute: The type of 'out' must be either std::optional<" + "raft::device_matrix_view>, " + "or std::nullopt."); + + std::optional permsOut_arg = std::forward(permsOut); + std::optional out_arg = std::forward(out); + permute(handle, in, permsOut_arg, out_arg); +} + +/** + * @brief Legacy overload of `permute` that takes raw arrays instead of mdspan. + * + * @tparam Type Type of each element of the input matrix to be permuted + * @tparam IntType Integer type of each element of the permsOut matrix + * @tparam IdxType Integer type of the dimensions of the matrices + * @tparam TPB threads per block (do not use any value other than the default) + * + * @param[out] perms If nonnull, the indices of the permutation + * @param[out] out If nonnull, the output matrix, containing the + * permuted rows of the input matrix @c in. (Not providing this + * is only useful if you provide @c perms.) + * @param[in] in input matrix + * @param[in] D number of columns in the matrices + * @param[in] N number of rows in the matrices + * @param[in] rowMajor true if the matrices are row major, + * false if they are column major + * @param[in] stream CUDA stream on which to run */ template void permute(IntType* perms, @@ -60,4 +192,4 @@ void permute(IntType* perms, }; // end namespace raft::random -#endif \ No newline at end of file +#endif diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index a0e9f2f25f..32e5540d51 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -40,6 +40,9 @@ template template class PermTest : public ::testing::TestWithParam> { + public: + using test_data_type = T; + protected: PermTest() : in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream()) @@ -81,6 +84,89 @@ class PermTest : public ::testing::TestWithParam> { int* outPerms_ptr = nullptr; }; +template +class PermMdspanTest : public ::testing::TestWithParam> { + public: + using test_data_type = T; + + protected: + PermMdspanTest() + : in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream()) + { + } + + private: + using index_type = int; + + template + using matrix_view_t = raft::device_matrix_view; + + template + using vector_view_t = raft::device_vector_view; + + protected: + void SetUp() override + { + auto stream = handle.get_stream(); + params = ::testing::TestWithParam>::GetParam(); + // forcefully set needPerms, since we need it for unit-testing! + if (params.needShuffle) { params.needPerms = true; } + raft::random::RngState r(params.seed); + int N = params.N; + int D = params.D; + int len = N * D; + if (params.needPerms) { + outPerms.resize(N, stream); + outPerms_ptr = outPerms.data(); + } + if (params.needShuffle) { + in.resize(len, stream); + out.resize(len, stream); + in_ptr = in.data(); + out_ptr = out.data(); + uniform(handle, r, in_ptr, len, T(-1.0), T(1.0)); + } + + auto set_up_views_and_test = [&](auto layout) { + using layout_type = std::decay_t; + + matrix_view_t in_view(in_ptr, N, D); + std::optional> out_view; + if (out_ptr != nullptr) { out_view.emplace(out_ptr, N, D); } + std::optional> outPerms_view; + if (outPerms_ptr != nullptr) { outPerms_view.emplace(outPerms_ptr, N); } + + permute(handle, in_view, outPerms_view, out_view); + + // None of these three permute calls should have an effect. + // The point is to test whether the function can deduce the + // element type of outPerms if given nullopt. + std::optional> out_view_empty; + std::optional> outPerms_view_empty; + permute(handle, in_view, std::nullopt, out_view_empty); + permute(handle, in_view, outPerms_view_empty, std::nullopt); + permute(handle, in_view, std::nullopt, std::nullopt); + }; + + if (params.rowMajor) { + set_up_views_and_test(raft::row_major{}); + } else { + set_up_views_and_test(raft::col_major{}); + } + + handle.sync_stream(); + } + + protected: + raft::handle_t handle; + PermInputs params; + rmm::device_uvector in, out; + T* in_ptr = nullptr; + T* out_ptr = nullptr; + rmm::device_uvector outPerms; + int* outPerms_ptr = nullptr; +}; + template ::testing::AssertionResult devArrMatchRange( const T* actual, size_t size, T start, L eq_compare, bool doSort = true, cudaStream_t stream = 0) @@ -169,19 +255,38 @@ const std::vector> inputsf = { {100000, 32, true, true, false, 1234567890ULL}, {100001, 33, true, true, false, 1234567890ULL}}; -typedef PermTest PermTestF; +#define _PERMTEST_BODY(DATA_TYPE) \ + do { \ + if (params.needPerms) { \ + ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare())); \ + } \ + if (params.needShuffle) { \ + ASSERT_TRUE(devArrMatchShuffle(outPerms_ptr, \ + out_ptr, \ + in_ptr, \ + params.D, \ + params.N, \ + params.rowMajor, \ + raft::Compare())); \ + } \ + } while (false) + +using PermTestF = PermTest; TEST_P(PermTestF, Result) { - if (params.needPerms) { - ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare())); - } - if (params.needShuffle) { - ASSERT_TRUE(devArrMatchShuffle( - outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare())); - } + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); } INSTANTIATE_TEST_CASE_P(PermTests, PermTestF, ::testing::ValuesIn(inputsf)); +using PermMdspanTestF = PermMdspanTest; +TEST_P(PermMdspanTestF, Result) +{ + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); +} +INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = { // only generate permutations {32, 8, true, false, true, 1234ULL}, @@ -219,18 +324,22 @@ const std::vector> inputsd = { {100000, 32, true, true, false, 1234ULL}, {100000, 32, true, true, false, 1234567890ULL}, {100001, 33, true, true, false, 1234567890ULL}}; -typedef PermTest PermTestD; + +using PermTestD = PermTest; TEST_P(PermTestD, Result) { - if (params.needPerms) { - ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare())); - } - if (params.needShuffle) { - ASSERT_TRUE(devArrMatchShuffle( - outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare())); - } + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); } INSTANTIATE_TEST_CASE_P(PermTests, PermTestD, ::testing::ValuesIn(inputsd)); +using PermMdspanTestD = PermMdspanTest; +TEST_P(PermMdspanTestD, Result) +{ + using test_data_type = PermTestF::test_data_type; + _PERMTEST_BODY(test_data_type); +} +INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestD, ::testing::ValuesIn(inputsd)); + } // end namespace random } // end namespace raft