Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KNN select-top-k variants #551

Merged
merged 43 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
179e1df
Integrate new select-top-k implementations
achirkin Mar 9, 2022
4749295
warpsort_topk: refactoring and fixing some bugs
achirkin Mar 9, 2022
8504d32
Allow passing indices along with keys (values).
achirkin Mar 9, 2022
cef3253
Adapt to the new bench
achirkin Mar 10, 2022
535fa0d
Use the pooled allocator helper
achirkin Mar 10, 2022
7d10507
Remove the step of calculating required buf size.
achirkin Mar 11, 2022
ba66efa
Remove unused code
achirkin Mar 11, 2022
3eab24b
Allow different types in select-k functions (float/double, int/size_t)
achirkin Mar 11, 2022
659bc18
More refactoring and comments
achirkin Mar 17, 2022
8b6351b
Update knn.cuh docs
achirkin Mar 17, 2022
a43e462
Add more comments
achirkin Mar 18, 2022
45f6a35
Use radix top-k as reference, because it supports larger k
achirkin Mar 18, 2022
0fe93d2
Add more comments and refactor vectorized_process
achirkin Mar 18, 2022
50800a4
Make bitonic sort use less template parameters for faster compile times
achirkin Mar 21, 2022
78805f0
Use gridDim.y for the batch dimension to simplify math and use less r…
achirkin Mar 21, 2022
9cf1f33
Update tests
achirkin Mar 21, 2022
90293dc
Allow larger batch sizes for radix_topk
achirkin Mar 21, 2022
b38c80e
Merge branch 'branch-22.04' into enh-knn-topk-variants
achirkin Mar 21, 2022
48ac5c7
Update docs
achirkin Mar 22, 2022
fa76a4d
More cosmetic refactoring
achirkin Mar 23, 2022
3285de5
Even more cosmetic refactoring
achirkin Mar 23, 2022
faecc32
Flip the ascending/descending flag for radix_topk
achirkin Mar 23, 2022
db24b10
Even more cosmetic refactoring
achirkin Mar 23, 2022
a30a2fc
Fix a typo
achirkin Mar 23, 2022
c722d9f
Rename one of the 'add' overloads to reflect it should be used only once
achirkin Mar 23, 2022
fe95ded
Refactor names and document radix_topk
achirkin Mar 24, 2022
00a62a4
Choose the batch size dynamically
achirkin Mar 24, 2022
52f863e
Rename the detail::topk folder
achirkin Mar 24, 2022
2a78c1f
Add the high-level algorithm description
achirkin Mar 24, 2022
d811f75
Rename the warpsort classes
achirkin Mar 24, 2022
fcab684
Fix a typo
achirkin Mar 24, 2022
dcb17fe
Merge remote-tracking branch 'rapidsai/branch-22.04' into enh-knn-top…
achirkin Mar 25, 2022
84de3f2
Clarify some parts of documentsion for bitonic sort
achirkin Mar 25, 2022
25ff099
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
99f6feb
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
deb7e44
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
073d0f5
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
a811740
Address review comments
achirkin Mar 28, 2022
ff2d6e6
Slightly reduce the number of tests for faster CI
achirkin Mar 28, 2022
e2f7d86
Couple more comments
achirkin Mar 28, 2022
1936abd
Address more comments
achirkin Mar 28, 2022
6b3804c
Remove commented-out bench cases
achirkin Mar 29, 2022
bea83b3
Change some bench cases
achirkin Mar 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ set(RAFT_CPP_BENCH_TARGET "bench_raft")
# (please keep the filenames in alphabetical order)
add_executable(${RAFT_CPP_BENCH_TARGET}
bench/linalg/reduce.cu
bench/spatial/selection.cu
bench/main.cpp
)

Expand Down
150 changes: 150 additions & 0 deletions cpp/bench/spatial/selection.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/spatial/knn/knn.cuh>

#include <raft/random/rng.hpp>
#include <raft/sparse/detail/utils.h>

#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

