Skip to content

Commit

Permalink
Allow different types in select-k functions (float/double, int/size_t)
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 11, 2022
1 parent 9861601 commit 76b1a8b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 84 deletions.
13 changes: 9 additions & 4 deletions cpp/bench/spatial/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,14 @@ SELECTION_REGISTER(float, int, RADIX_8_BITS);
SELECTION_REGISTER(float, int, RADIX_11_BITS);
SELECTION_REGISTER(float, int, WARP_SORT);

// SELECTION_REGISTER(double, int, FAISS);
// SELECTION_REGISTER(double, int, RADIX_8_BITS);
// SELECTION_REGISTER(double, int, RADIX_11_BITS);
// SELECTION_REGISTER(double, int, WARP_SORT);
SELECTION_REGISTER(double, int, FAISS);
SELECTION_REGISTER(double, int, RADIX_8_BITS);
SELECTION_REGISTER(double, int, RADIX_11_BITS);
SELECTION_REGISTER(double, int, WARP_SORT);

SELECTION_REGISTER(double, size_t, FAISS);
SELECTION_REGISTER(double, size_t, RADIX_8_BITS);
SELECTION_REGISTER(double, size_t, RADIX_11_BITS);
SELECTION_REGISTER(double, size_t, WARP_SORT);

} // namespace raft::bench::spatial
18 changes: 17 additions & 1 deletion cpp/include/raft/cudart_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,4 +430,20 @@ IntType gcd(IntType a, IntType b)
return a;
}

template <typename T>
constexpr T lower_bound()
{
if constexpr (std::numeric_limits<T>::has_infinity && std::numeric_limits<T>::is_signed) {
return -std::numeric_limits<T>::infinity();
}
return std::numeric_limits<T>::lowest();
}

template <typename T>
constexpr T upper_bound()
{
if constexpr (std::numeric_limits<T>::has_infinity) { return std::numeric_limits<T>::infinity(); }
return std::numeric_limits<T>::max();
}

} // namespace raft
13 changes: 7 additions & 6 deletions cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cub/block/radix_rank_sort_operations.cuh>

#include <raft/cudart_utils.h>
#include <raft/device_atomics.cuh>

