Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Mdspanify permute #834

Merged
merged 3 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 155 additions & 22 deletions cpp/include/raft/random/permute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,163 @@

#include "detail/permute.cuh"

#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <type_traits>

namespace raft::random {

/**
* @brief Generate permutations of the input array. Pretty useful primitive for
* shuffling the input datasets in ML algos. See note at the end for some of its
* limitations!
* @tparam Type Data type of the array to be shuffled
* @tparam IntType Integer type used for ther perms array
* @tparam IdxType Integer type used for addressing indices
* @tparam TPB threads per block
* @param perms the output permutation indices. Typically useful only when
* one wants to refer back. If you don't need this, pass a nullptr
* @param out the output shuffled array. Pass nullptr if you don't want this to
* be written. For eg: when you only want the perms array to be filled.
* @param in input array (in-place is not supported due to race conditions!)
* @param D number of columns of the input array
* @param N length of the input array (or number of rows)
* @param rowMajor whether the input/output matrices are row or col major
* @param stream cuda stream where to launch the work
*
* @note This is NOT a uniform permutation generator! In fact, it only generates
* very small percentage of permutations. If your application really requires a
* high quality permutation generator, it is recommended that you pick
* Knuth Shuffle.
* @brief Randomly permute the rows of the input matrix.
*
* We do not support in-place permutation, so that we can compute
* in parallel without race conditions. This function is useful
* for shuffling input data sets in machine learning algorithms.
*
* @tparam InputOutputValueType Type of each element of the input matrix,
* and the type of each element of the output matrix (if provided)
* @tparam IntType Integer type of each element of `permsOut`
* @tparam IdxType Integer type of the extents of the mdspan parameters
* @tparam Layout Either `raft::row_major` or `raft::col_major`
*
* @param[in] handle RAFT handle containing the CUDA stream
* on which to run.
* @param[in] in input matrix
* @param[out] permsOut If provided, the indices of the permutation.
* @param[out] out If provided, the output matrix, containing the
* permuted rows of the input matrix `in`. (Not providing this
* is only useful if you provide `permsOut`.)
*
* @pre If `permsOut.has_value()` is `true`,
* then `(*permsOut).extent(0) == in.extent(0)` is `true`.
*
* @pre If `out.has_value()` is `true`,
* then `(*out).extents() == in.extents()` is `true`.
*
* @note This is NOT a uniform permutation generator!
* It only generates a small fraction of all possible random permutations.
* If your application needs a high-quality permutation generator,
* then we recommend Knuth Shuffle.
*/
template <typename InputOutputValueType, typename IntType, typename IdxType, typename Layout>
void permute(const raft::handle_t& handle,
raft::device_matrix_view<const InputOutputValueType, IdxType, Layout> in,
std::optional<raft::device_vector_view<IntType, IdxType>> permsOut,
std::optional<raft::device_matrix_view<InputOutputValueType, IdxType, Layout>> out)
{
static_assert(std::is_integral_v<IntType>,
"permute: The type of each element "
"of permsOut (if provided) must be an integral type.");
static_assert(std::is_integral_v<IdxType>,
"permute: The index type "
"of each mdspan argument must be an integral type.");
constexpr bool is_row_major = std::is_same_v<Layout, raft::row_major>;
constexpr bool is_col_major = std::is_same_v<Layout, raft::col_major>;
static_assert(is_row_major || is_col_major,
"permute: Layout must be either "
"raft::row_major or raft::col_major (or one of their aliases)");

const bool permsOut_has_value = permsOut.has_value();
const bool out_has_value = out.has_value();

RAFT_EXPECTS(!permsOut_has_value || (*permsOut).extent(0) == in.extent(0),
"permute: If 'permsOut' is provided, then its extent(0) "
"must equal the number of rows of the input matrix 'in'.");
RAFT_EXPECTS(!out_has_value || (*out).extents() == in.extents(),
"permute: If 'out' is provided, then both its extents "
"must match the extents of the input matrix 'in'.");

IntType* permsOut_ptr = permsOut_has_value ? (*permsOut).data_handle() : nullptr;
InputOutputValueType* out_ptr = out_has_value ? (*out).data_handle() : nullptr;

if (permsOut_ptr != nullptr || out_ptr != nullptr) {
const IdxType N = in.extent(0);
const IdxType D = in.extent(1);
detail::permute<InputOutputValueType, IntType, IdxType>(
permsOut_ptr, out_ptr, in.data_handle(), D, N, is_row_major, handle.get_stream());
}
}

namespace permute_impl {

template <typename T, typename InputOutputValueType, typename IdxType, typename Layout>
struct perms_out_view {
};

template <typename InputOutputValueType, typename IdxType, typename Layout>
struct perms_out_view<std::nullopt_t, InputOutputValueType, IdxType, Layout> {
// permsOut won't have a value anyway,
// so we can pick any integral value type we want.
using type = raft::device_vector_view<IdxType, IdxType>;
};

template <typename PermutationIndexType,
typename InputOutputValueType,
typename IdxType,
typename Layout>
struct perms_out_view<std::optional<raft::device_vector_view<PermutationIndexType, IdxType>>,
InputOutputValueType,
IdxType,
Layout> {
using type = raft::device_vector_view<PermutationIndexType, IdxType>;
};

template <typename T, typename InputOutputValueType, typename IdxType, typename Layout>
using perms_out_view_t = typename perms_out_view<T, InputOutputValueType, IdxType, Layout>::type;

} // namespace permute_impl

/**
* @brief Overload of `permute` that compiles if users pass in `std::nullopt`
* for either or both of `permsOut` and `out`.
*/
template <typename InputOutputValueType,
typename IdxType,
typename Layout,
typename PermsOutType,
typename OutType>
void permute(const raft::handle_t& handle,
raft::device_matrix_view<const InputOutputValueType, IdxType, Layout> in,
PermsOutType&& permsOut,
OutType&& out)
{
// If PermsOutType is std::optional<device_vector_view<T, IdxType>>
// for some T, then that type T need not be related to any of the
// other template parameters. Thus, we have to deduce it specially.
using perms_out_view_type = permute_impl::
perms_out_view_t<std::decay_t<PermsOutType>, InputOutputValueType, IdxType, Layout>;
using out_view_type = raft::device_matrix_view<InputOutputValueType, IdxType, Layout>;

static_assert(std::is_same_v<std::decay_t<OutType>, std::nullopt_t> ||
std::is_same_v<std::decay_t<OutType>, std::optional<out_view_type>>,
"permute: The type of 'out' must be either std::optional<"
"raft::device_matrix_view<InputOutputViewType, IdxType, Layout>>, "
"or std::nullopt.");

std::optional<perms_out_view_type> permsOut_arg = std::forward<PermsOutType>(permsOut);
std::optional<out_view_type> out_arg = std::forward<OutType>(out);
permute(handle, in, permsOut_arg, out_arg);
}

/**
* @brief Legacy overload of `permute` that takes raw arrays instead of mdspan.
*
* @tparam Type Type of each element of the input matrix to be permuted
* @tparam IntType Integer type of each element of the permsOut matrix
* @tparam IdxType Integer type of the dimensions of the matrices
* @tparam TPB threads per block (do not use any value other than the default)
*
* @param[out] perms If nonnull, the indices of the permutation
* @param[out] out If nonnull, the output matrix, containing the
* permuted rows of the input matrix @c in. (Not providing this
* is only useful if you provide @c perms.)
* @param[in] in input matrix
* @param[in] D number of columns in the matrices
* @param[in] N number of rows in the matrices
* @param[in] rowMajor true if the matrices are row major,
* false if they are column major
* @param[in] stream CUDA stream on which to run
*/
template <typename Type, typename IntType = int, typename IdxType = int, int TPB = 256>
void permute(IntType* perms,
Expand All @@ -60,4 +193,4 @@ void permute(IntType* perms,

}; // end namespace raft::random

#endif
#endif
2 changes: 2 additions & 0 deletions cpp/include/raft/random/rng_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#pragma once

#include <cstdint>

namespace raft {
namespace random {

Expand Down
4 changes: 2 additions & 2 deletions cpp/test/nvtx.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
*/
#ifdef NVTX_ENABLED
#include <gtest/gtest.h>
#include <raft/common/detail/nvtx.hpp>
#include <raft/core/detail/nvtx.hpp>
/**
* tests for the functionality of generating next color based on string
* entered in the NVTX Range marker wrappers
Expand Down
141 changes: 125 additions & 16 deletions cpp/test/random/permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ template <typename T>

template <typename T>
class PermTest : public ::testing::TestWithParam<PermInputs<T>> {
public:
using test_data_type = T;

protected:
PermTest()
: in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream())
Expand Down Expand Up @@ -81,6 +84,89 @@ class PermTest : public ::testing::TestWithParam<PermInputs<T>> {
int* outPerms_ptr = nullptr;
};

template <typename T>
class PermMdspanTest : public ::testing::TestWithParam<PermInputs<T>> {
public:
using test_data_type = T;

protected:
PermMdspanTest()
: in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream())
{
}

private:
using index_type = int;

template <class ElementType, class Layout>
using matrix_view_t = raft::device_matrix_view<ElementType, index_type, Layout>;

template <class ElementType>
using vector_view_t = raft::device_vector_view<ElementType, index_type>;

protected:
void SetUp() override
{
auto stream = handle.get_stream();
params = ::testing::TestWithParam<PermInputs<T>>::GetParam();
// forcefully set needPerms, since we need it for unit-testing!
if (params.needShuffle) { params.needPerms = true; }
raft::random::RngState r(params.seed);
int N = params.N;
int D = params.D;
int len = N * D;
if (params.needPerms) {
outPerms.resize(N, stream);
outPerms_ptr = outPerms.data();
}
if (params.needShuffle) {
in.resize(len, stream);
out.resize(len, stream);
in_ptr = in.data();
out_ptr = out.data();
uniform(handle, r, in_ptr, len, T(-1.0), T(1.0));
}

auto set_up_views_and_test = [&](auto layout) {
using layout_type = std::decay_t<decltype(layout)>;

matrix_view_t<const T, layout_type> in_view(in_ptr, N, D);
std::optional<matrix_view_t<T, layout_type>> out_view;
if (out_ptr != nullptr) { out_view.emplace(out_ptr, N, D); }
std::optional<vector_view_t<index_type>> outPerms_view;
if (outPerms_ptr != nullptr) { outPerms_view.emplace(outPerms_ptr, N); }

permute(handle, in_view, outPerms_view, out_view);

// None of these three permute calls should have an effect.
// The point is to test whether the function can deduce the
// element type of outPerms if given nullopt.
std::optional<matrix_view_t<T, layout_type>> out_view_empty;
std::optional<vector_view_t<index_type>> outPerms_view_empty;
permute(handle, in_view, std::nullopt, out_view_empty);
permute(handle, in_view, outPerms_view_empty, std::nullopt);
permute(handle, in_view, std::nullopt, std::nullopt);
};

if (params.rowMajor) {
set_up_views_and_test(raft::row_major{});
} else {
set_up_views_and_test(raft::col_major{});
}

handle.sync_stream();
}

protected:
raft::handle_t handle;
PermInputs<T> params;
rmm::device_uvector<T> in, out;
T* in_ptr = nullptr;
T* out_ptr = nullptr;
rmm::device_uvector<int> outPerms;
int* outPerms_ptr = nullptr;
};

template <typename T, typename L>
::testing::AssertionResult devArrMatchRange(
const T* actual, size_t size, T start, L eq_compare, bool doSort = true, cudaStream_t stream = 0)
Expand Down Expand Up @@ -169,19 +255,38 @@ const std::vector<PermInputs<float>> inputsf = {
{100000, 32, true, true, false, 1234567890ULL},
{100001, 33, true, true, false, 1234567890ULL}};

typedef PermTest<float> PermTestF;
#define _PERMTEST_BODY(DATA_TYPE) \
do { \
if (params.needPerms) { \
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>())); \
} \
if (params.needShuffle) { \
ASSERT_TRUE(devArrMatchShuffle(outPerms_ptr, \
out_ptr, \
in_ptr, \
params.D, \
params.N, \
params.rowMajor, \
raft::Compare<DATA_TYPE>())); \
} \
} while (false)

using PermTestF = PermTest<float>;
TEST_P(PermTestF, Result)
{
if (params.needPerms) {
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>()));
}
if (params.needShuffle) {
ASSERT_TRUE(devArrMatchShuffle(
outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare<float>()));
}
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermTests, PermTestF, ::testing::ValuesIn(inputsf));

using PermMdspanTestF = PermMdspanTest<float>;
TEST_P(PermMdspanTestF, Result)
{
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestF, ::testing::ValuesIn(inputsf));

const std::vector<PermInputs<double>> inputsd = {
// only generate permutations
{32, 8, true, false, true, 1234ULL},
Expand Down Expand Up @@ -219,18 +324,22 @@ const std::vector<PermInputs<double>> inputsd = {
{100000, 32, true, true, false, 1234ULL},
{100000, 32, true, true, false, 1234567890ULL},
{100001, 33, true, true, false, 1234567890ULL}};
typedef PermTest<double> PermTestD;

using PermTestD = PermTest<double>;
TEST_P(PermTestD, Result)
{
if (params.needPerms) {
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>()));
}
if (params.needShuffle) {
ASSERT_TRUE(devArrMatchShuffle(
outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare<double>()));
}
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermTests, PermTestD, ::testing::ValuesIn(inputsd));

using PermMdspanTestD = PermMdspanTest<double>;
TEST_P(PermMdspanTestD, Result)
{
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestD, ::testing::ValuesIn(inputsd));

} // end namespace random
} // end namespace raft