Skip to content

Commit

Permalink
Improve the performance of radix top-k (#1175)
Browse files Browse the repository at this point in the history
The main changes are:

- Add a one-block version. It uses single thread block for one row of a batch and is used when `len` is relatively small (<= 16384)
- Avoid writing candidates to buffers when the number of candidates is larger than buffer length.
- Add a parameter to control whether to use a fused filter in the last pass or use a standalone filter kernel. The later case is preferable when the leading bits of inputs are almost same.
- Early stopping: when the target bucket contains `k` values, we can stop the computation earlier
- Many implementation details are polished, like the initialization of `counter`, calculation of kernel launch parameters, and the scan step
- Tests and benchmarks are updated to include the new implementations. New benchmarks are added to demonstrate the advantage of adaptive version.

Authors:
  - Yong Wang (https://github.com/yong-wang)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1175
  • Loading branch information
yong-wang authored Mar 24, 2023
1 parent 1b18d1f commit 8f1fa07
Show file tree
Hide file tree
Showing 6 changed files with 975 additions and 363 deletions.
128 changes: 93 additions & 35 deletions cpp/bench/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cstdint>
#include <cstring>
#include <type_traits>

namespace raft::matrix {

using namespace raft::bench; // NOLINT
Expand All @@ -50,7 +54,23 @@ struct selection : public fixture {
{
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream);
raft::random::RngState state{42};
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0));

KeyT min_value = -1.0;
KeyT max_value = 1.0;
if (p.use_same_leading_bits) {
if constexpr (std::is_same_v<KeyT, float>) {
uint32_t min_bits = 0x3F800000; // 1.0
uint32_t max_bits = 0x3F8000FF; // 1.00003
memcpy(&min_value, &min_bits, sizeof(KeyT));
memcpy(&max_value, &max_bits, sizeof(KeyT));
} else if constexpr (std::is_same_v<KeyT, double>) {
uint64_t min_bits = 0x3FF0000000000000; // 1.0
uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015
memcpy(&min_value, &min_bits, sizeof(KeyT));
memcpy(&max_value, &max_bits, sizeof(KeyT));
}
}
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value);
}

void run_benchmark(::benchmark::State& state) override // NOLINT
Expand All @@ -60,6 +80,7 @@ struct selection : public fixture {
try {
std::ostringstream label_stream;
label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k;
if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; }
state.SetLabel(label_stream.str());
loop_on_state(state, [this, &handle]() {
select::select_k_impl<KeyT, IdxT>(handle,
Expand All @@ -85,21 +106,55 @@ struct selection : public fixture {
};

const std::vector<select::params> kInputs{
{20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true},
{20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true},
{20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true},

{1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true},
{1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true},
{1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true},

{100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true},
{100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true},
{100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true},

{10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true},
{10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true},
{10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true},
{20000, 500, 1, true},
{20000, 500, 2, true},
{20000, 500, 4, true},
{20000, 500, 8, true},
{20000, 500, 16, true},
{20000, 500, 32, true},
{20000, 500, 64, true},
{20000, 500, 128, true},
{20000, 500, 256, true},

{1000, 10000, 1, true},
{1000, 10000, 2, true},
{1000, 10000, 4, true},
{1000, 10000, 8, true},
{1000, 10000, 16, true},
{1000, 10000, 32, true},
{1000, 10000, 64, true},
{1000, 10000, 128, true},
{1000, 10000, 256, true},

{100, 100000, 1, true},
{100, 100000, 2, true},
{100, 100000, 4, true},
{100, 100000, 8, true},
{100, 100000, 16, true},
{100, 100000, 32, true},
{100, 100000, 64, true},
{100, 100000, 128, true},
{100, 100000, 256, true},

{10, 1000000, 1, true},
{10, 1000000, 2, true},
{10, 1000000, 4, true},
{10, 1000000, 8, true},
{10, 1000000, 16, true},
{10, 1000000, 32, true},
{10, 1000000, 64, true},
{10, 1000000, 128, true},
{10, 1000000, 256, true},

{10, 1000000, 1, true, false, true},
{10, 1000000, 2, true, false, true},
{10, 1000000, 4, true, false, true},
{10, 1000000, 8, true, false, true},
{10, 1000000, 16, true, false, true},
{10, 1000000, 32, true, false, true},
{10, 1000000, 64, true, false, true},
{10, 1000000, 128, true, false, true},
{10, 1000000, 256, true, false, true},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
Expand All @@ -109,24 +164,27 @@ const std::vector<select::params> kInputs{
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT
SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT

} // namespace raft::matrix
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/detail/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void select_k(const T* in_val,
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
} else {
select::radix::select_k<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr);
}
}

Expand Down
Loading

0 comments on commit 8f1fa07

Please sign in to comment.