/*
Two implementations:
Expand Down Expand Up @@ -176,7 +177,7 @@ __device__ void filter_and_histogram(const T* in_buf,
if (pass == 0) {
auto f = [greater, start_bit, mask](T value, IdxT) {
int bucket = calc_bucket<T, BITS_PER_PASS>(value, start_bit, mask, greater);
atomicAdd(histogram_smem + bucket, 1);
atomicAdd(histogram_smem + bucket, IdxT(1));
};
vectorized_process(in_buf, len, f);
} else {
Expand Down Expand Up @@ -208,11 +209,11 @@ __device__ void filter_and_histogram(const T* in_buf,
int prev_bucket =
calc_bucket<T, BITS_PER_PASS>(value, previous_start_bit, previous_mask, greater);
if (prev_bucket == want_bucket) {
IdxT pos = atomicAdd(&filter_cnt, 1);
IdxT pos = atomicAdd(&filter_cnt, IdxT(1));
out_buf[pos] = value;
if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; }
int bucket = calc_bucket<T, BITS_PER_PASS>(value, start_bit, mask, greater);
atomicAdd(histogram_smem + bucket, 1);
atomicAdd(histogram_smem + bucket, IdxT(1));

if (counter_len == 1) {
if (out) {
Expand All @@ -223,7 +224,7 @@ __device__ void filter_and_histogram(const T* in_buf,
}
}
} else if (out && prev_bucket < want_bucket) {
IdxT pos = atomicAdd(&out_cnt, 1);
IdxT pos = atomicAdd(&out_cnt, IdxT(1));
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
Expand Down Expand Up @@ -385,12 +386,12 @@ __global__ void radix_kernel(const T* in_buf,
const T value = out_buf[i];
int bucket = calc_bucket<T, BITS_PER_PASS>(value, start_bit, mask, greater);
if (bucket < want_bucket) {
IdxT pos = atomicAdd(&out_cnt, 1);
IdxT pos = atomicAdd(&out_cnt, IdxT(1));
out[pos] = value;
out_idx[pos] = out_idx_buf[i];
} else if (bucket == want_bucket) {
IdxT needed_num_of_kth = counter->k;
IdxT back_pos = atomicAdd(&(counter->out_back_cnt), 1);
IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1));
if (back_pos < needed_num_of_kth) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
Expand Down
47 changes: 7 additions & 40 deletions cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,32 +121,6 @@ static constexpr int kMaxCapacity = 512;

namespace {

template <typename T>
constexpr T get_lower_bound()
{
if (std::numeric_limits<T>::has_infinity && std::numeric_limits<T>::is_signed) {
return -std::numeric_limits<T>::infinity();
} else {
return std::numeric_limits<T>::lowest();
}
}

template <typename T>
constexpr T get_upper_bound()
{
if (std::numeric_limits<T>::has_infinity) {
return std::numeric_limits<T>::infinity();
} else {
return std::numeric_limits<T>::max();
}
}

template <typename T>
constexpr T get_dummy(bool greater)
{
return greater ? get_lower_bound<T>() : get_upper_bound<T>();
}

template <bool greater, typename T>
__device__ inline bool is_greater_than(T val, T baseline)
{
Expand Down Expand Up @@ -217,7 +191,7 @@ class WarpSort {
template <int capacity, bool greater, typename T, typename IdxT>
class WarpSelect : public WarpSort<capacity, greater, T, IdxT> {
public:
__device__ WarpSelect(IdxT k, T dummy)
__device__ WarpSelect(int k, T dummy)
: WarpSort<capacity, greater, T, IdxT>(k, dummy),
buf_len_(0),
k_th_(dummy),
Expand Down Expand Up @@ -335,7 +309,7 @@ class WarpSelect : public WarpSort<capacity, greater, T, IdxT> {
template <int capacity, bool greater, typename T, typename IdxT>
class WarpBitonic : public WarpSort<capacity, greater, T, IdxT> {
public:
__device__ WarpBitonic(IdxT k, T dummy)
__device__ WarpBitonic(int k, T dummy)
: WarpSort<capacity, greater, T, IdxT>(k, dummy), buf_len_(0)
{
for (int i = 0; i < max_arr_len_; ++i) {
Expand Down Expand Up @@ -465,7 +439,7 @@ class WarpBitonic : public WarpSort<capacity, greater, T, IdxT> {
template <int capacity, bool greater, typename T, typename IdxT>
class WarpMerge : public WarpSort<capacity, greater, T, IdxT> {
public:
__device__ WarpMerge(IdxT k, T dummy) : WarpSort<capacity, greater, T, IdxT>(k, dummy) {}
__device__ WarpMerge(int k, T dummy) : WarpSort<capacity, greater, T, IdxT>(k, dummy) {}

__device__ void add(const T* in, const IdxT* in_idx, IdxT start, IdxT end)
{
Expand Down Expand Up @@ -507,7 +481,7 @@ template <template <int, bool, typename, typename> class WarpSortWarpWide,
typename IdxT>
class WarpSortBlockWide {
public:
__device__ WarpSortBlockWide(IdxT k, T dummy, void* smem_buf)
__device__ WarpSortBlockWide(int k, T dummy, void* smem_buf)
: queue_(k, dummy), k_(k), dummy_(dummy)
{
val_smem_ = static_cast<T*>(smem_buf);
Expand Down Expand Up @@ -601,14 +575,8 @@ template <template <int, bool, typename, typename> class WarpSortClass,
bool greater,
typename T,
typename IdxT>
__global__ void block_kernel(const T* in,
const IdxT* in_idx,
IdxT batch_size,
IdxT len,
IdxT k,
T* out,
IdxT* out_idx,
T dummy)
__global__ void block_kernel(
const T* in, const IdxT* in_idx, IdxT batch_size, IdxT len, int k, T* out, IdxT* out_idx, T dummy)
{
extern __shared__ __align__(sizeof(T) * 256) uint8_t smem_buf_bytes[];
T* smem_buf = (T*)smem_buf_bytes;
Expand Down Expand Up @@ -702,7 +670,7 @@ struct launch_setup {
}
}
ASSERT(capacity <= Capacity, "Requested k is too big (%d)", k);
T dummy = get_dummy<T>(greater);
T dummy = greater ? lower_bound<T>() : upper_bound<T>();
if (greater) {
block_kernel<WarpSortClass, Capacity, true>
<<<batch_size * num_blocks, block_dim, smem_size, stream>>>(
Expand Down Expand Up @@ -822,7 +790,6 @@ void warp_sort_topk_(int num_of_block,
rmm::device_uvector<IdxT> tmp_idx(num_of_block * k * batch_size, stream);

// printf("#block=%d, #warp=%d\n", num_of_block, num_of_warp);
// T dummy = get_dummy<T>(greater);
int capacity = calc_capacity(k);

T* result_val = (num_of_block == 1) ? out : tmp_val.data();
Expand Down
16 changes: 11 additions & 5 deletions cpp/include/raft/spatial/knn/detail/selection_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ namespace spatial {
namespace knn {
namespace detail {

template <typename key_t, typename payload_t>
constexpr int kFaissMaxK()
{
return (sizeof(key_t) + sizeof(payload_t) > 8) ? 512 : 1024;
}

template <typename key_t, typename payload_t, bool select_min, int warp_q, int thread_q, int tpb>
__global__ void select_k_kernel(key_t* inK,
payload_t* inV,
Expand Down Expand Up @@ -100,8 +106,7 @@ inline void select_k_impl(key_t* inK,
constexpr int n_threads = (warp_q <= 1024) ? 128 : 64;
auto block = dim3(n_threads);

auto kInit =
select_min ? faiss::gpu::Limits<key_t>::getMax() : faiss::gpu::Limits<key_t>::getMin();
auto kInit = select_min ? upper_bound<key_t>() : lower_bound<key_t>();
auto vInit = -1;
if (select_min) {
select_k_kernel<key_t, payload_t, false, warp_q, thread_q, n_threads>
Expand Down Expand Up @@ -138,6 +143,7 @@ inline void select_k(key_t* inK,
int k,
cudaStream_t stream)
{
constexpr int max_k = kFaissMaxK<payload_t, key_t>();
if (k == 1)
select_k_impl<payload_t, key_t, 1, 1>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
Expand All @@ -156,11 +162,11 @@ inline void select_k(key_t* inK,
else if (k <= 512)
select_k_impl<payload_t, key_t, 512, 8>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else if (k <= 1024)
select_k_impl<payload_t, key_t, 1024, 8>(
else if (k <= 1024 && k <= max_k)
select_k_impl<payload_t, key_t, max_k, 8>(
inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream);
else
ASSERT(k <= 1024, "Current max k is 1024 (requested %d)", k);
ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k);
}

}; // namespace detail
Expand Down
83 changes: 55 additions & 28 deletions cpp/test/spatial/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ struct SelectInOutSimple {
template <typename KeyT, typename IdxT>
struct SelectInOutComputed {
public:
bool not_supported = false;

SelectInOutComputed(const SelectTestSpec& spec,
knn::SelectKAlgo algo,
const std::vector<KeyT>& in_dists,
Expand All @@ -104,6 +106,23 @@ struct SelectInOutComputed {
out_dists_(spec.n_inputs * spec.k),
out_ids_(spec.n_inputs * spec.k)
{
// check if the size is supported by the algorithm
switch (algo) {
case knn::SelectKAlgo::WARP_SORT:
if (spec.k > raft::spatial::knn::detail::ivf_flat::kMaxCapacity) {
not_supported = true;
return;
}
break;
case knn::SelectKAlgo::FAISS:
if (spec.k > raft::spatial::knn::detail::kFaissMaxK<IdxT, KeyT>()) {
not_supported = true;
return;
}
break;
default: break;
}

auto stream = rmm::cuda_stream_default;

rmm::device_uvector<KeyT> in_dists_d(in_dists_.size(), stream);
Expand Down Expand Up @@ -234,6 +253,7 @@ class SelectionTest : public testing::TestWithParam<typename ParamsReader<KeyT,

void run()
{
if (res.not_supported) { GTEST_SKIP(); }
ASSERT_TRUE(hostArrMatch(ref.get_out_dists().data(),
res.get_out_dists().data(),
spec.n_inputs * spec.k,
Expand Down Expand Up @@ -325,39 +345,46 @@ struct with_ref {
};
};

auto inputs_random_f = testing::Values(SelectTestSpec{20, 700, 1, true},
SelectTestSpec{20, 700, 2, true},
SelectTestSpec{20, 700, 3, true},
SelectTestSpec{20, 700, 4, true},
SelectTestSpec{20, 700, 5, true},
SelectTestSpec{20, 700, 6, true},
SelectTestSpec{20, 700, 7, true},
SelectTestSpec{20, 700, 8, true},
SelectTestSpec{20, 700, 9, true},
SelectTestSpec{20, 700, 10, true},
SelectTestSpec{20, 700, 11, true},
SelectTestSpec{20, 700, 12, true},
SelectTestSpec{20, 700, 16, true},
SelectTestSpec{100, 1700, 17, true},
SelectTestSpec{100, 1700, 31, true},
SelectTestSpec{100, 1700, 32, false},
SelectTestSpec{100, 1700, 33, false},
SelectTestSpec{100, 1700, 63, false},
SelectTestSpec{100, 1700, 64, false},
SelectTestSpec{100, 1700, 65, false},
SelectTestSpec{100, 1700, 255, true},
SelectTestSpec{100, 1700, 256, true},
SelectTestSpec{100, 1700, 511, false},
SelectTestSpec{100, 1700, 512, true},
SelectTestSpec{100, 1700, 1023, false},
SelectTestSpec{100, 1700, 1024, true},
SelectTestSpec{100, 1700, 1700, true});
auto inputs_random = testing::Values(SelectTestSpec{20, 700, 1, true},
SelectTestSpec{20, 700, 2, true},
SelectTestSpec{20, 700, 3, true},
SelectTestSpec{20, 700, 4, true},
SelectTestSpec{20, 700, 5, true},
SelectTestSpec{20, 700, 6, true},
SelectTestSpec{20, 700, 7, true},
SelectTestSpec{20, 700, 8, true},
SelectTestSpec{20, 700, 9, true},
SelectTestSpec{20, 700, 10, true},
SelectTestSpec{20, 700, 11, true},
SelectTestSpec{20, 700, 12, true},
SelectTestSpec{20, 700, 16, true},
SelectTestSpec{100, 1700, 17, true},
SelectTestSpec{100, 1700, 31, true},
SelectTestSpec{100, 1700, 32, false},
SelectTestSpec{100, 1700, 33, false},
SelectTestSpec{100, 1700, 63, false},
SelectTestSpec{100, 1700, 64, false},
SelectTestSpec{100, 1700, 65, false},
SelectTestSpec{100, 1700, 255, true},
SelectTestSpec{100, 1700, 256, true},
SelectTestSpec{100, 1700, 511, false},
SelectTestSpec{100, 1700, 512, true},
SelectTestSpec{100, 1700, 1023, false},
SelectTestSpec{100, 1700, 1024, true},
SelectTestSpec{100, 1700, 1700, true});

typedef SelectionTest<float, int, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomFloatInt;
TEST_P(ReferencedRandomFloatInt, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomFloatInt,
testing::Combine(inputs_random_f, selection_algos));
testing::Combine(inputs_random, selection_algos));

typedef SelectionTest<double, int, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomDoubleInt;
TEST_P(ReferencedRandomDoubleInt, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomDoubleInt,
testing::Combine(inputs_random, selection_algos));

} // namespace raft::spatial::selection

0 comments on commit 76b1a8b

Please sign in to comment.