Skip to content

Commit

Permalink
Allow nullptr as input-indices argument of select_k (#618)
Browse files Browse the repository at this point in the history
Allow passing `nullptr` as the input-indices-argument of `select_k`. This would imply the input indices are a range from `0` to `input_len - 1` and improve the performance by doing less global memory reads.

Also makes input pointers `const` as they should be.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #618
  • Loading branch information
achirkin authored May 17, 2022
1 parent ed81462 commit 136c77a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 78 deletions.
46 changes: 21 additions & 25 deletions cpp/include/raft/spatial/knn/detail/selection_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/cudart_utils.h>
#include <raft/pow2_utils.cuh>
#include <raft/spatial/knn/faiss_mr.hpp>

#include <faiss/gpu/GpuDistance.h>
Expand All @@ -38,8 +39,8 @@ constexpr int kFaissMaxK()
}

template <typename key_t, typename payload_t, bool select_min, int warp_q, int thread_q, int tpb>
__global__ void select_k_kernel(key_t* inK,
payload_t* inV,
__global__ void select_k_kernel(const key_t* inK,
const payload_t* inV,
size_t n_rows,
size_t n_cols,
key_t* outK,
Expand All @@ -48,7 +49,8 @@ __global__ void select_k_kernel(key_t* inK,
payload_t initV,
int k)
{
constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize;
using align_warp = Pow2<WarpSize>;
constexpr int kNumWarps = align_warp::div(tpb);

__shared__ key_t smemK[kNumWarps * warp_q];
__shared__ payload_t smemV[kNumWarps * warp_q];
Expand All @@ -59,27 +61,21 @@ __global__ void select_k_kernel(key_t* inK,

// Grid is exactly sized to rows available
int row = blockIdx.x;
int i = threadIdx.x;
{
size_t i = size_t(threadIdx.x);

int idx = row * n_cols;
key_t* inKStart = inK + idx + i;
payload_t* inVStart = inV + idx + i;
inK += row * n_cols;
if (inV != nullptr) { inV += row * n_cols; }

// Whole warps must participate in the selection
int limit = faiss::gpu::utils::roundDown(n_cols, faiss::gpu::kWarpSize);
// Whole warps must participate in the selection
size_t limit = align_warp::roundDown(n_cols);

for (; i < limit; i += tpb) {
inKStart = inK + idx + i;
inVStart = inV + idx + i;
for (; i < limit; i += tpb) {
heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i));
}

heap.add(*inKStart, *inVStart);
}

// Handle last remainder fraction of a warp of elements
if (i < n_cols) {
inKStart = inK + idx + i;
inVStart = inV + idx + i;
heap.addThreadQ(*inKStart, *inVStart);
// Handle last remainder fraction of a warp of elements
if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); }
}

heap.reduce();
Expand All @@ -91,8 +87,8 @@ __global__ void select_k_kernel(key_t* inK,
}

template <typename payload_t = int, typename key_t = float, int warp_q, int thread_q>
inline void select_k_impl(key_t* inK,
payload_t* inV,
inline void select_k_impl(const key_t* inK,
const payload_t* inV,
size_t n_rows,
size_t n_cols,
key_t* outK,
Expand Down Expand Up @@ -133,8 +129,8 @@ inline void select_k_impl(key_t* inK,
* @param[in] stream CUDA stream to use
*/
template <typename payload_t = int, typename key_t = float>
inline void select_k(key_t* inK,
payload_t* inV,
inline void select_k(const key_t* inK,
const payload_t* inV,
size_t n_rows,
size_t n_cols,
key_t* outK,
Expand Down Expand Up @@ -172,4 +168,4 @@ inline void select_k(key_t* inK,
}; // namespace detail
}; // namespace knn
}; // namespace spatial
}; // namespace raft
}; // namespace raft
7 changes: 5 additions & 2 deletions cpp/include/raft/spatial/knn/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ enum class SelectKAlgo {
* @param[in] in_values
* contiguous device array of inputs of size (input_len * n_inputs);
* typically, these are indices of the corresponding in_keys.
* You can pass `NULL` as an argument here; this would imply `in_values` is a homogeneous array
* of indices from `0` to `input_len - 1` for every input and reduce the usage of memory
* bandwidth.
* @param[in] n_inputs
* number of input rows, i.e. the batch size.
* @param[in] input_len
Expand All @@ -118,8 +121,8 @@ enum class SelectKAlgo {
* the implementation of the algorithm
*/
template <typename idx_t = int, typename value_t = float>
inline void select_k(value_t* in_keys,
idx_t* in_values,
inline void select_k(const value_t* in_keys,
const idx_t* in_values,
size_t n_inputs,
size_t input_len,
value_t* out_keys,
Expand Down
167 changes: 116 additions & 51 deletions cpp/test/spatial/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct SelectTestSpec {
int input_len;
int k;
int select_min;
bool use_index_input = true;
};

std::ostream& operator<<(std::ostream& os, const SelectTestSpec& ss)
Expand Down Expand Up @@ -129,7 +130,7 @@ struct SelectInOutComputed {
update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream);

raft::spatial::knn::select_k<IdxT, KeyT>(in_dists_d.data(),
in_ids_d.data(),
spec.use_index_input ? in_ids_d.data() : nullptr,
spec.n_inputs,
spec.input_len,
out_dists_d.data(),
Expand Down Expand Up @@ -242,11 +243,6 @@ class SelectionTest : public testing::TestWithParam<typename ParamsReader<KeyT,
}
};

auto selection_algos = testing::Values(knn::SelectKAlgo::FAISS,
knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT);

template <typename KeyT, typename IdxT>
struct params_simple {
using InOut = SelectInOutSimple<KeyT, IdxT>;
Expand All @@ -268,32 +264,54 @@ struct params_simple {

auto inputs_simple_f = testing::Values(
params_simple<float, int>::Inputs(
{5, 5, 5, true},
{5, 5, 5, true, true},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0,
4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0},
{4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}),
params_simple<float, int>::Inputs(
{5, 5, 3, true},
{5, 5, 3, true, true},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0},
{4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}),
params_simple<float, int>::Inputs(
{5, 7, 3, true},
{5, 5, 5, true, false},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0,
4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0},
{4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}),
params_simple<float, int>::Inputs(
{5, 5, 3, true, false},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0},
{4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}),
params_simple<float, int>::Inputs(
{5, 7, 3, true, true},
{5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0,
4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2},
{1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2},
{4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}),
params_simple<float, int>::Inputs(
{1, 7, 3, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}),
{1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}),
params_simple<float, int>::Inputs(
{1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}),
params_simple<float, int>::Inputs(
{1, 7, 3, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}),
{1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}),
params_simple<float, int>::Inputs(
{1, 7, 3, false}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}),
{1, 130, 5, false, true},
{19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2,
1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20},
{20, 19, 18, 17, 16},
{129, 0, 117, 116, 115}),
params_simple<float, int>::Inputs(
{1, 130, 15, false},
{1, 130, 15, false, true},
{19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2,
Expand All @@ -306,7 +324,11 @@ typedef SelectionTest<float, int, params_simple> SimpleFloatInt;
TEST_P(SimpleFloatInt, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
SimpleFloatInt,
testing::Combine(inputs_simple_f, selection_algos));
testing::Combine(inputs_simple_f,
testing::Values(knn::SelectKAlgo::FAISS,
knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT)));

template <knn::SelectKAlgo RefAlgo>
struct with_ref {
Expand Down Expand Up @@ -334,50 +356,93 @@ struct with_ref {
};
};

auto inputs_random = testing::Values(SelectTestSpec{1, 130, 15, false},
SelectTestSpec{1, 128, 15, false},
SelectTestSpec{20, 700, 1, true},
SelectTestSpec{20, 700, 2, true},
SelectTestSpec{20, 700, 3, true},
SelectTestSpec{20, 700, 4, true},
SelectTestSpec{20, 700, 5, true},
SelectTestSpec{20, 700, 6, true},
SelectTestSpec{20, 700, 7, true},
SelectTestSpec{20, 700, 8, true},
SelectTestSpec{20, 700, 9, true},
SelectTestSpec{20, 700, 10, true},
SelectTestSpec{20, 700, 11, true},
SelectTestSpec{20, 700, 12, true},
SelectTestSpec{20, 700, 16, true},
SelectTestSpec{100, 1700, 17, true},
SelectTestSpec{100, 1700, 31, true},
SelectTestSpec{100, 1700, 32, false},
SelectTestSpec{100, 1700, 33, false},
SelectTestSpec{100, 1700, 63, false},
SelectTestSpec{100, 1700, 64, false},
SelectTestSpec{100, 1700, 65, false},
SelectTestSpec{100, 1700, 255, true},
SelectTestSpec{100, 1700, 256, true},
SelectTestSpec{100, 1700, 511, false},
SelectTestSpec{100, 1700, 512, true},
SelectTestSpec{100, 1700, 1023, false},
SelectTestSpec{100, 1700, 1024, true},
SelectTestSpec{100, 1700, 1700, true},
SelectTestSpec{10000, 100, 100, false},
SelectTestSpec{10000, 200, 100, false});

typedef SelectionTest<float, int, with_ref<knn::SelectKAlgo::RADIX_8_BITS>::params_random>
auto inputs_random_longlist = testing::Values(SelectTestSpec{1, 130, 15, false},
SelectTestSpec{1, 128, 15, false},
SelectTestSpec{20, 700, 1, true},
SelectTestSpec{20, 700, 2, true},
SelectTestSpec{20, 700, 3, true},
SelectTestSpec{20, 700, 4, true},
SelectTestSpec{20, 700, 5, true},
SelectTestSpec{20, 700, 6, true},
SelectTestSpec{20, 700, 7, true},
SelectTestSpec{20, 700, 8, true},
SelectTestSpec{20, 700, 9, true},
SelectTestSpec{20, 700, 10, true, false},
SelectTestSpec{20, 700, 11, true},
SelectTestSpec{20, 700, 12, true},
SelectTestSpec{20, 700, 16, true},
SelectTestSpec{100, 1700, 17, true},
SelectTestSpec{100, 1700, 31, true, false},
SelectTestSpec{100, 1700, 32, false},
SelectTestSpec{100, 1700, 33, false},
SelectTestSpec{100, 1700, 63, false},
SelectTestSpec{100, 1700, 64, false, false},
SelectTestSpec{100, 1700, 65, false},
SelectTestSpec{100, 1700, 255, true},
SelectTestSpec{100, 1700, 256, true},
SelectTestSpec{100, 1700, 511, false},
SelectTestSpec{100, 1700, 512, true},
SelectTestSpec{100, 1700, 1023, false, false},
SelectTestSpec{100, 1700, 1024, true},
SelectTestSpec{100, 1700, 1700, true});

auto inputs_random_largesize = testing::Values(SelectTestSpec{100, 100000, 1, true},
SelectTestSpec{100, 100000, 2, true},
SelectTestSpec{100, 100000, 3, true, false},
SelectTestSpec{100, 100000, 7, true},
SelectTestSpec{100, 100000, 16, true},
SelectTestSpec{100, 100000, 31, true},
SelectTestSpec{100, 100000, 32, true, false},
SelectTestSpec{100, 100000, 60, true},
SelectTestSpec{100, 100000, 100, true, false},
SelectTestSpec{100, 100000, 200, true},
SelectTestSpec{100000, 100, 100, false},
SelectTestSpec{1, 1000000000, 1, true},
SelectTestSpec{1, 1000000000, 16, false, false},
SelectTestSpec{1, 1000000000, 64, false},
SelectTestSpec{1, 1000000000, 128, true, false},
SelectTestSpec{1, 1000000000, 256, false, false});

auto inputs_random_largek = testing::Values(SelectTestSpec{100, 100000, 1000, true},
SelectTestSpec{100, 100000, 2000, true},
SelectTestSpec{100, 100000, 100000, true, false},
SelectTestSpec{100, 100000, 2048, false},
SelectTestSpec{100, 100000, 1237, true});

typedef SelectionTest<float, int, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomFloatInt;
TEST_P(ReferencedRandomFloatInt, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomFloatInt,
testing::Combine(inputs_random, selection_algos));
testing::Combine(inputs_random_longlist,
testing::Values(knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT)));

typedef SelectionTest<double, size_t, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomDoubleSizeT;
TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomDoubleSizeT,
testing::Combine(inputs_random_longlist,
testing::Values(knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT)));

typedef SelectionTest<double, int, with_ref<knn::SelectKAlgo::RADIX_8_BITS>::params_random>
typedef SelectionTest<double, int, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomDoubleInt;
TEST_P(ReferencedRandomDoubleInt, Run) { run(); }
TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomDoubleInt,
testing::Combine(inputs_random, selection_algos));
testing::Combine(inputs_random_largesize,
testing::Values(knn::SelectKAlgo::WARP_SORT)));

typedef SelectionTest<float, size_t, with_ref<knn::SelectKAlgo::RADIX_8_BITS>::params_random>
ReferencedRandomFloatSizeT;
TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomFloatSizeT,
testing::Combine(inputs_random_largek,
testing::Values(knn::SelectKAlgo::RADIX_11_BITS)));

} // namespace raft::spatial::selection

0 comments on commit 136c77a

Please sign in to comment.