From 85589acac80abc87eed065ce3da575316d163c84 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 23 Jan 2023 17:10:23 +0100 Subject: [PATCH 1/4] Squash-merge enh-matrix-topk --- cpp/bench/CMakeLists.txt | 6 +- cpp/bench/matrix/select_k.cu | 133 +++++ cpp/bench/neighbors/selection.cu | 123 ----- .../topk.cuh => matrix/detail/select_k.cuh} | 58 +-- .../detail/select_radix.cuh} | 113 +++-- .../detail/select_warpsort.cuh} | 415 +++++++++++----- cpp/include/raft/matrix/select_k.cuh | 110 +++++ cpp/include/raft/neighbors/detail/refine.cuh | 4 +- .../spatial/knn/detail/ivf_flat_search.cuh | 75 +-- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 79 ++- cpp/include/raft/spatial/knn/knn.cuh | 38 +- .../knn/detail/topk => util}/bitonic_sort.cuh | 83 ++-- cpp/include/raft/util/integer_utils.hpp | 34 +- cpp/test/CMakeLists.txt | 5 +- cpp/test/matrix/select_k.cu | 459 ++++++++++++++++++ cpp/test/matrix/select_k.cuh | 127 +++++ cpp/test/neighbors/ann_ivf_flat.cu | 8 +- cpp/test/neighbors/ann_utils.cuh | 23 +- cpp/test/neighbors/selection.cu | 2 +- cpp/test/util/bitonic_sort.cu | 200 ++++++++ docs/source/cpp_api/matrix_ordering.rst | 12 + 21 files changed, 1631 insertions(+), 476 deletions(-) create mode 100644 cpp/bench/matrix/select_k.cu delete mode 100644 cpp/bench/neighbors/selection.cu rename cpp/include/raft/{spatial/knn/detail/topk.cuh => matrix/detail/select_k.cuh} (59%) rename cpp/include/raft/{spatial/knn/detail/topk/radix_topk.cuh => matrix/detail/select_radix.cuh} (87%) rename cpp/include/raft/{spatial/knn/detail/topk/warpsort_topk.cuh => matrix/detail/select_warpsort.cuh} (71%) create mode 100644 cpp/include/raft/matrix/select_k.cuh rename cpp/include/raft/{spatial/knn/detail/topk => util}/bitonic_sort.cuh (68%) create mode 100644 cpp/test/matrix/select_k.cu create mode 100644 cpp/test/matrix/select_k.cuh create mode 100644 cpp/test/util/bitonic_sort.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 8dcdb325e9..6b985acfc3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -103,7 +103,10 @@ if(BUILD_BENCH) bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp) + ConfigureBench( + NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu + bench/main.cpp + ) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu @@ -127,7 +130,6 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu bench/neighbors/refine.cu - bench/neighbors/selection.cu bench/main.cpp OPTIONAL DIST diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu new file mode 100644 index 0000000000..452a50ba50 --- /dev/null +++ b/cpp/bench/matrix/select_k.cu @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * TODO: reconsider how to organize shared test+bench files better + * Related Issue: https://github.com/rapidsai/raft/issues/1153 + * (although this header does not depend on any gtest headers) + */ +#include "../../test/matrix/select_k.cuh" + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace raft::matrix { + +using namespace raft::bench; // NOLINT + +template +struct selection : public fixture { + explicit selection(const select::params& p) + : params_(p), + in_dists_(p.batch_size * p.len, stream), + in_ids_(p.batch_size * p.len, stream), + out_dists_(p.batch_size * p.k, stream), + out_ids_(p.batch_size * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); + raft::random::RngState state{42}; + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + } + + void run_benchmark(::benchmark::State& state) override // NOLINT + { + handle_t handle{stream}; + using_pool_memory_res res; + try { + std::ostringstream label_stream; + label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this, &handle]() { + select::select_k_impl(handle, + Algo, + in_dists_.data(), + in_ids_.data(), + params_.batch_size, + params_.len, + params_.k, + out_dists_.data(), + out_ids_.data(), + params_.select_min); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const select::params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, + {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, + {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ + } + +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT + +} // namespace raft::matrix diff --git a/cpp/bench/neighbors/selection.cu b/cpp/bench/neighbors/selection.cu deleted file mode 100644 index 1f116c199f..0000000000 --- a/cpp/bench/neighbors/selection.cu +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -#include -#include - -#include -#include - -namespace raft::bench::spatial { - -struct params { - int n_inputs; - int input_len; - int k; - int select_min; -}; - -template -struct selection : public fixture { - explicit selection(const params& p) - : params_(p), - in_dists_(p.n_inputs * p.input_len, stream), - in_ids_(p.n_inputs * p.input_len, stream), - out_dists_(p.n_inputs * p.k, stream), - out_ids_(p.n_inputs * p.k, stream) - { - raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); - raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); - } - - void run_benchmark(::benchmark::State& state) override - { - using_pool_memory_res res; - try { - std::ostringstream label_stream; - label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; - state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { - raft::spatial::knn::select_k(in_dists_.data(), - in_ids_.data(), - params_.n_inputs, - params_.input_len, - out_dists_.data(), - out_ids_.data(), - params_.select_min, - params_.k, - stream, - Algo); - }); - } catch (raft::exception& e) { - state.SkipWithError(e.what()); - } - } - - private: - const params params_; - rmm::device_uvector in_dists_, out_dists_; - rmm::device_uvector in_ids_, out_ids_; -}; - -const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, -}; - -#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ - namespace BENCHMARK_PRIVATE_NAME(selection) \ - { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ - } - -SELECTION_REGISTER(float, int, FAISS); -SELECTION_REGISTER(float, int, RADIX_8_BITS); -SELECTION_REGISTER(float, int, RADIX_11_BITS); -SELECTION_REGISTER(float, int, WARP_SORT); - -SELECTION_REGISTER(double, int, FAISS); -SELECTION_REGISTER(double, int, RADIX_8_BITS); -SELECTION_REGISTER(double, int, RADIX_11_BITS); -SELECTION_REGISTER(double, int, WARP_SORT); - -SELECTION_REGISTER(double, size_t, FAISS); -SELECTION_REGISTER(double, size_t, RADIX_8_BITS); -SELECTION_REGISTER(double, size_t, RADIX_11_BITS); -SELECTION_REGISTER(double, size_t, WARP_SORT); - -} // namespace raft::bench::spatial diff --git a/cpp/include/raft/spatial/knn/detail/topk.cuh b/cpp/include/raft/matrix/detail/select_k.cuh similarity index 59% rename from cpp/include/raft/spatial/knn/detail/topk.cuh rename to cpp/include/raft/matrix/detail/select_k.cuh index f4dcb53088..ac1ba3dfa3 100644 --- a/cpp/include/raft/spatial/knn/detail/topk.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,34 @@ #pragma once -#include "topk/radix_topk.cuh" -#include "topk/warpsort_topk.cuh" +#include "select_radix.cuh" +#include "select_warpsort.cuh" #include #include #include -namespace raft::spatial::knn::detail { +namespace raft::matrix::detail { /** * Select k smallest or largest key/values from each row in the input data. * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT * the index type (what is being selected together with the keys). * - * @param[in] in + * @param[in] in_val * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. * @param[in] in_idx * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. + * typically, these are indices of the corresponding in_val. * @param batch_size * number of input rows, i.e. the batch size. * @param len @@ -51,12 +51,12 @@ namespace raft::spatial::knn::detail { * Invariant: len >= k. * @param k * the number of outputs to select in each input row. - * @param[out] out + * @param[out] out_val * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. + * the k smallest/largest values from each row of the `in_val`. * @param[out] out_idx * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. + * the payload selected together with `out_val`. * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param stream @@ -64,28 +64,28 @@ namespace raft::spatial::knn::detail { * memory pool here to avoid memory allocations within the call). */ template -void select_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope( - "matrix::select_topk(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); // TODO (achirkin): investigate the trade-off for a wider variety of inputs. const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= raft::spatial::knn::detail::topk::kMaxCapacity && !radix_faster) { - topk::warp_sort_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { - topk::radix_topk= 4 ? 11 : 8), 512>( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } } -} // namespace raft::spatial::knn::detail +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh similarity index 87% rename from cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh rename to cpp/include/raft/matrix/detail/select_radix.cuh index 9c0f20b706..de19e63a4c 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -27,29 +28,29 @@ #include #include -#include +#include #include -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::radix { constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template -__host__ __device__ constexpr int calc_num_buckets() +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() { return 1 << BitsPerPass; } template -__host__ __device__ constexpr int calc_num_passes() +_RAFT_HOST_DEVICE constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } // Minimum reasonable block size for the given radix size. template -__host__ __device__ constexpr int calc_min_block_size() +_RAFT_HOST_DEVICE constexpr int calc_min_block_size() { return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); } @@ -62,7 +63,7 @@ __host__ __device__ constexpr int calc_min_block_size() * NB: Use pass=-1 for calc_mask(). */ template -__device__ constexpr int calc_start_bit(int pass) +_RAFT_DEVICE constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } @@ -70,7 +71,7 @@ __device__ constexpr int calc_start_bit(int pass) } template -__device__ constexpr unsigned calc_mask(int pass) +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); @@ -82,7 +83,7 @@ __device__ constexpr unsigned calc_mask(int pass) * as of integers. */ template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); @@ -91,7 +92,7 @@ __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int return (twiddle_in(x, greater) >> start_bit) & mask; @@ -112,7 +113,7 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @param f the lambda taking two arguments (T x, IdxT idx) */ template -__device__ void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -167,18 +168,18 @@ struct Counter { * (see steps 4-1 in `radix_kernel` description). */ template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - Counter* counter, - IdxT* histogram, - bool greater, - int pass, - int k) +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, + bool greater, + int pass, + int k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -260,10 +261,10 @@ __device__ void filter_and_histogram(const T* in_buf, * (step 2 in `radix_kernel` description) */ template -__device__ void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram, + const int start, + const int num_buckets, + const IdxT current) { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; @@ -284,7 +285,7 @@ __device__ void scan(volatile IdxT* histogram, * (steps 2-3 in `radix_kernel` description) */ template -__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) { constexpr int num_buckets = calc_num_buckets(); int index = threadIdx.x; @@ -547,21 +548,21 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * memory pool here to avoid memory allocations within the call). */ template -void radix_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { // reduce the block size if the input length is too small. if constexpr (BlockSize > calc_min_block_size()) { if (BlockSize * ITEM_PER_THREAD > len) { - return radix_topk( + return select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } } @@ -573,23 +574,33 @@ void radix_topk(const T* in, dim3 blocks = get_optimal_grid_size(batch_size, len); size_t max_chunk_size = blocks.y; - auto pool_guard = raft::get_pool_memory_resource( - mr, - max_chunk_size * (sizeof(Counter) // counters - + sizeof(IdxT) * (num_buckets + 2) // histograms and IdxT bufs - + sizeof(T) * 2 // T bufs - )); + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf; + size_t mem_free, mem_total; + RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); + std::optional managed_memory; + rmm::mr::device_memory_resource* mr_buf = nullptr; + if (mem_req > mem_free) { + // if there's not enough memory for buffers on the device, resort to the managed memory. + mem_req = req_aux; + managed_memory.emplace(); + mr_buf = &managed_memory.value(); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(num_buckets * max_chunk_size, stream, mr); - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { blocks.y = std::min(max_chunk_size, batch_size - offset); @@ -646,4 +657,4 @@ void radix_topk(const T* in, } } -} // namespace raft::spatial::knn::detail::topk +} // namespace raft::matrix::detail::select::radix diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh similarity index 71% rename from cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh rename to cpp/include/raft/matrix/detail/select_warpsort.cuh index c06aa04aea..d362b73792 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -16,10 +16,11 @@ #pragma once -#include "bitonic_sort.cuh" - +#include #include +#include #include +#include #include #include @@ -31,12 +32,12 @@ /* Three APIs of different scopes are provided: - 1. host function: warp_sort_topk() + 1. host function: select_k() 2. block-wide API: class block_sort 3. warp-wide API: several implementations of warp_sort_* - 1. warp_sort_topk() + 1. select_k() (see the docstring) 2. class block_sort @@ -74,7 +75,7 @@ These two classes can be regarded as fixed size priority queue for a warp. Usage is similar to class block_sort. No shared memory is needed. - The host function (warp_sort_topk) uses a heuristic to choose between these two classes for + The host function (select_k) uses a heuristic to choose between these two classes for sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small (see the usage of LaunchThreshold::len_factor_for_choosing). @@ -94,7 +95,7 @@ } */ -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::warpsort { static constexpr int kMaxCapacity = 256; @@ -102,18 +103,12 @@ namespace { /** Whether 'left` should indeed be on the left w.r.t. `right`. */ template -__device__ __forceinline__ auto is_ordered(T left, T right) -> bool +_RAFT_DEVICE _RAFT_FORCEINLINE auto is_ordered(T left, T right) -> bool { if constexpr (Ascending) { return left < right; } if constexpr (!Ascending) { return left > right; } } -constexpr auto calc_capacity(int k) -> int -{ - int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); - return capacity; -} - } // namespace /** @@ -134,7 +129,7 @@ constexpr auto calc_capacity(int k) -> int */ template class warp_sort { - static_assert(isPo2(Capacity)); + static_assert(is_a_power_of_two(Capacity)); static_assert(std::is_default_constructible_v); public: @@ -148,13 +143,16 @@ class warp_sort { /** The number of elements to select. */ const int k; + /** Extra memory required per-block for keeping the state (shared or global). */ + constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; } + /** * Construct the warp_sort empty queue. * * @param k * number of elements to select. */ - __device__ warp_sort(int k) : k(k) + _RAFT_DEVICE warp_sort(int k) : k(k) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -182,7 +180,7 @@ class warp_sort { * It serves as a conditional; when `false` the function does nothing. * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. */ - __device__ void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) + _RAFT_DEVICE void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) { if (do_merge) { int idx = Pow2::mod(laneId()) ^ Pow2::Mask; @@ -198,7 +196,7 @@ class warp_sort { } } if (kWarpWidth < WarpSize || do_merge) { - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } } @@ -211,15 +209,23 @@ class warp_sort { * @param[out] out_idx * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). + * @param valF (optional) postprocess values (T -> OutT) + * @param idxF (optional) postprocess indices (IdxT -> OutIdxT) */ - template - __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = post_process(val_arr_[i]); - out_idx[idx] = idx_arr_[i]; + out[idx] = valF(val_arr_[i]); + out_idx[idx] = idxF(idx_arr_[i]); } } @@ -246,8 +252,8 @@ class warp_sort { * the associated indices of the elements in the same format as `keys_in`. */ template - __device__ __forceinline__ void merge_in(const T* __restrict__ keys_in, - const IdxT* __restrict__ ids_in) + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) { #pragma unroll for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { @@ -258,7 +264,7 @@ class warp_sort { idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; } } - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } }; @@ -276,8 +282,9 @@ class warp_sort_filtered : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_filtered(int k, T limit) + explicit _RAFT_DEVICE warp_sort_filtered(int k, T limit = kDummy) : warp_sort(k), buf_len_(0), k_th_(limit) { #pragma unroll @@ -287,12 +294,14 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ explicit warp_sort_filtered(int k) - : warp_sort_filtered(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_filtered{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // comparing for k_th should reduce the total amount of updates: // `false` means the input value is surely not in the top-k values. @@ -310,22 +319,22 @@ class warp_sort_filtered : public warp_sort { if (do_add) { add_to_buf_(val, idx); } } - __device__ void done() + _RAFT_DEVICE void done() { if (any(buf_len_ != 0)) { merge_buf_(); } } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync @@ -335,7 +344,7 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ void add_to_buf_(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE void add_to_buf_(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -374,8 +383,9 @@ class warp_sort_distributed : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_distributed(int k, T limit) + explicit _RAFT_DEVICE warp_sort_distributed(int k, T limit = kDummy) : warp_sort(k), buf_val_(kDummy), buf_idx_(IdxT{}), @@ -384,12 +394,14 @@ class warp_sort_distributed : public warp_sort { { } - __device__ __forceinline__ explicit warp_sort_distributed(int k) - : warp_sort_distributed(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_distributed{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // mask tells which lanes in the warp have valid items to be added uint32_t mask = ballot(is_ordered(val, k_th_)); @@ -429,7 +441,7 @@ class warp_sort_distributed : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { merge_buf_(); @@ -438,16 +450,16 @@ class warp_sort_distributed : public warp_sort { } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); this->merge_in<1>(&buf_val_, &buf_idx_); set_k_th_(); // contains warp sync buf_val_ = kDummy; @@ -464,6 +476,117 @@ class warp_sort_distributed : public warp_sort { T k_th_; }; +/** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ +template +class warp_sort_distributed_ext : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + constexpr static auto mem_required(uint32_t block_size) -> size_t + { + return (sizeof(T) + sizeof(IdxT)) * block_size; + } + + _RAFT_DEVICE warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy) + : warp_sort(k), + val_buf_(val_buf), + idx_buf_(idx_buf), + buf_len_(0), + k_th_(limit) + { + val_buf_[laneId()] = kDummy; + } + + _RAFT_DEVICE static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy) + { + T* val_buf = nullptr; + IdxT* idx_buf = nullptr; + if constexpr (alignof(T) >= alignof(IdxT)) { + val_buf = reinterpret_cast(shmem); + idx_buf = reinterpret_cast(val_buf + blockDim.x); + } else { + idx_buf = reinterpret_cast(shmem); + val_buf = reinterpret_cast(idx_buf + blockDim.x); + } + auto warp_offset = Pow2::roundDown(threadIdx.x); + val_buf += warp_offset; + idx_buf += warp_offset; + return warp_sort_distributed_ext{k, val_buf, idx_buf, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + bool do_add = is_ordered(val, k_th_); + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(do_add); + if (mask == 0) { return; } + // where to put the element in the tmp buffer + int dst_ix = buf_len_ + __popc(mask & ((1u << laneId()) - 1u)); + // put all elements, which fit into the current tmp buffer + if (do_add && dst_ix < WarpSize) { + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + do_add = false; + } + // Total number of elements to be added + buf_len_ += __popc(mask); + // If the buffer is still not full, we can return + if (buf_len_ < WarpSize) { return; } + // Otherwise, merge the warp tmp buffer into the queue + merge_buf_(); // implies warp sync + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (do_add) { + dst_ix -= WarpSize; + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + __syncthreads(); + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + __syncwarp(); // make sure the threads are aware of the data written by others + T buf_val = val_buf_[laneId()]; + IdxT buf_idx = idx_buf_[laneId()]; + val_buf_[laneId()] = kDummy; + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val, buf_idx); + this->merge_in<1>(&buf_val, &buf_idx); + set_k_th_(); // contains warp sync + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T* val_buf_; + IdxT* idx_buf_; + uint32_t buf_len_; // 0 <= buf_len_ < WarpSize + + T k_th_; +}; + /** * This version of warp_sort adds every input element into the intermediate sorting * buffer, and thus does the sorting step every `Capacity` input elements. @@ -476,8 +599,10 @@ class warp_sort_immediate : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_immediate(int k) : warp_sort(k), buf_len_(0) + explicit _RAFT_DEVICE warp_sort_immediate(int k) + : warp_sort(k), buf_len_(0) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -486,7 +611,12 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, uint8_t* = nullptr) + { + return warp_sort_immediate{k}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -500,7 +630,7 @@ class warp_sort_immediate : public warp_sort { ++buf_len_; if (buf_len_ == kMaxArrLen) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -510,10 +640,10 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); } } @@ -545,15 +675,11 @@ class block_sort { using queue_t = WarpSortWarpWide; template - __device__ block_sort(int k, uint8_t* smem_buf, Args... args) : queue_(k, args...) + _RAFT_DEVICE block_sort(int k, Args... args) : queue_(queue_t::init_blockwide(k, args...)) { - val_smem_ = reinterpret_cast(smem_buf); - const int num_of_warp = subwarp_align::div(blockDim.x); - idx_smem_ = reinterpret_cast( - smem_buf + Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k)); } - __device__ void add(T val, IdxT idx) { queue_.add(val, idx); } + _RAFT_DEVICE void add(T val, IdxT idx) { queue_.add(val, idx); } /** * At the point of calling this function, the warp-level queues consumed all input @@ -561,22 +687,26 @@ class block_sort { * * Here we tree-merge the results using the shared memory and block sync. */ - __device__ void done() + _RAFT_DEVICE void done(uint8_t* smem_buf) { queue_.done(); + int nwarps = subwarp_align::div(blockDim.x); + auto val_smem = reinterpret_cast(smem_buf); + auto idx_smem = reinterpret_cast( + smem_buf + Pow2<256>::roundUp(ceildiv(nwarps, 2) * sizeof(T) * queue_.k)); + const int warp_id = subwarp_align::div(threadIdx.x); // NB: there is no need for the second __synchthreads between .load_sorted and .store: // we shift the pointers every iteration, such that individual warps either access the same // locations or do not overlap with any of the other warps. The access patterns within warps // are different for the two functions, but .load_sorted implies warp sync at the end, so // there is no need for __syncwarp either. - for (int shift_mask = ~0, nwarps = subwarp_align::div(blockDim.x), split = (nwarps + 1) >> 1; - nwarps > 1; + for (int shift_mask = ~0, split = (nwarps + 1) >> 1; nwarps > 1; nwarps = split, split = (nwarps + 1) >> 1) { if (warp_id < nwarps && warp_id >= split) { int dst_warp_shift = (warp_id - (split & shift_mask)) * queue_.k; - queue_.store(val_smem_ + dst_warp_shift, idx_smem_ + dst_warp_shift); + queue_.store(val_smem + dst_warp_shift, idx_smem + dst_warp_shift); } __syncthreads(); @@ -586,23 +716,27 @@ class block_sort { // The last argument serves as a condition for loading // -- to make sure threads within a full warp do not diverge on `bitonic::merge()` queue_.load_sorted( - val_smem_ + src_warp_shift, idx_smem_ + src_warp_shift, warp_id < nwarps - split); + val_smem + src_warp_shift, idx_smem + src_warp_shift, warp_id < nwarps - split); } } } /** Save the content by the pointer location. */ - template - __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, post_process); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, valF, idxF); } } private: using subwarp_align = Pow2; queue_t queue_; - T* val_smem_; - IdxT* idx_smem_; }; /** @@ -620,7 +754,10 @@ __launch_bounds__(256) __global__ void block_kernel(const T* in, const IdxT* in_idx, IdxT len, int k, T* out, IdxT* out_idx) { extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; - block_sort queue(k, smem_buf_bytes); + using bq_t = block_sort; + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(k, warp_smem); + in += blockIdx.y * len; if (in_idx != nullptr) { in_idx += blockIdx.y * len; } @@ -631,7 +768,7 @@ __launch_bounds__(256) __global__ (i < len && in_idx != nullptr) ? __ldcs(in_idx + i) : i); } - queue.done(); + queue.done(smem_buf_bytes); const int block_id = blockIdx.x + gridDim.x * blockIdx.y; queue.store(out + block_id * k, out_idx + block_id * k); } @@ -658,7 +795,7 @@ struct launch_setup { int* min_grid_size, int block_size_limit = 0) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::calc_optimal_params( @@ -691,7 +828,7 @@ struct launch_setup { IdxT* out_idx, rmm::cuda_stream_view stream) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::kernel(k, @@ -742,6 +879,18 @@ struct LaunchThreshold { static constexpr int len_factor_for_single_block = 32; }; +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + template <> struct LaunchThreshold { static constexpr int len_factor_for_choosing = 4; @@ -753,7 +902,7 @@ template