From 179e1dff1b20866745e10d05422305cb502cad47 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 9 Mar 2022 09:57:54 +0100 Subject: [PATCH 01/41] Integrate new select-top-k implementations --- cpp/bench/CMakeLists.txt | 1 + cpp/bench/spatial/selection.cu | 132 +++ .../knn/detail/ivf_flat/bitonic_sort.cuh | 168 ++++ .../knn/detail/ivf_flat/radix_topk.cuh | 657 ++++++++++++++ .../knn/detail/ivf_flat/warpsort_topk.cuh | 850 ++++++++++++++++++ .../spatial/knn/detail/selection_faiss.cuh | 68 +- cpp/include/raft/spatial/knn/knn.cuh | 218 ++++- cpp/include/raft/spatial/knn/knn.hpp | 222 ++++- cpp/test/spatial/selection.cu | 386 ++++++-- cpp/test/test_utils.h | 26 + 10 files changed, 2502 insertions(+), 226 deletions(-) create mode 100644 cpp/bench/spatial/selection.cu create mode 100644 cpp/include/raft/spatial/knn/detail/ivf_flat/bitonic_sort.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 9f0a6096d9..5214047571 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -19,6 +19,7 @@ set(RAFT_CPP_BENCH_TARGET "bench_raft") # (please keep the filenames in alphabetical order) add_executable(${RAFT_CPP_BENCH_TARGET} bench/linalg/reduce.cu + bench/spatial/selection.cu bench/main.cpp ) diff --git a/cpp/bench/spatial/selection.cu b/cpp/bench/spatial/selection.cu new file mode 100644 index 0000000000..644b983a7e --- /dev/null +++ b/cpp/bench/spatial/selection.cu @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include + +namespace raft::bench::spatial { + +struct params { + int n_inputs; + int input_len; + int k; + int select_min; +}; + +template +struct selection : public Fixture { + selection(const std::string& name, const params& p) : Fixture(name), params_(p) {} + + protected: + void allocateBuffers(const ::benchmark::State& state) override + { + auto in_len = params_.n_inputs * params_.input_len; + alloc(in_dists_, in_len, false); + alloc(in_ids_, in_len, false); + alloc(out_dists_, params_.n_inputs * params_.k, false); + alloc(out_ids_, params_.n_inputs * params_.k, false); + + raft::sparse::iota_fill(in_ids_, IdxT(params_.n_inputs), IdxT(params_.input_len), stream); + raft::random::Rng(42).uniform(in_dists_, in_len, KeyT(-1.0), KeyT(1.0), stream); + } + + void deallocateBuffers(const ::benchmark::State& state) override + { + dealloc(in_dists_, params_.n_inputs * params_.input_len); + dealloc(in_ids_, params_.n_inputs * params_.input_len); + dealloc(out_dists_, params_.n_inputs * params_.k); + dealloc(out_ids_, params_.n_inputs * params_.k); + } + + void runBenchmark(::benchmark::State& state) override + { + rmm::mr::cuda_memory_resource cuda_mr; + rmm::mr::pool_memory_resource pool_mr{ + &cuda_mr, size_t(1) << size_t(30), size_t(16) << size_t(30)}; + rmm::mr::set_current_device_resource(&pool_mr); + try { + std::ostringstream label_stream; + label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loopOnState(state, [this]() { + raft::spatial::knn::select_k(in_dists_, + in_ids_, + params_.n_inputs, + params_.input_len, + out_dists_, + out_ids_, + params_.select_min, + params_.k, + stream, + Algo); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + rmm::mr::set_current_device_resource(nullptr); + } + + private: + params params_; + KeyT *in_dists_, *out_dists_; + IdxT *in_ids_, *out_ids_; +}; + +const std::vector kInputs{ + {10000, 10, 3, true}, {10000, 10, 10, true}, {10000, 700, 3, true}, + {10000, 700, 32, true}, {10000, 2000, 64, true}, {10000, 10000, 7, true}, + {10000, 10000, 19, true}, {10000, 10000, 127, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + {1000, 10000, 512, true}, {1000, 10000, 1024, true}, {1000, 10000, 2048, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + {100, 100000, 512, true}, {100, 100000, 1024, true}, {100, 100000, 2048, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, + {10, 1000000, 512, true}, {10, 1000000, 1024, true}, {10, 1000000, 2048, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(params, SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ + } + +SELECTION_REGISTER(float, int, FAISS); +SELECTION_REGISTER(float, int, RADIX_8_BITS); +SELECTION_REGISTER(float, int, RADIX_11_BITS); +SELECTION_REGISTER(float, int, WARP_SORT); + +// SELECTION_REGISTER(double, int, FAISS); +// SELECTION_REGISTER(double, int, RADIX_8_BITS); +// SELECTION_REGISTER(double, int, RADIX_11_BITS); +// SELECTION_REGISTER(double, int, WARP_SORT); + +} // namespace raft::bench::spatial diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/bitonic_sort.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/bitonic_sort.cuh new file mode 100644 index 0000000000..c99d9b0313 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/bitonic_sort.cuh @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::spatial::knn::detail::ivf_flat { + +namespace helpers { + +template +__device__ __forceinline__ void swap(T& x, T& y) +{ + T t = x; + x = y; + y = t; +} + +template +__device__ __forceinline__ void assign(bool cond, T* ptr, T x) +{ + if (cond) { *ptr = x; } +} + +template +__device__ __forceinline__ auto first(T x, Ts... xs) -> T +{ + return x; +} + +} // namespace helpers + +/** + * Bitonic merge at the warp level. + * + * @tparam Size is the number of elements (must be power of two). + * @tparam Ascending is the resulting order (true: ascending, false: descending). + */ +template +struct bitonic_merge { + static_assert(isPo2(Size)); + + /** How many contiguous elements are processed by one thread. */ + static constexpr int kArrLen = Size / WarpSize; + static constexpr int kStride = kArrLen / 2; + + template + using when_fits_in_warp = + std::enable_if_t<(Fits == (Size <= WarpSize)) && std::is_same_v, void>; + + template + static __device__ auto run(bool reverse, KeyT* keys, PayloadTs*... payloads) + -> when_fits_in_warp + { + for (int i = 0; i < kStride; ++i) { + const int other_i = i + kStride; + KeyT& key = keys[i]; + KeyT& other = keys[other_i]; + bool do_swap = Ascending != reverse ? key > other : key < other; + // Normally, we expect `payloads` to be the array of indices from 0 to len; + // in that case, the construct below makes the sorting stable. + if constexpr (sizeof...(payloads) > 0) { + if (key == other) { + do_swap = + reverse != (helpers::first(payloads...)[i] > helpers::first(payloads...)[other_i]); + } + } + if (do_swap) { + helpers::swap(key, other); + (helpers::swap(payloads[i], payloads[other_i]), ...); + } + } + + bitonic_merge::run(reverse, keys, payloads...); + bitonic_merge::run(reverse, keys + kStride, (payloads + kStride)...); + } + + template + static __device__ auto run(bool reverse, KeyT* keys, PayloadTs*... payloads) + -> when_fits_in_warp + { + const int lane = threadIdx.x % Size; + for (int stride = Size / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + KeyT& key = *keys; + KeyT other = shfl_xor(key, stride, Size); + + bool asc = Ascending != reverse; + bool do_assign = key != other && ((key > other) == (asc != is_second)); + // Normally, we expect `payloads` to be the array of indices from 0 to len; + // in that case, the construct below makes the sorting stable. + if constexpr (sizeof...(payloads) > 0) { + auto payload_this = *helpers::first(payloads...); + auto payload_that = shfl_xor(payload_this, stride, Size); + if (key == other) { do_assign = reverse != ((payload_this > payload_that) != is_second); } + } + + helpers::assign(do_assign, keys, other); + // NB: don't put shfl_xor in a conditional; it must be called by all threads in a warp. + (helpers::assign(do_assign, payloads, shfl_xor(*payloads, stride, Size)), ...); + } + } + + template + static __device__ __forceinline__ void run(KeyT* keys, PayloadTs*... payloads) + { + return run(false, keys, payloads...); + } +}; + +/** + * Bitonic sort at the warp level. + * + * @tparam Size is the number of elements (must be power of two). + * @tparam Ascending is the resulting order (true: ascending, false: descending). + */ +template +struct bitonic_sort { + static_assert(isPo2(Size)); + + static constexpr int kSize2 = Size / 2; + + template + static __device__ __forceinline__ void run(bool reverse, KeyT* keys, PayloadTs*... payloads) + { + if constexpr (Size > 2) { + // NB: the `reverse` expression here is always `0` (false) when `Size > WarpSize` + bitonic_sort::run(laneId() & kSize2, keys, payloads...); + } + if constexpr (Size > WarpSize) { + // NB: this part is executed only if the size of the input arrays is larger than the warp. + constexpr int kShift = kSize2 / WarpSize; + bitonic_sort::run(true, keys + kShift, (payloads + kShift)...); + } + bitonic_merge::run(reverse, keys, payloads...); + } + + /** + * Execute the sort. + * + * @param keys + * is a device pointer to a contiguous array of keys, unique per thread; + * @param payloads + * are zero or more associated arrays of the same size as keys, which are sorted together with + * the keys. + */ + template + static __device__ __forceinline__ void run(KeyT* keys, PayloadTs*... payloads) + { + return run(false, keys, payloads...); + } +}; + +} // namespace raft::spatial::knn::detail::ivf_flat diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh new file mode 100644 index 0000000000..a48f7a1e3c --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh @@ -0,0 +1,657 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +/* + Two implementations: + + (1) radix select (select + filter): + first select the k-th value by going through radix passes, + then filter out all wanted data from original data + + (2) radix topk: + filter out wanted data directly while going through radix passes +*/ + +namespace raft::spatial::knn::detail::ivf_flat { + +inline size_t calc_aligned_size(const std::vector& sizes) +{ + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + size_t total = 0; + for (auto sz : sizes) { + total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + return total + ALIGN_BYTES - 1; +} + +inline std::vector calc_aligned_pointers(const void* p, const std::vector& sizes) +{ + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + + char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + + std::vector aligned_pointers; + aligned_pointers.reserve(sizes.size()); + for (auto sz : sizes) { + aligned_pointers.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + + return aligned_pointers; +} + +constexpr int BLOCK_DIM = 512; +constexpr int ITEM_PER_THREAD = 32; + +template +__host__ __device__ constexpr int calc_num_buckets() +{ + return 1 << BITS_PER_PASS; +} + +template +__host__ __device__ constexpr int calc_num_passes() +{ + return (sizeof(T) * 8 - 1) / BITS_PER_PASS + 1; +} + +// bit 0 is the least significant (rightmost) bit +// this function works even when pass=-1, which is used in calc_mask() +template +__device__ constexpr int calc_start_bit(int pass) +{ + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BITS_PER_PASS; + if (start_bit < 0) { start_bit = 0; } + return start_bit; +} + +template +__device__ constexpr unsigned calc_mask(int pass) +{ + static_assert(BITS_PER_PASS <= 31); + int num_bits = + calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +template +__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +{ + auto bits = reinterpret_cast::UnsignedBits&>(key); + bits = cub::Traits::TwiddleIn(bits); + if (greater) { bits = ~bits; } + return bits; +} + +template +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +{ + static_assert(BITS_PER_PASS <= sizeof(int) * 8 - 1); // so return type can be int + return (twiddle_in(x, greater) >> start_bit) & mask; +} + +template +__device__ void vectorized_process(const T* in, idxT len, Func f) +{ + using WideT = float4; + + const idxT stride = blockDim.x * gridDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = tid; i < len; i += stride) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { skip_cnt = len; } + const WideT* in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + for (idxT i = tid; i < len_cast; i += stride) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WarpSize >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WarpSize > skip_cnt + // no need to use loop + if (tid < skip_cnt) { f(in[tid], tid); } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WarpSize + // no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + if (remain_i < len) { f(in[remain_i], remain_i); } + } +} + +template +struct Counter { + idxT k; + idxT len; + idxT previous_len; + int bucket; + + idxT filter_cnt; + unsigned int finished_block_cnt; + idxT out_cnt; + idxT out_back_cnt; + T kth_value; +}; + +template +__device__ void filter_and_histogram(const T* in_buf, + const idxT* in_idx_buf, + T* out_buf, + idxT* out_idx_buf, + T* out, + idxT* out_idx, + idxT len, + Counter* counter, + idxT* histogram, + bool greater, + int pass, + int k) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ idxT histogram_smem[num_buckets]; + for (idxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + auto f = [greater, start_bit, mask](T value, idxT) { + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, 1); + }; + vectorized_process(in_buf, len, f); + } else { + const idxT previous_len = counter->previous_len; + const int want_bucket = counter->bucket; + idxT& filter_cnt = counter->filter_cnt; + idxT& out_cnt = counter->out_cnt; + T& kth_value = counter->kth_value; + const idxT counter_len = counter->len; + const int previous_start_bit = calc_start_bit(pass - 1); + const unsigned previous_mask = calc_mask(pass - 1); + + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + greater, + k, + start_bit, + mask, + previous_start_bit, + previous_mask, + want_bucket, + &filter_cnt, + &out_cnt, + &kth_value, + counter_len](T value, idxT i) { + int prev_bucket = + calc_bucket(value, previous_start_bit, previous_mask, greater); + if (prev_bucket == want_bucket) { + idxT pos = atomicAdd(&filter_cnt, 1); + out_buf[pos] = value; + if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, 1); + + if (counter_len == 1) { + if (out) { + out[k - 1] = value; + out_idx[k - 1] = in_idx_buf ? in_idx_buf[i] : i; + } else { + kth_value = value; + } + } + } else if (out && prev_bucket < want_bucket) { + idxT pos = atomicAdd(&out_cnt, 1); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + + vectorized_process(in_buf, previous_len, f); + } + __syncthreads(); + + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } + } +} + +template +__device__ void scan(volatile idxT* histogram, + const int start, + const int num_buckets, + const idxT current) +{ + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + idxT thread_data = 0; + int index = start + threadIdx.x; + if (index < num_buckets) { thread_data = histogram[index]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + if (index < num_buckets) { histogram[index] = thread_data + current; } + __syncthreads(); // This sync is necessary, as the content of histogram needs to be + // read after +} + +template +__device__ void choose_bucket(Counter* counter, idxT* histogram, const idxT k) +{ + constexpr int num_buckets = calc_num_buckets(); + int index = threadIdx.x; + idxT current_value = 0; + int num_pass = 1; + if constexpr (num_buckets >= NUM_THREAD) { + static_assert(num_buckets % NUM_THREAD == 0); + num_pass = num_buckets / NUM_THREAD; + } + + for (int i = 0; i < num_pass && (current_value < k); i++) { + scan(histogram, i * NUM_THREAD, num_buckets, current_value); + if (index < num_buckets) { + idxT prev = (index == 0) ? 0 : histogram[index - 1]; + idxT cur = histogram[index]; + + // one and only one thread will satisfy this condition, so only write once + if (prev < k && cur >= k) { + counter->k = k - prev; + counter->previous_len = counter->len; + counter->len = cur - prev; + counter->bucket = index; + } + } + index += NUM_THREAD; + current_value = histogram[(i + 1) * NUM_THREAD - 1]; + } +} + +template +__global__ void radix_kernel(const T* in_buf, + const idxT* in_idx_buf, + T* out_buf, + idxT* out_idx_buf, + T* out, + idxT* out_idx, + Counter* counters, + idxT* histograms, + const idxT len, + const idxT k, + const bool greater, + const int pass) +{ + __shared__ bool isLastBlockDone; + + constexpr int num_buckets = calc_num_buckets(); + constexpr int num_passes = calc_num_passes(); + const int batch_id = blockIdx.y; + in_buf += batch_id * len; + out_buf += batch_id * len; + if (in_idx_buf) { in_idx_buf += batch_id * len; } + if (out_idx_buf) { out_idx_buf += batch_id * len; } + if (out) { + out += batch_id * k; + out_idx += batch_id * k; + } + auto counter = counters + batch_id; + auto histogram = histograms + batch_id * num_buckets; + + filter_and_histogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + len, + counter, + histogram, + greater, + pass, + k); + __threadfence(); + + if (threadIdx.x == 0) { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlockDone = (finished == (gridDim.x - 1)); + } + + // Synchronize to make sure that each thread reads the correct value of + // isLastBlockDone. + __syncthreads(); + if (isLastBlockDone) { + if (counter->len == 1 && threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + // init counter, other members of counter is initialized with 0 by cudaMemset() + if (pass == 0 && threadIdx.x == 0) { + counter->k = k; + counter->len = len; + if (out) { counter->out_back_cnt = 0; } + } + __syncthreads(); + + idxT ori_k = counter->k; + + if (counter->len > 0) { + choose_bucket(counter, histogram, ori_k); + } + + __syncthreads(); + if (pass == num_passes - 1) { + const idxT previous_len = counter->previous_len; + const int want_bucket = counter->bucket; + int start_bit = calc_start_bit(pass); + unsigned mask = calc_mask(pass); + + if (!out) { // radix select + for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = out_buf[i]; + int bucket = calc_bucket(value, start_bit, mask, greater); + if (bucket == want_bucket) { + // TODO: UB + // could use atomicExch, but it's not defined for T=half + counter->kth_value = value; + break; + } + } + } else { // radix topk + idxT& out_cnt = counter->out_cnt; + for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = out_buf[i]; + int bucket = calc_bucket(value, start_bit, mask, greater); + if (bucket < want_bucket) { + idxT pos = atomicAdd(&out_cnt, 1); + out[pos] = value; + out_idx[pos] = out_idx_buf[i]; + } else if (bucket == want_bucket) { + idxT needed_num_of_kth = counter->k; + idxT back_pos = atomicAdd(&(counter->out_back_cnt), 1); + if (back_pos < needed_num_of_kth) { + idxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = out_idx_buf[i]; + } + } + } + __syncthreads(); + } + } else { + // reset for next pass + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + if (threadIdx.x == 0) { counter->filter_cnt = 0; } + } + } +} + +template +__global__ void final_filter(const T* in, + const idxT len, + const idxT k, + Counter* counters, + T* out, + idxT* out_idx, + bool greater) +{ + const int batch_id = blockIdx.y; + const T kth_value = counters[batch_id].kth_value; + const idxT needed_num_of_kth = counters[batch_id].k; + idxT& out_cnt = counters[batch_id].out_cnt; + idxT& out_back_cnt = counters[batch_id].out_back_cnt; + + in = in + batch_id * len; + out = out + batch_id * k; + out_idx = out_idx + batch_id * k; + + auto f = [k, greater, kth_value, needed_num_of_kth, &out_cnt, &out_back_cnt, out, out_idx]( + T val, idxT i) { + if ((greater && val > kth_value) || (!greater && val < kth_value)) { + idxT pos = atomicAdd(&out_cnt, 1); + out[pos] = val; + out_idx[pos] = i; + } else if (val == kth_value) { + idxT back_pos = atomicAdd(&out_back_cnt, 1); + if (back_pos < needed_num_of_kth) { + idxT pos = k - 1 - back_pos; + out[pos] = val; + out_idx[pos] = i; + } + } + }; + vectorized_process(in, len, f); +} + +template +void radix_select_topk(void* buf, + size_t& buf_size, + const T* in, + idxT batch_size, + idxT len, + idxT k, + T* out, + idxT* out_idx, + bool greater, + cudaStream_t stream) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + Counter* counters = nullptr; + idxT* histograms = nullptr; + T* buf1 = nullptr; + T* buf2 = nullptr; + { + std::vector sizes = {sizeof(*counters) * batch_size, + sizeof(*histograms) * num_buckets * batch_size, + sizeof(*buf1) * len * batch_size, + sizeof(*buf2) * len * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + counters = static_cast(aligned_pointers[0]); + histograms = static_cast(aligned_pointers[1]); + buf1 = static_cast(aligned_pointers[2]); + buf2 = static_cast(aligned_pointers[3]); + + RAFT_CUDA_TRY(cudaMemsetAsync( + buf, + 0, + static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), + stream)); + } + + const T* in_buf = nullptr; + T* out_buf = nullptr; + + dim3 blocks((len - 1) / (NUM_THREAD * ITEM_PER_THREAD) + 1, batch_size); + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + if (pass == 0) { + in_buf = in; + out_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + out_buf = buf1; + } else { + in_buf = (pass % 2 == 0) ? buf1 : buf2; + out_buf = (pass % 2 == 0) ? buf2 : buf1; + } + radix_kernel<<>>(in_buf, + nullptr, + out_buf, + nullptr, + nullptr, + nullptr, + counters, + histograms, + len, + k, + greater, + pass); + } + + constexpr int FILTER_BLOCK_DIM = 256; + constexpr int FILTER_ITEM_PER_THREAD = 32; + dim3 filter_blocks((len - 1) / (FILTER_BLOCK_DIM * FILTER_ITEM_PER_THREAD) + 1, batch_size); + final_filter<<>>( + in, len, k, counters, out, out_idx, greater); +} + +template +void radix_topk(void* buf, + size_t& buf_size, + const T* in, + const idxT* in_idx, + idxT batch_size, + idxT len, + idxT k, + T* out, + idxT* out_idx, + bool greater, + cudaStream_t stream) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + Counter* counters = nullptr; + idxT* histograms = nullptr; + T* buf1 = nullptr; + idxT* idx_buf1 = nullptr; + T* buf2 = nullptr; + idxT* idx_buf2 = nullptr; + { + std::vector sizes = {sizeof(*counters) * batch_size, + sizeof(*histograms) * num_buckets * batch_size, + sizeof(*buf1) * len * batch_size, + sizeof(*idx_buf1) * len * batch_size, + sizeof(*buf2) * len * batch_size, + sizeof(*idx_buf2) * len * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + counters = static_cast(aligned_pointers[0]); + histograms = static_cast(aligned_pointers[1]); + buf1 = static_cast(aligned_pointers[2]); + idx_buf1 = static_cast(aligned_pointers[3]); + buf2 = static_cast(aligned_pointers[4]); + idx_buf2 = static_cast(aligned_pointers[5]); + + RAFT_CUDA_TRY(cudaMemsetAsync( + buf, + 0, + static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), + stream)); + } + + const T* in_buf = nullptr; + const idxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + idxT* out_idx_buf = nullptr; + + dim3 blocks((len - 1) / (NUM_THREAD * ITEM_PER_THREAD) + 1, batch_size); + + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx ? in_idx : nullptr; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } + + radix_kernel<<>>(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters, + histograms, + len, + k, + greater, + pass); + } +} + +} // namespace raft::spatial::knn::detail::ivf_flat diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh new file mode 100644 index 0000000000..1ffeb7335f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh @@ -0,0 +1,850 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "bitonic_sort.cuh" + +#include +#include + +#include +#include +#include + +/* + Three APIs of different scope are provided: + 1. host function: warp_sort_topk() + 2. block-wide API: class WarpSortBlockWide + 3. warp-wide API: class WarpSelect and class WarpBitonic + + + 1. warp_sort_topk() + Like CUB functions, it should be called twice. + First for getting required buffer size, and a second for the real top-k computation. + For the first call, buf==nullptr should be passed, and required buffer + size is returned as parameter buf_size. + For the second call, pass allocated buffer of required size. + + Example: + void* buf = nullptr; + size_t buf_size; + warp_sort_topk(nullptr, buf_size, ...); // will set buf_size + cudaMalloc(&buf, buf_size); + warp_sort_topk(buf, buf_size, ...); + + + 2. class WarpSortBlockWide + It can be regarded as a fixed size priority queue for a thread block, + although the API is not typical. + class WarpSelect and WarpBitonic can be used to instantiate WarpSortBlockWide. + + It uses dynamic shared memory as intermediate buffer. + So the required shared memory size should be calculated using + calc_smem_size_for_block_wide() and passed as the 3rd kernel launch parameter. + + Two overloaded add() functions can be used to add items to the queue. + One is add(const T* in, idxT start, idxT end) and it adds a range of items, + namely [start, end) of in. The idx is inferred from start. + This function should be called only once to add all items, and should not be + used together with the second form of add(). + The second one is add(T val, idxT idx), and it adds only one item pair. + Note that the range [start, end) is for the whole block of threads, that is, + each thread in the same block should get the same start/end. + In contrast, the parameters of the second form are for only one thread, + so each thread must get different val/idx. + + After adding is finished, function done() should be called. And finally, + dump() is used to get the top-k result. + + Example: + __global__ void kernel() { + WarpSortBlockWide queue(...); + + // way 1, [0, len) is same for the whole block + queue.add(in, 0, len); + // way 2, each thread gets its own val/idx pair + for (idxT i = threadIdx.x; i < len, i += blockDim.x) { + queue.add(in[i], idx[i]); + } + + queue.done(); + queue.dump(out, out_idx); + } + + int smem_size = calc_smem_size_for_block_wide(...); + kernel<>>(); + + + 3. class WarpSelect and class WarpBitonic + These two classes can be regarded as fixed sized priority queue for a warp. + Usage is similar to class WarpSortBlockWide. + Two types of add() functions are provided, and also note that [start, end) is + for a whole warp, while val/idx is for a thread. + No shared memory is needed. + + Example: + __global__ void kernel() { + WarpBitonic<...> queue(...); + int warp_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; + + // way 1, [0, len) is same for the whole warp + queue.add(in, 0, len); + // way 2, each thread gets its own val/idx pair + for (idxT i = lane_id; i < len, i += WarpSize) { + queue.add(in[i], idx[i]); + } + + queue.done(); + // each warp outputs to a different offset + queue.dump(out+ warp_id * k * sizeof(T), out_idx+ warp_id * k * sizeof(idxT)); + } + */ + +namespace raft::spatial::knn::detail::ivf_flat { + +namespace { + +template +constexpr T get_lower_bound() +{ + if (std::numeric_limits::has_infinity && std::numeric_limits::is_signed) { + return -std::numeric_limits::infinity(); + } else { + return std::numeric_limits::lowest(); + } +} + +template +constexpr T get_upper_bound() +{ + if (std::numeric_limits::has_infinity) { + return std::numeric_limits::infinity(); + } else { + return std::numeric_limits::max(); + } +} + +template +constexpr T get_dummy(bool greater) +{ + return greater ? get_lower_bound() : get_upper_bound(); +} + +template +__device__ inline bool is_greater_than(T val, T baseline) +{ + if constexpr (greater) { return val > baseline; } + if constexpr (!greater) { return val < baseline; } +} + +template +constexpr HDI T nextHighestPowerOf2(T v) +{ + /** + * TODO: Not entirely sure if this is what we need in the code of this file. + * It returns `r`, such that r > v, r <= v*2, and r is power of two. + */ + return isPo2(v) ? (v << (T)1) : ((T)1 << (log2(v) + 1)); +} + +int calc_capacity(int k) +{ + int capacity = nextHighestPowerOf2(k); + if (capacity < WarpSize) { capacity = WarpSize; } + return capacity; +} +} // namespace +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) : lane_(threadIdx.x % WarpSize), k_(k), dummy_(dummy) + { + static_assert(capacity >= WarpSize && isPo2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + } + } + + // load and merge k sorted values + __device__ void load_sorted(const T* in, const idxT* in_idx, idxT start) + { + idxT idx = start + WarpSize - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WarpSize) { + if (idx < start + k_) { + T t = in[idx]; + if (is_greater_than(t, val_arr_[i])) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + bitonic_merge::run(val_arr_, idx_arr_); + } + + __device__ void dump(T* out, idxT* out_idx) const + { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WarpSize + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + protected: + static constexpr int max_arr_len_ = capacity / WarpSize; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + const int lane_; + const idxT k_; + const T dummy_; +}; + +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + buf_len_(0), + k_th_(dummy), + k_th_lane_((k - 1) % WarpSize) + { + for (int i = 0; i < max_buf_len_; ++i) { + val_buf_[i] = dummy_; + } + } + + __device__ void add(const T* in, idxT start, idxT end) + { + const idxT end_for_fullwarp = Pow2::roundUp(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WarpSize) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) + { + if (is_greater_than(val, k_th_)) { + for (int i = 0; i < max_buf_len_ - 1; ++i) { + val_buf_[i] = val_buf_[i + 1]; + idx_buf_[i] = idx_buf_[i + 1]; + } + val_buf_[max_buf_len_ - 1] = val; + idx_buf_[max_buf_len_ - 1] = idx; + + ++buf_len_; + } + + if (any(buf_len_ == max_buf_len_)) { merge_buf_(); } + } + + __device__ void done() + { + if (any(buf_len_ != 0)) { merge_buf_(); } + } + + private: + __device__ void set_k_th_() + { + // it's the best we can do, should use "val_arr_[k_th_row_]" + k_th_ = shfl(val_arr_[max_arr_len_ - 1], k_th_lane_); + } + + __device__ void merge_buf_() + { + bitonic_sort::run(val_buf_, idx_buf_); + + if (max_arr_len_ > max_buf_len_) { + for (int i = 0; i < max_buf_len_; ++i) { + T& val = val_arr_[max_arr_len_ - max_buf_len_ + i]; + T& buf = val_buf_[i]; + if (is_greater_than(buf, val)) { + val = buf; + idx_arr_[max_arr_len_ - max_buf_len_ + i] = idx_buf_[i]; + } + } + } else if (max_arr_len_ < max_buf_len_) { + for (int i = 0; i < max_arr_len_; ++i) { + T& val = val_arr_[i]; + T& buf = val_buf_[max_buf_len_ - max_arr_len_ + i]; + if (is_greater_than(buf, val)) { + val = buf; + idx_arr_[i] = idx_buf_[max_buf_len_ - max_arr_len_ + i]; + } + } + } else { + for (int i = 0; i < max_arr_len_; ++i) { + if (is_greater_than(val_buf_[i], val_arr_[i])) { + val_arr_[i] = val_buf_[i]; + idx_arr_[i] = idx_buf_[i]; + } + } + } + + bitonic_merge::run(val_arr_, idx_arr_); + + buf_len_ = 0; + set_k_th_(); // contains sync + for (int i = 0; i < max_buf_len_; ++i) { + val_buf_[i] = dummy_; + } + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + static constexpr int max_buf_len_ = (capacity <= 64) ? 2 : 4; + + T val_buf_[max_buf_len_]; + idxT idx_buf_[max_buf_len_]; + int buf_len_; + + T k_th_; + const int k_th_lane_; +}; + +template +class WarpBitonic : public WarpSort { + public: + __device__ WarpBitonic(idxT k, T dummy) + : WarpSort(k, dummy), buf_len_(0) + { + for (int i = 0; i < max_arr_len_; ++i) { + val_buf_[i] = dummy_; + } + } + + __device__ void add(const T* in, idxT start, idxT end) + { + add_first_(in, start, end); + start += capacity; + while (start < end) { + add_extra_(in, start, end); + merge_(); + start += capacity; + } + } + + __device__ void add(T val, idxT idx) + { + for (int i = 0; i < max_arr_len_; ++i) { + if (i == buf_len_) { + val_buf_[i] = val; + idx_buf_[i] = idx; + } + } + + ++buf_len_; + if (buf_len_ == max_arr_len_) { + bitonic_sort::run(val_buf_, idx_buf_); + merge_(); + + for (int i = 0; i < max_arr_len_; ++i) { + val_buf_[i] = dummy_; + } + buf_len_ = 0; + } + } + + __device__ void done() + { + if (buf_len_ != 0) { + bitonic_sort::run(val_buf_, idx_buf_); + merge_(); + } + } + + private: + __device__ void add_first_(const T* in, idxT start, idxT end) + { + idxT idx = start + lane_; + for (int i = 0; i < max_arr_len_; ++i, idx += WarpSize) { + if (idx < end) { + val_arr_[i] = in[idx]; + idx_arr_[i] = idx; + } + } + bitonic_sort::run(val_arr_, idx_arr_); + } + + __device__ void add_extra_(const T* in, idxT start, idxT end) + { + idxT idx = start + lane_; + for (int i = 0; i < max_arr_len_; ++i, idx += WarpSize) { + val_buf_[i] = (idx < end) ? in[idx] : dummy_; + idx_buf_[i] = idx; + } + bitonic_sort::run(val_buf_, idx_buf_); + } + + __device__ void merge_() + { + for (int i = 0; i < max_arr_len_; ++i) { + if (is_greater_than(val_buf_[i], val_arr_[i])) { + val_arr_[i] = val_buf_[i]; + idx_arr_[i] = idx_buf_[i]; + } + } + bitonic_merge::run(val_arr_, idx_arr_); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T val_buf_[max_arr_len_]; + idxT idx_buf_[max_arr_len_]; + int buf_len_; +}; + +template +class WarpMerge : public WarpSort { + public: + __device__ WarpMerge(idxT k, T dummy) : WarpSort(k, dummy) {} + + __device__ void add(const T* in, const idxT* in_idx, idxT start, idxT end) + { + idxT idx = start + lane_; + idxT first_end = (start + k_ < end) ? (start + k_) : end; + for (int i = 0; i < max_arr_len_; ++i, idx += WarpSize) { + if (idx < first_end) { + val_arr_[i] = in[idx]; + idx_arr_[i] = in_idx[idx]; + } + } + + for (start += k_; start < end; start += k_) { + load_sorted(in, in_idx, start); + } + } + + __device__ void done() {} + + private: + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; +}; + +template +int calc_smem_size_for_block_wide(int num_of_warp, idxT k) +{ + return Pow2<256>::roundUp(num_of_warp / 2 * sizeof(T) * k) + num_of_warp / 2 * sizeof(idxT) * k; +} + +template