Skip to content

Commit

Permalink
Add public enum for select-k algorithm selection (#2046)
Browse files Browse the repository at this point in the history
Add an enum that controls which select-k algorithm is used. This takes the enum that was in the raft_internal and exposes in the public api.  This lets users pick which select algorithm they want to use directly

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2046
  • Loading branch information
benfred authored Jan 10, 2024
1 parent 1484a03 commit 26d310b
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 269 deletions.
44 changes: 26 additions & 18 deletions cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,7 +52,7 @@ struct replace_with_mask {
}
};

template <typename KeyT, typename IdxT, select::Algo Algo>
template <typename KeyT, typename IdxT, SelectAlgo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
: fixture(p.use_memory_pool),
Expand Down Expand Up @@ -110,16 +110,24 @@ struct selection : public fixture {
int iter = 0;
loop_on_state(state, [&iter, this]() {
common::nvtx::range lap_scope("lap-", iter++);
select::select_k_impl<KeyT, IdxT>(handle,
Algo,
in_dists_.data(),
params_.use_index_input ? in_ids_.data() : NULL,
params_.batch_size,
params_.len,
params_.k,
out_dists_.data(),
out_ids_.data(),
params_.select_min);

std::optional<raft::device_matrix_view<const IdxT, int64_t, row_major>> in_ids_view;
if (params_.use_index_input) {
in_ids_view = raft::make_device_matrix_view<const IdxT, int64_t>(
in_ids_.data(), params_.batch_size, params_.len);
}

matrix::select_k<KeyT, IdxT>(handle,
raft::make_device_matrix_view<const KeyT, int64_t>(
in_dists_.data(), params_.batch_size, params_.len),
in_ids_view,
raft::make_device_matrix_view<KeyT, int64_t>(
out_dists_.data(), params_.batch_size, params_.k),
raft::make_device_matrix_view<IdxT, int64_t>(
out_ids_.data(), params_.batch_size, params_.k),
params_.select_min,
false,
Algo);
});
} catch (raft::exception& e) {
state.SkipWithError(e.what());
Expand Down Expand Up @@ -213,13 +221,13 @@ const std::vector<select::params> kInputs{
{1000, 10000, 256, true, false, false, true, 0.999},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
namespace BENCHMARK_PRIVATE_NAME(selection) { \
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
#define SELECTION_REGISTER(KeyT, IdxT, A) \
namespace BENCHMARK_PRIVATE_NAME(selection) { \
using SelectK = selection<KeyT, IdxT, raft::matrix::SelectAlgo::A>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT
Expand Down Expand Up @@ -252,7 +260,7 @@ SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT
// register other benchmarks
#define SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, A, input) \
{ \
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
using SelectK = selection<KeyT, IdxT, SelectAlgo::A>; \
std::stringstream name; \
name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \
<< input.len << "/" << input.k << "/" << input.use_index_input << "/" \
Expand Down
9 changes: 6 additions & 3 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
#include <cstdint> // uint32_t
#include <cuda_fp16.h> // __half
#include <raft/core/device_resources.hpp>
#include <raft/matrix/select_k_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::mr::device_memory_resource
Expand All @@ -38,7 +39,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -54,7 +56,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx, \
bool select_min, \
rmm::mr::device_memory_resource* mr, \
bool sorted)
bool sorted, \
raft::matrix::SelectAlgo algo)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
90 changes: 58 additions & 32 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +24,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/matrix/init.cuh>
#include <raft/matrix/select_k_types.hpp>

#include <raft/core/resource/thrust_policy.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand All @@ -31,10 +33,6 @@

namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
*
Expand All @@ -47,31 +45,31 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bit
* 'generate_heuristic' notebook there will replace the body of this function
* with the latest learned heuristic
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 256) {
if (cols > 16862) {
if (rows > 1020) {
return Algo::kRadix11bitsExtraPass;
return SelectAlgo::kRadix11bitsExtraPass;
} else {
return Algo::kRadix11bits;
return SelectAlgo::kRadix11bits;
}
} else {
return Algo::kRadix11bitsExtraPass;
return SelectAlgo::kRadix11bitsExtraPass;
}
} else {
if (k > 2) {
if (cols > 22061) {
return Algo::kWarpDistributedShm;
return SelectAlgo::kWarpDistributedShm;
} else {
if (rows > 198) {
return Algo::kWarpDistributedShm;
return SelectAlgo::kWarpDistributedShm;
} else {
return Algo::kWarpImmediate;
return SelectAlgo::kWarpImmediate;
}
}
} else {
return Algo::kWarpImmediate;
return SelectAlgo::kWarpImmediate;
}
}
}
Expand Down Expand Up @@ -239,31 +237,48 @@ void select_k(raft::resources const& handle,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }
auto stream = raft::resource::get_cuda_stream(handle);
auto algo = choose_select_k_algorithm(batch_size, len, k);

if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); }

auto stream = raft::resource::get_cuda_stream(handle);
switch (algo) {
case Algo::kRadix11bits:
case Algo::kRadix11bitsExtraPass: {
bool fused_last_filter = algo == Algo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
fused_last_filter,
stream,
mr);
case SelectAlgo::kRadix8bits:
case SelectAlgo::kRadix11bits:
case SelectAlgo::kRadix11bitsExtraPass: {
if (algo == SelectAlgo::kRadix8bits) {
detail::select::radix::select_k<T, IdxT, 8, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream,
mr);

} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
fused_last_filter,
stream,
mr);
}
if (sorted) {
auto offsets = raft::make_device_vector<IdxT, IdxT>(handle, (IdxT)(batch_size + 1));

Expand All @@ -283,14 +298,25 @@ void select_k(raft::resources const& handle,
}
return;
}
case Algo::kWarpDistributedShm:
case SelectAlgo::kWarpDistributed:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case SelectAlgo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case Algo::kWarpImmediate:
case SelectAlgo::kWarpAuto:
return detail::select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case SelectAlgo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case SelectAlgo::kWarpFiltered:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_filtered>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
11 changes: 8 additions & 3 deletions cpp/include/raft/matrix/select_k.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/select_k_types.hpp>

#include <optional>

Expand Down Expand Up @@ -76,6 +77,8 @@ namespace raft::matrix {
* whether to select k smallest (true) or largest (false) keys.
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -84,7 +87,8 @@ void select_k(raft::resources const& handle,
raft::device_matrix_view<T, int64_t, row_major> out_val,
raft::device_matrix_view<IdxT, int64_t, row_major> out_idx,
bool select_min,
bool sorted = false)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits<int>::max()),
"output k must fit the int type.");
Expand All @@ -109,7 +113,8 @@ void select_k(raft::resources const& handle,
out_idx.data_handle(),
select_min,
nullptr,
sorted);
sorted,
algo);
}

/** @} */ // end of group select_k
Expand Down
Loading

0 comments on commit 26d310b

Please sign in to comment.