diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index d157a57f52..010bd5aaac 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -38,8 +39,8 @@ constexpr int kFaissMaxK() } template -__global__ void select_k_kernel(key_t* inK, - payload_t* inV, +__global__ void select_k_kernel(const key_t* inK, + const payload_t* inV, size_t n_rows, size_t n_cols, key_t* outK, @@ -48,7 +49,8 @@ __global__ void select_k_kernel(key_t* inK, payload_t initV, int k) { - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + using align_warp = Pow2; + constexpr int kNumWarps = align_warp::div(tpb); __shared__ key_t smemK[kNumWarps * warp_q]; __shared__ payload_t smemV[kNumWarps * warp_q]; @@ -59,27 +61,21 @@ __global__ void select_k_kernel(key_t* inK, // Grid is exactly sized to rows available int row = blockIdx.x; - int i = threadIdx.x; + { + size_t i = size_t(threadIdx.x); - int idx = row * n_cols; - key_t* inKStart = inK + idx + i; - payload_t* inVStart = inV + idx + i; + inK += row * n_cols; + if (inV != nullptr) { inV += row * n_cols; } - // Whole warps must participate in the selection - int limit = faiss::gpu::utils::roundDown(n_cols, faiss::gpu::kWarpSize); + // Whole warps must participate in the selection + size_t limit = align_warp::roundDown(n_cols); - for (; i < limit; i += tpb) { - inKStart = inK + idx + i; - inVStart = inV + idx + i; + for (; i < limit; i += tpb) { + heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); + } - heap.add(*inKStart, *inVStart); - } - - // Handle last remainder fraction of a warp of elements - if (i < n_cols) { - inKStart = inK + idx + i; - inVStart = inV + idx + i; - heap.addThreadQ(*inKStart, *inVStart); + // Handle last remainder fraction of a warp of elements + if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); } } heap.reduce(); @@ -91,8 +87,8 @@ __global__ void select_k_kernel(key_t* inK, } template -inline void select_k_impl(key_t* inK, - payload_t* inV, +inline void select_k_impl(const key_t* inK, + const payload_t* inV, size_t n_rows, size_t n_cols, key_t* outK, @@ -133,8 +129,8 @@ inline void select_k_impl(key_t* inK, * @param[in] stream CUDA stream to use */ template -inline void select_k(key_t* inK, - payload_t* inV, +inline void select_k(const key_t* inK, + const payload_t* inV, size_t n_rows, size_t n_cols, key_t* outK, @@ -172,4 +168,4 @@ inline void select_k(key_t* inK, }; // namespace detail }; // namespace knn }; // namespace spatial -}; // namespace raft \ No newline at end of file +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 8765a7c30a..52e7e31cc2 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -98,6 +98,9 @@ enum class SelectKAlgo { * @param[in] in_values * contiguous device array of inputs of size (input_len * n_inputs); * typically, these are indices of the corresponding in_keys. + * You can pass `NULL` as an argument here; this would imply `in_values` is a homogeneous array + * of indices from `0` to `input_len - 1` for every input and reduce the usage of memory + * bandwidth. * @param[in] n_inputs * number of input rows, i.e. the batch size. * @param[in] input_len @@ -118,8 +121,8 @@ enum class SelectKAlgo { * the implementation of the algorithm */ template -inline void select_k(value_t* in_keys, - idx_t* in_values, +inline void select_k(const value_t* in_keys, + const idx_t* in_values, size_t n_inputs, size_t input_len, value_t* out_keys, diff --git a/cpp/test/spatial/selection.cu b/cpp/test/spatial/selection.cu index 862fad56b4..3f2738fc02 100644 --- a/cpp/test/spatial/selection.cu +++ b/cpp/test/spatial/selection.cu @@ -38,6 +38,7 @@ struct SelectTestSpec { int input_len; int k; int select_min; + bool use_index_input = true; }; std::ostream& operator<<(std::ostream& os, const SelectTestSpec& ss) @@ -129,7 +130,7 @@ struct SelectInOutComputed { update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream); raft::spatial::knn::select_k(in_dists_d.data(), - in_ids_d.data(), + spec.use_index_input ? in_ids_d.data() : nullptr, spec.n_inputs, spec.input_len, out_dists_d.data(), @@ -242,11 +243,6 @@ class SelectionTest : public testing::TestWithParam struct params_simple { using InOut = SelectInOutSimple; @@ -268,32 +264,54 @@ struct params_simple { auto inputs_simple_f = testing::Values( params_simple::Inputs( - {5, 5, 5, true}, + {5, 5, 5, true, true}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), params_simple::Inputs( - {5, 5, 3, true}, + {5, 5, 3, true, true}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), params_simple::Inputs( - {5, 7, 3, true}, + {5, 5, 5, true, false}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, + {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), + params_simple::Inputs( + {5, 5, 3, true, false}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, + {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), + params_simple::Inputs( + {5, 7, 3, true, true}, {5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2}, {1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2}, {4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}), params_simple::Inputs( - {1, 7, 3, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}), + {1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}), + params_simple::Inputs( + {1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}), params_simple::Inputs( - {1, 7, 3, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}), + {1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}), params_simple::Inputs( - {1, 7, 3, false}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}), + {1, 130, 5, false, true}, + {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + {20, 19, 18, 17, 16}, + {129, 0, 117, 116, 115}), params_simple::Inputs( - {1, 130, 15, false}, + {1, 130, 15, false, true}, {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, @@ -306,7 +324,11 @@ typedef SelectionTest SimpleFloatInt; TEST_P(SimpleFloatInt, Run) { run(); } INSTANTIATE_TEST_CASE_P(SelectionTest, SimpleFloatInt, - testing::Combine(inputs_simple_f, selection_algos)); + testing::Combine(inputs_simple_f, + testing::Values(knn::SelectKAlgo::FAISS, + knn::SelectKAlgo::RADIX_8_BITS, + knn::SelectKAlgo::RADIX_11_BITS, + knn::SelectKAlgo::WARP_SORT))); template struct with_ref { @@ -334,50 +356,93 @@ struct with_ref { }; }; -auto inputs_random = testing::Values(SelectTestSpec{1, 130, 15, false}, - SelectTestSpec{1, 128, 15, false}, - SelectTestSpec{20, 700, 1, true}, - SelectTestSpec{20, 700, 2, true}, - SelectTestSpec{20, 700, 3, true}, - SelectTestSpec{20, 700, 4, true}, - SelectTestSpec{20, 700, 5, true}, - SelectTestSpec{20, 700, 6, true}, - SelectTestSpec{20, 700, 7, true}, - SelectTestSpec{20, 700, 8, true}, - SelectTestSpec{20, 700, 9, true}, - SelectTestSpec{20, 700, 10, true}, - SelectTestSpec{20, 700, 11, true}, - SelectTestSpec{20, 700, 12, true}, - SelectTestSpec{20, 700, 16, true}, - SelectTestSpec{100, 1700, 17, true}, - SelectTestSpec{100, 1700, 31, true}, - SelectTestSpec{100, 1700, 32, false}, - SelectTestSpec{100, 1700, 33, false}, - SelectTestSpec{100, 1700, 63, false}, - SelectTestSpec{100, 1700, 64, false}, - SelectTestSpec{100, 1700, 65, false}, - SelectTestSpec{100, 1700, 255, true}, - SelectTestSpec{100, 1700, 256, true}, - SelectTestSpec{100, 1700, 511, false}, - SelectTestSpec{100, 1700, 512, true}, - SelectTestSpec{100, 1700, 1023, false}, - SelectTestSpec{100, 1700, 1024, true}, - SelectTestSpec{100, 1700, 1700, true}, - SelectTestSpec{10000, 100, 100, false}, - SelectTestSpec{10000, 200, 100, false}); - -typedef SelectionTest::params_random> +auto inputs_random_longlist = testing::Values(SelectTestSpec{1, 130, 15, false}, + SelectTestSpec{1, 128, 15, false}, + SelectTestSpec{20, 700, 1, true}, + SelectTestSpec{20, 700, 2, true}, + SelectTestSpec{20, 700, 3, true}, + SelectTestSpec{20, 700, 4, true}, + SelectTestSpec{20, 700, 5, true}, + SelectTestSpec{20, 700, 6, true}, + SelectTestSpec{20, 700, 7, true}, + SelectTestSpec{20, 700, 8, true}, + SelectTestSpec{20, 700, 9, true}, + SelectTestSpec{20, 700, 10, true, false}, + SelectTestSpec{20, 700, 11, true}, + SelectTestSpec{20, 700, 12, true}, + SelectTestSpec{20, 700, 16, true}, + SelectTestSpec{100, 1700, 17, true}, + SelectTestSpec{100, 1700, 31, true, false}, + SelectTestSpec{100, 1700, 32, false}, + SelectTestSpec{100, 1700, 33, false}, + SelectTestSpec{100, 1700, 63, false}, + SelectTestSpec{100, 1700, 64, false, false}, + SelectTestSpec{100, 1700, 65, false}, + SelectTestSpec{100, 1700, 255, true}, + SelectTestSpec{100, 1700, 256, true}, + SelectTestSpec{100, 1700, 511, false}, + SelectTestSpec{100, 1700, 512, true}, + SelectTestSpec{100, 1700, 1023, false, false}, + SelectTestSpec{100, 1700, 1024, true}, + SelectTestSpec{100, 1700, 1700, true}); + +auto inputs_random_largesize = testing::Values(SelectTestSpec{100, 100000, 1, true}, + SelectTestSpec{100, 100000, 2, true}, + SelectTestSpec{100, 100000, 3, true, false}, + SelectTestSpec{100, 100000, 7, true}, + SelectTestSpec{100, 100000, 16, true}, + SelectTestSpec{100, 100000, 31, true}, + SelectTestSpec{100, 100000, 32, true, false}, + SelectTestSpec{100, 100000, 60, true}, + SelectTestSpec{100, 100000, 100, true, false}, + SelectTestSpec{100, 100000, 200, true}, + SelectTestSpec{100000, 100, 100, false}, + SelectTestSpec{1, 1000000000, 1, true}, + SelectTestSpec{1, 1000000000, 16, false, false}, + SelectTestSpec{1, 1000000000, 64, false}, + SelectTestSpec{1, 1000000000, 128, true, false}, + SelectTestSpec{1, 1000000000, 256, false, false}); + +auto inputs_random_largek = testing::Values(SelectTestSpec{100, 100000, 1000, true}, + SelectTestSpec{100, 100000, 2000, true}, + SelectTestSpec{100, 100000, 100000, true, false}, + SelectTestSpec{100, 100000, 2048, false}, + SelectTestSpec{100, 100000, 1237, true}); + +typedef SelectionTest::params_random> ReferencedRandomFloatInt; TEST_P(ReferencedRandomFloatInt, Run) { run(); } INSTANTIATE_TEST_CASE_P(SelectionTest, ReferencedRandomFloatInt, - testing::Combine(inputs_random, selection_algos)); + testing::Combine(inputs_random_longlist, + testing::Values(knn::SelectKAlgo::RADIX_8_BITS, + knn::SelectKAlgo::RADIX_11_BITS, + knn::SelectKAlgo::WARP_SORT))); + +typedef SelectionTest::params_random> + ReferencedRandomDoubleSizeT; +TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } +INSTANTIATE_TEST_CASE_P(SelectionTest, + ReferencedRandomDoubleSizeT, + testing::Combine(inputs_random_longlist, + testing::Values(knn::SelectKAlgo::RADIX_8_BITS, + knn::SelectKAlgo::RADIX_11_BITS, + knn::SelectKAlgo::WARP_SORT))); -typedef SelectionTest::params_random> +typedef SelectionTest::params_random> ReferencedRandomDoubleInt; -TEST_P(ReferencedRandomDoubleInt, Run) { run(); } +TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } INSTANTIATE_TEST_CASE_P(SelectionTest, ReferencedRandomDoubleInt, - testing::Combine(inputs_random, selection_algos)); + testing::Combine(inputs_random_largesize, + testing::Values(knn::SelectKAlgo::WARP_SORT))); + +typedef SelectionTest::params_random> + ReferencedRandomFloatSizeT; +TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } +INSTANTIATE_TEST_CASE_P(SelectionTest, + ReferencedRandomFloatSizeT, + testing::Combine(inputs_random_largek, + testing::Values(knn::SelectKAlgo::RADIX_11_BITS))); } // namespace raft::spatial::selection