namespace raft::bench::spatial {

struct params {
int n_inputs;
int input_len;
int k;
int select_min;
};

template <typename KeyT, typename IdxT, raft::spatial::knn::SelectKAlgo Algo>
achirkin marked this conversation as resolved.
Show resolved Hide resolved
struct selection : public fixture {
explicit selection(const params& p)
: params_(p),
in_dists_(p.n_inputs * p.input_len, stream),
in_ids_(p.n_inputs * p.input_len, stream),
out_dists_(p.n_inputs * p.k, stream),
out_ids_(p.n_inputs * p.k, stream)
{
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something we need to do in this PR but it would be nice to move this utility out of sparse if it's going to get used in other places.

raft::random::Rng(42).uniform(
in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream);
}

void run_benchmark(::benchmark::State& state) override
{
using_pool_memory_res res;
try {
std::ostringstream label_stream;
label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k;
state.SetLabel(label_stream.str());
loop_on_state(state, [this]() {
raft::spatial::knn::select_k<IdxT, KeyT>(in_dists_.data(),
in_ids_.data(),
params_.n_inputs,
params_.input_len,
out_dists_.data(),
out_ids_.data(),
params_.select_min,
params_.k,
stream,
Algo);
});
} catch (raft::exception& e) {
state.SkipWithError(e.what());
}
}

private:
const params params_;
rmm::device_uvector<KeyT> in_dists_, out_dists_;
rmm::device_uvector<IdxT> in_ids_, out_ids_;
};

const std::vector<params> kInputs{
{10000, 10, 3, true},
{10000, 10, 10, true},
{10, 40, 15, true},
{10, 80, 15, true},
{10, 80, 1, true},
{10, 80, 7, true},
{10, 80, 8, true},
{10, 700, 3, true},
{10, 700, 32, true},
{10, 2000, 64, true},
{10, 10000, 7, true},
{10, 10000, 19, true},
{10, 10000, 127, true},

{1000, 10000, 1, true},
{1000, 10000, 2, true},
{1000, 10000, 4, true},
{1000, 10000, 8, true},
{1000, 10000, 16, true},
{1000, 10000, 32, true},
{1000, 10000, 64, true},
{1000, 10000, 128, true},
{1000, 10000, 256, true},
// {1000, 10000, 512, true}, {1000, 10000, 1024, true}, {1000, 10000, 2048, true},
achirkin marked this conversation as resolved.
Show resolved Hide resolved

{100, 100000, 1, true},
{100, 100000, 2, true},
{100, 100000, 4, true},
{100, 100000, 8, true},
{100, 100000, 16, true},
{100, 100000, 32, true},
{100, 100000, 64, true},
{100, 100000, 128, true},
{100, 100000, 256, true},
// {100, 100000, 512, true}, {100, 100000, 1024, true}, {100, 100000, 2048, true},

{10, 1000000, 1, true},
{10, 1000000, 2, true},
{10, 1000000, 4, true},
{10, 1000000, 8, true},
{10, 1000000, 16, true},
{10, 1000000, 32, true},
{10, 1000000, 64, true},
{10, 1000000, 128, true},
{10, 1000000, 256, true},
// {10, 1000000, 512, true}, {10, 1000000, 1024, true}, {10, 1000000, 2048, true},
};

#define SELECTION_REGISTER(KeyT, IdxT, Algo) \
namespace BENCHMARK_PRIVATE_NAME(selection) \
{ \
using SelectK = selection<KeyT, IdxT, raft::spatial::knn::SelectKAlgo::Algo>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \
}

SELECTION_REGISTER(float, int, FAISS);
SELECTION_REGISTER(float, int, RADIX_8_BITS);
SELECTION_REGISTER(float, int, RADIX_11_BITS);
SELECTION_REGISTER(float, int, WARP_SORT);

SELECTION_REGISTER(double, int, FAISS);
SELECTION_REGISTER(double, int, RADIX_8_BITS);
SELECTION_REGISTER(double, int, RADIX_11_BITS);
SELECTION_REGISTER(double, int, WARP_SORT);

SELECTION_REGISTER(double, size_t, FAISS);
SELECTION_REGISTER(double, size_t, RADIX_8_BITS);
SELECTION_REGISTER(double, size_t, RADIX_11_BITS);
SELECTION_REGISTER(double, size_t, WARP_SORT);

} // namespace raft::bench::spatial
16 changes: 16 additions & 0 deletions cpp/include/raft/cudart_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,22 @@ IntType gcd(IntType a, IntType b)
return a;
}

template <typename T>
constexpr T lower_bound()
{
if constexpr (std::numeric_limits<T>::has_infinity && std::numeric_limits<T>::is_signed) {
return -std::numeric_limits<T>::infinity();
}
return std::numeric_limits<T>::lowest();
}

template <typename T>
constexpr T upper_bound()
{
if constexpr (std::numeric_limits<T>::has_infinity) { return std::numeric_limits<T>::infinity(); }
return std::numeric_limits<T>::max();
}

} // namespace raft

