Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow nullptr as input-indices argument of select_k #618

Merged
merged 3 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions cpp/include/raft/spatial/knn/detail/selection_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/cudart_utils.h>
#include <raft/pow2_utils.cuh>
#include <raft/spatial/knn/faiss_mr.hpp>

#include <faiss/gpu/GpuDistance.h>
Expand All @@ -38,8 +39,8 @@ constexpr int kFaissMaxK()
}

template <typename key_t, typename payload_t, bool select_min, int warp_q, int thread_q, int tpb>
__global__ void select_k_kernel(key_t* inK,
payload_t* inV,
__global__ void select_k_kernel(const key_t* inK,
const payload_t* inV,
size_t n_rows,
size_t n_cols,
key_t* outK,
Expand All @@ -48,7 +49,8 @@ __global__ void select_k_kernel(key_t* inK,
payload_t initV,
int k)
{
constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize;
using align_warp = Pow2<WarpSize>;
constexpr int kNumWarps = align_warp::div(tpb);

__shared__ key_t smemK[kNumWarps * warp_q];
__shared__ payload_t smemV[kNumWarps * warp_q];
Expand All @@ -59,27 +61,21 @@ __global__ void select_k_kernel(key_t* inK,

// Grid is exactly sized to rows available
int row = blockIdx.x;
int i = threadIdx.x;
{
size_t i = size_t(threadIdx.x);

int idx = row * n_cols;
key_t* inKStart = inK + idx + i;
payload_t* inVStart = inV + idx + i;
inK += row * n_cols;
if (inV != nullptr) { inV += row * n_cols; }

// Whole warps must participate in the selection
int limit = faiss::gpu::utils::roundDown(n_cols, faiss::gpu::kWarpSize);
// Whole warps must participate in the selection
size_t limit = align_warp::roundDown(n_cols);

for (; i < limit; i += tpb) {
inKStart = inK + idx + i;
inVStart = inV + idx + i;
for (; i < limit; i += tpb) {
heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i));
}

heap.add(*inKStart, *inVStart);
}

// Handle last remainder fraction of a warp of elements
if (i < n_cols) {
inKStart = inK + idx + i;
inVStart = inV + idx + i;
heap.addThreadQ(*inKStart, *inVStart);
// Handle last remainder fraction of a warp of elements
if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); }
}

heap.reduce();
Expand All @@ -91,8 +87,8 @@ __global__ void select_k_kernel(key_t* inK,
}

template <typename payload_t = int, typename key_t = float, int warp_q, int thread_q>
inline void select_k_impl(key_t* inK,
payload_t* inV,
inline void select_k_impl(const key_t* inK,
const payload_t* inV,
size_t n_rows,
size_t n_cols,
key_t* outK,
Expand Down Expand Up @@ -133,8 +129,8 @@ inline void select_k_impl(key_t* inK,
* @param[in] stream CUDA stream to use
*/
template <typename payload_t = int, typename key_t = float>
inline void select_k(key_t* inK,
payload_t* inV,
inline void select_k(const key_t* inK,
const payload_t* inV,
size_t n_rows,
size_t n_cols,
key_t* outK,
Expand Down Expand Up @@ -172,4 +168,4 @@ inline void select_k(key_t* inK,
}; // namespace detail
}; // namespace knn
}; // namespace spatial
}; // namespace raft
}; // namespace raft
7 changes: 5 additions & 2 deletions cpp/include/raft/spatial/knn/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ enum class SelectKAlgo {
* @param[in] in_values
* contiguous device array of inputs of size (input_len * n_inputs);
* typically, these are indices of the corresponding in_keys.
* You can pass `NULL` as an argument here; this would imply `in_values` is a homogeneous array
* of indices from `0` to `input_len - 1` for every input and reduce the usage of memory
* bandwidth.
* @param[in] n_inputs
* number of input rows, i.e. the batch size.
* @param[in] input_len
Expand All @@ -118,8 +121,8 @@ enum class SelectKAlgo {
* the implementation of the algorithm
*/
template <typename idx_t = int, typename value_t = float>
inline void select_k(value_t* in_keys,
idx_t* in_values,
inline void select_k(const value_t* in_keys,
const idx_t* in_values,
size_t n_inputs,
size_t input_len,
value_t* out_keys,
Expand Down
167 changes: 116 additions & 51 deletions cpp/test/spatial/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct SelectTestSpec {
int input_len;
int k;
int select_min;
bool use_index_input = true;
};

std::ostream& operator<<(std::ostream& os, const SelectTestSpec& ss)
Expand Down Expand Up @@ -129,7 +130,7 @@ struct SelectInOutComputed {
update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream);

raft::spatial::knn::select_k<IdxT, KeyT>(in_dists_d.data(),
in_ids_d.data(),
spec.use_index_input ? in_ids_d.data() : nullptr,
spec.n_inputs,
spec.input_len,
out_dists_d.data(),
Expand Down Expand Up @@ -242,11 +243,6 @@ class SelectionTest : public testing::TestWithParam<typename ParamsReader<KeyT,
}
};

auto selection_algos = testing::Values(knn::SelectKAlgo::FAISS,
knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT);

template <typename KeyT, typename IdxT>
struct params_simple {
using InOut = SelectInOutSimple<KeyT, IdxT>;
Expand All @@ -268,32 +264,54 @@ struct params_simple {

auto inputs_simple_f = testing::Values(
params_simple<float, int>::Inputs(
{5, 5, 5, true},
{5, 5, 5, true, true},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0,
4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0},
{4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}),
params_simple<float, int>::Inputs(
{5, 5, 3, true},
{5, 5, 3, true, true},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0},
{4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}),
params_simple<float, int>::Inputs(
{5, 7, 3, true},
{5, 5, 5, true, false},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0,
4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0},
{4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}),
params_simple<float, int>::Inputs(
{5, 5, 3, true, false},
{5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0,
1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0},
{1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0},
{4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}),
params_simple<float, int>::Inputs(
{5, 7, 3, true, true},
{5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0,
4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2},
{1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2},
{4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}),
params_simple<float, int>::Inputs(
{1, 7, 3, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}),
{1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}),
params_simple<float, int>::Inputs(
{1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}),
params_simple<float, int>::Inputs(
{1, 7, 3, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}),
{1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}),
params_simple<float, int>::Inputs(
{1, 7, 3, false}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}),
{1, 130, 5, false, true},
{19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2,
1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20},
{20, 19, 18, 17, 16},
{129, 0, 117, 116, 115}),
params_simple<float, int>::Inputs(
{1, 130, 15, false},
{1, 130, 15, false, true},
{19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2,
Expand All @@ -306,7 +324,11 @@ typedef SelectionTest<float, int, params_simple> SimpleFloatInt;
TEST_P(SimpleFloatInt, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
SimpleFloatInt,
testing::Combine(inputs_simple_f, selection_algos));
testing::Combine(inputs_simple_f,
testing::Values(knn::SelectKAlgo::FAISS,
knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT)));

template <knn::SelectKAlgo RefAlgo>
struct with_ref {
Expand All @@ -333,50 +355,93 @@ struct with_ref {
};
};

auto inputs_random = testing::Values(SelectTestSpec{1, 130, 15, false},
SelectTestSpec{1, 128, 15, false},
SelectTestSpec{20, 700, 1, true},
SelectTestSpec{20, 700, 2, true},
SelectTestSpec{20, 700, 3, true},
SelectTestSpec{20, 700, 4, true},
SelectTestSpec{20, 700, 5, true},
SelectTestSpec{20, 700, 6, true},
SelectTestSpec{20, 700, 7, true},
SelectTestSpec{20, 700, 8, true},
SelectTestSpec{20, 700, 9, true},
SelectTestSpec{20, 700, 10, true},
SelectTestSpec{20, 700, 11, true},
SelectTestSpec{20, 700, 12, true},
SelectTestSpec{20, 700, 16, true},
SelectTestSpec{100, 1700, 17, true},
SelectTestSpec{100, 1700, 31, true},
SelectTestSpec{100, 1700, 32, false},
SelectTestSpec{100, 1700, 33, false},
SelectTestSpec{100, 1700, 63, false},
SelectTestSpec{100, 1700, 64, false},
SelectTestSpec{100, 1700, 65, false},
SelectTestSpec{100, 1700, 255, true},
SelectTestSpec{100, 1700, 256, true},
SelectTestSpec{100, 1700, 511, false},
SelectTestSpec{100, 1700, 512, true},
SelectTestSpec{100, 1700, 1023, false},
SelectTestSpec{100, 1700, 1024, true},
SelectTestSpec{100, 1700, 1700, true},
SelectTestSpec{10000, 100, 100, false},
SelectTestSpec{10000, 200, 100, false});

typedef SelectionTest<float, int, with_ref<knn::SelectKAlgo::RADIX_8_BITS>::params_random>
auto inputs_random_longlist = testing::Values(SelectTestSpec{1, 130, 15, false},
SelectTestSpec{1, 128, 15, false},
SelectTestSpec{20, 700, 1, true},
SelectTestSpec{20, 700, 2, true},
SelectTestSpec{20, 700, 3, true},
SelectTestSpec{20, 700, 4, true},
SelectTestSpec{20, 700, 5, true},
SelectTestSpec{20, 700, 6, true},
SelectTestSpec{20, 700, 7, true},
SelectTestSpec{20, 700, 8, true},
SelectTestSpec{20, 700, 9, true},
SelectTestSpec{20, 700, 10, true, false},
SelectTestSpec{20, 700, 11, true},
SelectTestSpec{20, 700, 12, true},
SelectTestSpec{20, 700, 16, true},
SelectTestSpec{100, 1700, 17, true},
SelectTestSpec{100, 1700, 31, true, false},
SelectTestSpec{100, 1700, 32, false},
SelectTestSpec{100, 1700, 33, false},
SelectTestSpec{100, 1700, 63, false},
SelectTestSpec{100, 1700, 64, false, false},
SelectTestSpec{100, 1700, 65, false},
SelectTestSpec{100, 1700, 255, true},
SelectTestSpec{100, 1700, 256, true},
SelectTestSpec{100, 1700, 511, false},
SelectTestSpec{100, 1700, 512, true},
SelectTestSpec{100, 1700, 1023, false, false},
SelectTestSpec{100, 1700, 1024, true},
SelectTestSpec{100, 1700, 1700, true});

auto inputs_random_largesize = testing::Values(SelectTestSpec{100, 100000, 1, true},
SelectTestSpec{100, 100000, 2, true},
SelectTestSpec{100, 100000, 3, true, false},
SelectTestSpec{100, 100000, 7, true},
SelectTestSpec{100, 100000, 16, true},
SelectTestSpec{100, 100000, 31, true},
SelectTestSpec{100, 100000, 32, true, false},
SelectTestSpec{100, 100000, 60, true},
SelectTestSpec{100, 100000, 100, true, false},
SelectTestSpec{100, 100000, 200, true},
SelectTestSpec{100000, 100, 100, false},
SelectTestSpec{1, 1000000000, 1, true},
SelectTestSpec{1, 1000000000, 16, false, false},
SelectTestSpec{1, 1000000000, 64, false},
SelectTestSpec{1, 1000000000, 128, true, false},
SelectTestSpec{1, 1000000000, 256, false, false});

auto inputs_random_largek = testing::Values(SelectTestSpec{100, 100000, 1000, true},
SelectTestSpec{100, 100000, 2000, true},
SelectTestSpec{100, 100000, 100000, true, false},
SelectTestSpec{100, 100000, 2048, false},
SelectTestSpec{100, 100000, 1237, true});

typedef SelectionTest<float, int, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomFloatInt;
TEST_P(ReferencedRandomFloatInt, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomFloatInt,
testing::Combine(inputs_random, selection_algos));
testing::Combine(inputs_random_longlist,
testing::Values(knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT)));

typedef SelectionTest<double, size_t, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomDoubleSizeT;
TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomDoubleSizeT,
testing::Combine(inputs_random_longlist,
testing::Values(knn::SelectKAlgo::RADIX_8_BITS,
knn::SelectKAlgo::RADIX_11_BITS,
knn::SelectKAlgo::WARP_SORT)));

typedef SelectionTest<double, int, with_ref<knn::SelectKAlgo::RADIX_8_BITS>::params_random>
typedef SelectionTest<double, int, with_ref<knn::SelectKAlgo::FAISS>::params_random>
ReferencedRandomDoubleInt;
TEST_P(ReferencedRandomDoubleInt, Run) { run(); }
TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomDoubleInt,
testing::Combine(inputs_random, selection_algos));
testing::Combine(inputs_random_largesize,
testing::Values(knn::SelectKAlgo::WARP_SORT)));

typedef SelectionTest<float, size_t, with_ref<knn::SelectKAlgo::RADIX_8_BITS>::params_random>
ReferencedRandomFloatSizeT;
TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); }
INSTANTIATE_TEST_CASE_P(SelectionTest,
ReferencedRandomFloatSizeT,
testing::Combine(inputs_random_largek,
testing::Values(knn::SelectKAlgo::RADIX_11_BITS)));

} // namespace raft::spatial::selection