Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the performance of radix top-k #1175

Merged
merged 45 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
765f4d4
Improve performance of radix top-k
yong-wang Dec 8, 2022
bc5df8d
Merge remote-tracking branch 'origin/branch-23.02' into fea-update-ra…
yong-wang Jan 24, 2023
d085e58
radix top-k: conform to RAFT code style
yong-wang Jan 5, 2023
5f63cbd
radix top-k: add extra input parameter in_idx
yong-wang Jan 11, 2023
6086454
radix top-k: replace greater with select_min
yong-wang Jan 12, 2023
dd63770
radix top-k: make it compiled
yong-wang Jan 11, 2023
aea2bd0
radix top-k: polish style
yong-wang Jan 11, 2023
99924bd
radix top-k: polish code
yong-wang Jan 13, 2023
2746173
radix top-k: remove Store classes
yong-wang Jan 18, 2023
381c075
radix top-k: polish code comments
yong-wang Jan 19, 2023
d20a480
radix top-k: change dynamic to adaptive
yong-wang Jan 19, 2023
c408245
modify radix top-k so that it conforms the latest select_k code
yong-wang Jan 23, 2023
53ebcb8
fix the case when k equals len
yong-wang Jan 24, 2023
fdd30e9
radix top-k: update tests and benchmarks
yong-wang Jan 23, 2023
9ff84d5
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 3, 2023
9dc49eb
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 3, 2023
677e2e7
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 9, 2023
87cd66e
add comments and revise code
yong-wang Feb 12, 2023
f83e7ec
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 16, 2023
e4891a8
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 17, 2023
d66c3ee
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 19, 2023
6add0b8
radix one-block: enable vectorized loading when pass==0
yong-wang Feb 12, 2023
204b370
radix: use one-block version when calculated gridDim.x==1
yong-wang Feb 12, 2023
b482e29
radix: add chunking
yong-wang Feb 12, 2023
834dc4b
radix: reduce buf size for adaptive mode
yong-wang Feb 12, 2023
117f94d
make implementation adaptive
yong-wang Feb 19, 2023
6555bbe
polish code comments
yong-wang Feb 19, 2023
48c4faf
fix potential mul overflow
yong-wang Feb 19, 2023
415798a
fix launch conf of last_filter_kernel
yong-wang Feb 20, 2023
77a14d9
update test and benchmark
yong-wang Feb 20, 2023
532cb4f
remove managed_memory_resource and refine code comments
yong-wang Feb 21, 2023
a4eee75
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Feb 21, 2023
63c5fd0
replace select_radix.cuh with select_radix_updated.cuh
yong-wang Feb 22, 2023
5e7addf
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Mar 1, 2023
f72b3e8
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Mar 2, 2023
f43a523
Merge remote-tracking branch 'origin/branch-23.04' into fea-update-ra…
yong-wang Mar 10, 2023
f7061eb
Add missing fused_last_filter arg while dispatching select_k
tfeher Mar 10, 2023
7543bbf
Merge branch 'branch-23.04' into fea-update-radix-topk
tfeher Mar 11, 2023
01ac6dd
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Mar 15, 2023
45b809f
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Mar 17, 2023
9d7a687
Merge branch 'branch-23.04' into fea-update-radix-topk
cjnolet Mar 23, 2023
d447fcb
Checking in fix for select_k based on offline conversation w/ Yong Wang.
cjnolet Mar 23, 2023
a11fb8e
minor polish
yong-wang Mar 24, 2023
dd6ae51
adjust the place of volatile
yong-wang Mar 24, 2023
f1e281b
Merge branch 'branch-23.04' into fea-update-radix-topk
yong-wang Mar 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 93 additions & 35 deletions cpp/bench/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cstdint>
#include <cstring>
#include <type_traits>

namespace raft::matrix {

using namespace raft::bench; // NOLINT
Expand All @@ -46,7 +50,23 @@ struct selection : public fixture {
{
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream);
raft::random::RngState state{42};
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0));

KeyT min_value = -1.0;
KeyT max_value = 1.0;
if (p.use_same_leading_bits) {
if constexpr (std::is_same_v<KeyT, float>) {
uint32_t min_bits = 0x3F800000; // 1.0
uint32_t max_bits = 0x3F8000FF; // 1.00003
memcpy(&min_value, &min_bits, sizeof(KeyT));
memcpy(&max_value, &max_bits, sizeof(KeyT));
} else if constexpr (std::is_same_v<KeyT, double>) {
uint64_t min_bits = 0x3FF0000000000000; // 1.0
uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015
memcpy(&min_value, &min_bits, sizeof(KeyT));
memcpy(&max_value, &max_bits, sizeof(KeyT));
}
}
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value);
}

void run_benchmark(::benchmark::State& state) override // NOLINT
Expand All @@ -56,6 +76,7 @@ struct selection : public fixture {
try {
std::ostringstream label_stream;
label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k;
if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; }
state.SetLabel(label_stream.str());
loop_on_state(state, [this, &handle]() {
select::select_k_impl<KeyT, IdxT>(handle,
Expand All @@ -81,21 +102,55 @@ struct selection : public fixture {
};

const std::vector<select::params> kInputs{
{20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true},
{20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true},
{20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true},

{1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true},
{1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true},
{1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true},

{100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true},
{100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true},
{100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true},

{10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true},
{10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true},
{10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true},
{20000, 500, 1, true},
{20000, 500, 2, true},
{20000, 500, 4, true},
{20000, 500, 8, true},
{20000, 500, 16, true},
{20000, 500, 32, true},
{20000, 500, 64, true},
{20000, 500, 128, true},
{20000, 500, 256, true},

{1000, 10000, 1, true},
{1000, 10000, 2, true},
{1000, 10000, 4, true},
{1000, 10000, 8, true},
{1000, 10000, 16, true},
{1000, 10000, 32, true},
{1000, 10000, 64, true},
{1000, 10000, 128, true},
{1000, 10000, 256, true},

{100, 100000, 1, true},
{100, 100000, 2, true},
{100, 100000, 4, true},
{100, 100000, 8, true},
{100, 100000, 16, true},
{100, 100000, 32, true},
{100, 100000, 64, true},
{100, 100000, 128, true},
{100, 100000, 256, true},

{10, 1000000, 1, true},
{10, 1000000, 2, true},
{10, 1000000, 4, true},
{10, 1000000, 8, true},
{10, 1000000, 16, true},
{10, 1000000, 32, true},
{10, 1000000, 64, true},
{10, 1000000, 128, true},
{10, 1000000, 256, true},

{10, 1000000, 1, true, false, true},
{10, 1000000, 2, true, false, true},
{10, 1000000, 4, true, false, true},
{10, 1000000, 8, true, false, true},
{10, 1000000, 16, true, false, true},
{10, 1000000, 32, true, false, true},
{10, 1000000, 64, true, false, true},
{10, 1000000, 128, true, false, true},
{10, 1000000, 256, true, false, true},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
Expand All @@ -105,24 +160,27 @@ const std::vector<select::params> kInputs{
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, int, kPublicApi); // NOLINT
SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT
SELECTION_REGISTER(float, int, kPublicApi); // NOLINT
SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, int, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, size_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT

} // namespace raft::matrix
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/detail/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void select_k(const T* in_val,
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
} else {
select::radix::select_k<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr);
}
}

Expand Down
Loading