Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add public enum for select-k algorithm selection #2046

Merged
merged 10 commits into from
Jan 10, 2024
42 changes: 25 additions & 17 deletions cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct replace_with_mask {
}
};

template <typename KeyT, typename IdxT, select::Algo Algo>
template <typename KeyT, typename IdxT, SelectAlgo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
: fixture(p.use_memory_pool),
Expand Down Expand Up @@ -110,16 +110,24 @@ struct selection : public fixture {
int iter = 0;
loop_on_state(state, [&iter, this]() {
common::nvtx::range lap_scope("lap-", iter++);
select::select_k_impl<KeyT, IdxT>(handle,
Algo,
in_dists_.data(),
params_.use_index_input ? in_ids_.data() : NULL,
params_.batch_size,
params_.len,
params_.k,
out_dists_.data(),
out_ids_.data(),
params_.select_min);

std::optional<raft::device_matrix_view<const IdxT, int64_t, row_major>> in_ids_view;
if (params_.use_index_input) {
in_ids_view = raft::make_device_matrix_view<const IdxT, int64_t>(
in_ids_.data(), params_.batch_size, params_.len);
}

matrix::select_k<KeyT, IdxT>(handle,
raft::make_device_matrix_view<const KeyT, int64_t>(
in_dists_.data(), params_.batch_size, params_.len),
in_ids_view,
raft::make_device_matrix_view<KeyT, int64_t>(
out_dists_.data(), params_.batch_size, params_.k),
raft::make_device_matrix_view<IdxT, int64_t>(
out_ids_.data(), params_.batch_size, params_.k),
params_.select_min,
false,
Algo);
});
} catch (raft::exception& e) {
state.SkipWithError(e.what());
Expand Down Expand Up @@ -213,13 +221,13 @@ const std::vector<select::params> kInputs{
{1000, 10000, 256, true, false, false, true, 0.999},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
namespace BENCHMARK_PRIVATE_NAME(selection) { \
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
#define SELECTION_REGISTER(KeyT, IdxT, A) \
namespace BENCHMARK_PRIVATE_NAME(selection) { \
using SelectK = selection<KeyT, IdxT, raft::matrix::SelectAlgo::A>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT
Expand Down Expand Up @@ -252,7 +260,7 @@ SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT
// register other benchmarks
#define SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, A, input) \
{ \
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
using SelectK = selection<KeyT, IdxT, SelectAlgo::A>; \
std::stringstream name; \
name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \
<< input.len << "/" << input.k << "/" << input.use_index_input << "/" \
Expand Down
7 changes: 5 additions & 2 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cstdint> // uint32_t
#include <cuda_fp16.h> // __half
#include <raft/core/device_resources.hpp>
#include <raft/matrix/select_k_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::mr::device_memory_resource
Expand All @@ -38,7 +39,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -54,7 +56,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx, \
bool select_min, \
rmm::mr::device_memory_resource* mr, \
bool sorted)
bool sorted, \
raft::matrix::SelectAlgo algo)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
88 changes: 57 additions & 31 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/*

* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -23,6 +24,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/matrix/init.cuh>
#include <raft/matrix/select_k_types.hpp>

#include <raft/core/resource/thrust_policy.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand All @@ -31,10 +33,6 @@

namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
*
Expand All @@ -47,31 +45,31 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bit
* 'generate_heuristic' notebook there will replace the body of this function
* with the latest learned heuristic
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 256) {
if (cols > 16862) {
if (rows > 1020) {
return Algo::kRadix11bitsExtraPass;
return SelectAlgo::kRadix11bitsExtraPass;
} else {
return Algo::kRadix11bits;
return SelectAlgo::kRadix11bits;
}
} else {
return Algo::kRadix11bitsExtraPass;
return SelectAlgo::kRadix11bitsExtraPass;
}
} else {
if (k > 2) {
if (cols > 22061) {
return Algo::kWarpDistributedShm;
return SelectAlgo::kWarpDistributedShm;
} else {
if (rows > 198) {
return Algo::kWarpDistributedShm;
return SelectAlgo::kWarpDistributedShm;
} else {
return Algo::kWarpImmediate;
return SelectAlgo::kWarpImmediate;
}
}
} else {
return Algo::kWarpImmediate;
return SelectAlgo::kWarpImmediate;
}
}
}
Expand Down Expand Up @@ -239,31 +237,48 @@ void select_k(raft::resources const& handle,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }
auto stream = raft::resource::get_cuda_stream(handle);
auto algo = choose_select_k_algorithm(batch_size, len, k);

if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); }

auto stream = raft::resource::get_cuda_stream(handle);
switch (algo) {
case Algo::kRadix11bits:
case Algo::kRadix11bitsExtraPass: {
bool fused_last_filter = algo == Algo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
fused_last_filter,
stream,
mr);
case SelectAlgo::kRadix8bits:
case SelectAlgo::kRadix11bits:
case SelectAlgo::kRadix11bitsExtraPass: {
if (algo == SelectAlgo::kRadix8bits) {
detail::select::radix::select_k<T, IdxT, 8, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream,
mr);

} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
fused_last_filter,
stream,
mr);
}
if (sorted) {
auto offsets = raft::make_device_vector<IdxT, IdxT>(handle, (IdxT)(batch_size + 1));

Expand All @@ -283,14 +298,25 @@ void select_k(raft::resources const& handle,
}
return;
}
case Algo::kWarpDistributedShm:
case SelectAlgo::kWarpDistributed:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case SelectAlgo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case Algo::kWarpImmediate:
case SelectAlgo::kWarpAuto:
return detail::select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case SelectAlgo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case SelectAlgo::kWarpFiltered:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_filtered>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
9 changes: 7 additions & 2 deletions cpp/include/raft/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/select_k_types.hpp>

#include <optional>

Expand Down Expand Up @@ -76,6 +77,8 @@ namespace raft::matrix {
* whether to select k smallest (true) or largest (false) keys.
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we expand the algorithms here a bit to give a brief summary of what they mean? I expect this API to be used mostly by power users, but I can't imagine even power users (outside of RAFT immediate developers) would be able to use these options without a brief description and summary of why each might be used. It's also a great opportunity to add the "see also" for the air-top-k paper.

*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -84,7 +87,8 @@ void select_k(raft::resources const& handle,
raft::device_matrix_view<T, int64_t, row_major> out_val,
raft::device_matrix_view<IdxT, int64_t, row_major> out_idx,
bool select_min,
bool sorted = false)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits<int>::max()),
"output k must fit the int type.");
Expand All @@ -109,7 +113,8 @@ void select_k(raft::resources const& handle,
out_idx.data_handle(),
select_min,
nullptr,
sorted);
sorted,
algo);
}

/** @} */ // end of group select_k
Expand Down
60 changes: 60 additions & 0 deletions cpp/include/raft/matrix/select_k_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <type_traits>

namespace raft::matrix {

/**
* @defgroup select_k Batched-select k smallest or largest key/values
* @{
*/

enum class SelectAlgo : uint8_t {
kAuto = 0,
kRadix8bits = 1,
kRadix11bits = 2,
kRadix11bitsExtraPass = 3,
kWarpAuto = 4,
kWarpImmediate = 5,
kWarpFiltered = 6,
kWarpDistributed = 7,
kWarpDistributedShm = 8,
};

inline auto operator<<(std::ostream& os, const SelectAlgo& algo) -> std::ostream&
{
auto underlying_value = static_cast<std::underlying_type<SelectAlgo>::type>(algo);

switch (algo) {
case SelectAlgo::kAuto: return os << "kAuto=" << underlying_value;
case SelectAlgo::kRadix8bits: return os << "kRadix8bits=" << underlying_value;
case SelectAlgo::kRadix11bits: return os << "kRadix11bits=" << underlying_value;
case SelectAlgo::kRadix11bitsExtraPass:
return os << "kRadix11bitsExtraPass=" << underlying_value;
case SelectAlgo::kWarpAuto: return os << "kWarpAuto=" << underlying_value;
case SelectAlgo::kWarpImmediate: return os << "kWarpImmediate=" << underlying_value;
case SelectAlgo::kWarpFiltered: return os << "kWarpFiltered=" << underlying_value;
case SelectAlgo::kWarpDistributed: return os << "kWarpDistributed=" << underlying_value;
case SelectAlgo::kWarpDistributedShm: return os << "kWarpDistributedShm=" << underlying_value;
default: throw std::invalid_argument("invalid value for SelectAlgo");
}
}

/** @} */ // end of group select_k

} // namespace raft::matrix
Loading