#endif
78 changes: 43 additions & 35 deletions cpp/include/raft/spatial/knn/detail/selection_faiss.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,33 +31,39 @@ namespace spatial {
namespace knn {
namespace detail {

template <typename K, typename IndexType, bool select_min, int warp_q, int thread_q, int tpb>
__global__ void select_k_kernel(K* inK,
IndexType* inV,
template <typename key_t, typename payload_t>
constexpr int kFaissMaxK()
{
return (sizeof(key_t) + sizeof(payload_t) > 8) ? 512 : 1024;
}

template <typename key_t, typename payload_t, bool select_min, int warp_q, int thread_q, int tpb>
__global__ void select_k_kernel(key_t* inK,
payload_t* inV,
size_t n_rows,
size_t n_cols,
K* outK,
IndexType* outV,
K initK,
IndexType initV,
key_t* outK,
payload_t* outV,
key_t initK,
payload_t initV,
int k)
{
constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize;

__shared__ K smemK[kNumWarps * warp_q];
__shared__ IndexType smemV[kNumWarps * warp_q];
__shared__ key_t smemK[kNumWarps * warp_q];
__shared__ payload_t smemV[kNumWarps * warp_q];

faiss::gpu::
BlockSelect<K, IndexType, select_min, faiss::gpu::Comparator<K>, warp_q, thread_q, tpb>
BlockSelect<key_t, payload_t, select_min, faiss::gpu::Comparator<key_t>, warp_q, thread_q, tpb>
heap(initK, initV, smemK, smemV, k);

// Grid is exactly sized to rows available
int row = blockIdx.x;
int i = threadIdx.x;

int idx = row * n_cols;
K* inKStart = inK + idx + i;
IndexType* inVStart = inV + idx + i;
key_t* inKStart = inK + idx + i;
payload_t* inVStart = inV + idx + i;

// Whole warps must participate in the selection
int limit = faiss::gpu::utils::roundDown(n_cols, faiss::gpu::kWarpSize);
Expand All @@ -84,13 +90,13 @@ __global__ void select_k_kernel(K* inK,
}
}

template <typename value_idx = int, typename value_t = float, int warp_q, int thread_q>
inline void select_k_impl(value_t* inK,
value_idx* inV,
template <typename payload_t = int, typename key_t = float, int warp_q, int thread_q>
inline void select_k_impl(key_t* inK,
payload_t* inV,
size_t n_rows,
size_t n_cols,
value_t* outK,
value_idx* outV,
key_t* outK,
payload_t* outV,
bool select_min,
int k,
cudaStream_t stream)
Expand All @@ -100,14 +106,13 @@ inline void select_k_impl(value_t* inK,
constexpr int n_threads = (warp_q <= 1024) ? 128 : 64;
auto block = dim3(n_threads);

auto kInit =
select_min ? faiss::gpu::Limits<value_t>::getMax() : faiss::gpu::Limits<value_t>::getMin();
auto kInit = select_min ? upper_bound<key_t>() : lower_bound<key_t>();
auto vInit = -1;
if (select_min) {
select_k_kernel<value_t, value_idx, false, warp_q, thread_q, n_threads>
select_k_kernel<key_t, payload_t, false, warp_q, thread_q, n_threads>
<<<grid, block, 0, stream>>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k);
} else {
select_k_kernel<value_t, value_idx, true, warp_q, thread_q, n_threads>
select_k_kernel<key_t, payload_t, true, warp_q, thread_q, n_threads>
<<<grid, block, 0, stream>>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k);
}
RAFT_CUDA_TRY(cudaGetLastError());
Expand All @@ -127,38 +132,41 @@ inline void select_k_impl(value_t* inK,
* @param[in] k number of neighbors per partition (also number of merged neighbors)
* @param[in] stream CUDA stream to use
*/
template <typename value_idx = int, typename value_t = float>
inline void select_k(value_t* inK,
value_idx* inV,
template <typename payload_t = int, typename key_t = float>
inline void select_k(key_t* inK,
payload_t* inV,
size_t n_rows,
size_t n_cols,
value_t* outK,
value_idx* outV,
key_t* outK,
payload_t* outV,
bool select_min,
int k,
cudaStream_t stream)
{
constexpr int max_k = kFaissMaxK<payload_t, key_t>();
if (k == 1)
select_k_impl<value_idx, value_t, 1, 1>(
select_k_impl<payload_t, key_t, 1, 1>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 32)
select_k_impl<value_idx, value_t, 32, 2>(
select_k_impl<payload_t, key_t, 32, 2>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 64)
select_k_impl<value_idx, value_t, 64, 3>(
select_k_impl<payload_t, key_t, 64, 3>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 128)
select_k_impl<value_idx, value_t, 128, 3>(
select_k_impl<payload_t, key_t, 128, 3>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 256)
select_k_impl<value_idx, value_t, 256, 4>(
select_k_impl<payload_t, key_t, 256, 4>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 512)
select_k_impl<value_idx, value_t, 512, 8>(
select_k_impl<payload_t, key_t, 512, 8>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 1024)
select_k_impl<value_idx, value_t, 1024, 8>(
else if (k <= 1024 && k <= max_k)
select_k_impl<payload_t, key_t, max_k, 8>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else
ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k);
achirkin marked this conversation as resolved.
Show resolved Hide resolved
}

}; // namespace detail
Expand Down
Loading