Skip to content

Commit

Permalink
Learn heuristic to pick fastest select_k algorithm
Browse files Browse the repository at this point in the history
This uses the select_k dataset from rapidsai#1497 to
learn a heuristic of the fastest select_k variant based off the rows/ cols/k
of the input. This heuristic is modelled as a DecisionTree, which is automatically
exported in C++ code that is compiled into RAFT. This lets us learn a function to
pick the fastest select_k method - which requires only a few if statements in C++
code, making it very cheap to evaluate.
  • Loading branch information
benfred committed May 17, 2023
1 parent d99d249 commit fa2877c
Show file tree
Hide file tree
Showing 5 changed files with 1,606 additions and 9 deletions.
120 changes: 111 additions & 9 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,101 @@

#include <raft/core/nvtx.hpp>

#include <raft/neighbors/detail/selection_faiss.cuh>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
*
* The body of this method is automatically generated, using a DecisionTree
* to predict the fastest algorithm based off of thousands of trial runs
* on different values of rows/cols/k. The decision tree is converted to c++
* code, which is cut and paste below.
*
* The code to generate is in cpp/scripts/heuristics/select_k, running the
* 'generate_heuristic' notebook there will replace the body of this function
* with the latest learned heuristic
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 134) {
if (k > 256) {
if (k > 809) {
return Algo::kRadix11bits;
} else {
if (rows > 124) {
if (cols > 63488) {
return Algo::kFaissBlockSelect;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 678736) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 13776) {
if (rows > 335) {
if (k > 1) {
if (rows > 546) {
return Algo::kWarpDistributedShm;
} else {
if (k > 17) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kFaissBlockSelect;
}
}
} else {
return Algo::kFaissBlockSelect;
}
} else {
if (k > 44) {
if (cols > 1031051) {
return Algo::kWarpDistributedShm;
} else {
if (rows > 22) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
if (k > 1) {
if (rows > 188) {
return Algo::kWarpDistributedShm;
} else {
if (k > 72) {
return Algo::kRadix11bits;
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
return Algo::kFaissBlockSelect;
}
}
}
}

/**
* Select k smallest or largest key/values from each row in the input data.
*
Expand Down Expand Up @@ -77,15 +167,27 @@ void select_k(const T* in_val,
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
// TODO (achirkin): investigate the trade-off for a wider variety of inputs.
const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128;
if (k <= select::warpsort::kMaxCapacity && !radix_faster) {
select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
} else {
select::radix::select_k<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr);

auto algo = choose_select_k_algorithm(batch_size, len, k);
switch (algo) {
case Algo::kRadix11bits:
return detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream);
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
}
}

} // namespace raft::matrix::detail
Loading

0 comments on commit fa2877c

Please sign in to comment.