Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learn heuristic to pick fastest select_k algorithm #1523

Merged
merged 4 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ using namespace raft::bench; // NOLINT
template <typename KeyT, typename IdxT, select::Algo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
: fixture(true),
: fixture(p.use_memory_pool),
params_(p),
in_dists_(p.batch_size * p.len, stream),
in_ids_(p.batch_size * p.len, stream),
Expand Down Expand Up @@ -193,7 +193,8 @@ SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
std::stringstream name; \
name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \
<< input.len << "/" << input.k << "/" << input.use_index_input; \
<< input.len << "/" << input.k << "/" << input.use_index_input << "/" \
<< input.use_memory_pool; \
auto* b = ::benchmark::internal::RegisterBenchmarkInternal( \
new raft::bench::internal::Fixture<SelectK, select::params>(name.str(), input)); \
b->UseManualTime(); \
Expand Down Expand Up @@ -266,5 +267,14 @@ void add_select_k_dataset_benchmarks()
SELECTION_REGISTER_INPUT(float, int64_t, input);
SELECTION_REGISTER_INPUT(float, uint32_t, input);
}

// also try again without a memory pool to see if there are significant differences
for (auto input : inputs) {
input.use_memory_pool = false;
SELECTION_REGISTER_INPUT(double, int64_t, input);
SELECTION_REGISTER_INPUT(double, uint32_t, input);
SELECTION_REGISTER_INPUT(float, int64_t, input);
SELECTION_REGISTER_INPUT(float, uint32_t, input);
}
}
} // namespace raft::matrix
120 changes: 111 additions & 9 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,101 @@

#include <raft/core/nvtx.hpp>

#include <raft/neighbors/detail/selection_faiss.cuh>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
*
* The body of this method is automatically generated, using a DecisionTree
* to predict the fastest algorithm based off of thousands of trial runs
* on different values of rows/cols/k. The decision tree is converted to c++
* code, which is cut and paste below.
*
* The code to generate is in cpp/scripts/heuristics/select_k, running the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a tiny nitpick:

Suggested change
* The code to generate is in cpp/scripts/heuristics/select_k, running the
* NOTE: The code to generate is in cpp/scripts/heuristics/select_k, running the

* 'generate_heuristic' notebook there will replace the body of this function
* with the latest learned heuristic
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 134) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we'd better use log_2(k) instead of k when constructing the heuristic, so that all values of k go in powers of two. For all warp-based algorithms, performance for non-powers of two is equal to their rounded-up powers of two (queue capacity parameter).

if (k > 256) {
if (k > 809) {
return Algo::kRadix11bits;
} else {
if (rows > 124) {
if (cols > 63488) {
return Algo::kFaissBlockSelect;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 678736) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 13776) {
if (rows > 335) {
if (k > 1) {
if (rows > 546) {
return Algo::kWarpDistributedShm;
} else {
if (k > 17) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kFaissBlockSelect;
}
}
} else {
return Algo::kFaissBlockSelect;
}
} else {
if (k > 44) {
if (cols > 1031051) {
return Algo::kWarpDistributedShm;
} else {
if (rows > 22) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
if (k > 1) {
if (rows > 188) {
return Algo::kWarpDistributedShm;
} else {
if (k > 72) {
return Algo::kRadix11bits;
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
return Algo::kFaissBlockSelect;
}
}
}
}

/**
* Select k smallest or largest key/values from each row in the input data.
*
Expand Down Expand Up @@ -77,15 +167,27 @@ void select_k(const T* in_val,
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
// TODO (achirkin): investigate the trade-off for a wider variety of inputs.
const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128;
if (k <= select::warpsort::kMaxCapacity && !radix_faster) {
select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
} else {
select::radix::select_k<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr);

auto algo = choose_select_k_algorithm(batch_size, len, k);
switch (algo) {
case Algo::kRadix11bits:
return detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream);
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
}
}

} // namespace raft::matrix::detail
1 change: 1 addition & 0 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct params {
bool select_min;
bool use_index_input = true;
bool use_same_leading_bits = false;
bool use_memory_pool = true;
};

inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream&
Expand Down
Loading