From 1d2863dc11d95f7f05a3042eca058b96f5a054bf Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 11 Mar 2022 08:56:35 +0100 Subject: [PATCH] Remove the step of calculating required buf size. --- .../knn/detail/ivf_flat/radix_topk.cuh | 283 ++++++++---------- .../knn/detail/ivf_flat/warpsort_topk.cuh | 129 +++----- cpp/include/raft/spatial/knn/knn.cuh | 87 +----- cpp/include/raft/spatial/knn/knn.hpp | 87 +----- 4 files changed, 191 insertions(+), 395 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh index a48f7a1e3c..d91c4c328d 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh @@ -114,15 +114,15 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) return (twiddle_in(x, greater) >> start_bit) & mask; } -template -__device__ void vectorized_process(const T* in, idxT len, Func f) +template +__device__ void vectorized_process(const T* in, IdxT len, Func f) { using WideT = float4; - const idxT stride = blockDim.x * gridDim.x; + const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = tid; i < len; i += stride) { + for (IdxT i = tid; i < len; i += stride) { f(in[i], i); } } else { @@ -139,10 +139,10 @@ __device__ void vectorized_process(const T* in, idxT len, Func f) : 0; if (skip_cnt > len) { skip_cnt = len; } const WideT* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; - for (idxT i = tid; i < len_cast; i += stride) { + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; + for (IdxT i = tid; i < len_cast; i += stride) { wide.scalar = in_cast[i]; - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { f(wide.array[j], real_i + j); @@ -156,44 +156,44 @@ __device__ void vectorized_process(const T* in, idxT len, Func f) // because len_cast = (len - skip_cnt) / items_per_scalar, // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; // and so - // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WarpSize - // no need to use loop - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WarpSize no need to use loop + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; if (remain_i < len) { f(in[remain_i], remain_i); } } } -template +template struct Counter { - idxT k; - idxT len; - idxT previous_len; + IdxT k; + IdxT len; + IdxT previous_len; int bucket; - idxT filter_cnt; + IdxT filter_cnt; unsigned int finished_block_cnt; - idxT out_cnt; - idxT out_back_cnt; + IdxT out_cnt; + IdxT out_back_cnt; T kth_value; }; -template +template __device__ void filter_and_histogram(const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out_buf, - idxT* out_idx_buf, + IdxT* out_idx_buf, T* out, - idxT* out_idx, - idxT len, - Counter* counter, - idxT* histogram, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, bool greater, int pass, int k) { constexpr int num_buckets = calc_num_buckets(); - __shared__ idxT histogram_smem[num_buckets]; - for (idxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram_smem[i] = 0; } __syncthreads(); @@ -202,18 +202,18 @@ __device__ void filter_and_histogram(const T* in_buf, const unsigned mask = calc_mask(pass); if (pass == 0) { - auto f = [greater, start_bit, mask](T value, idxT) { + auto f = [greater, start_bit, mask](T value, IdxT) { int bucket = calc_bucket(value, start_bit, mask, greater); atomicAdd(histogram_smem + bucket, 1); }; vectorized_process(in_buf, len, f); } else { - const idxT previous_len = counter->previous_len; + const IdxT previous_len = counter->previous_len; const int want_bucket = counter->bucket; - idxT& filter_cnt = counter->filter_cnt; - idxT& out_cnt = counter->out_cnt; + IdxT& filter_cnt = counter->filter_cnt; + IdxT& out_cnt = counter->out_cnt; T& kth_value = counter->kth_value; - const idxT counter_len = counter->len; + const IdxT counter_len = counter->len; const int previous_start_bit = calc_start_bit(pass - 1); const unsigned previous_mask = calc_mask(pass - 1); @@ -232,11 +232,11 @@ __device__ void filter_and_histogram(const T* in_buf, &filter_cnt, &out_cnt, &kth_value, - counter_len](T value, idxT i) { + counter_len](T value, IdxT i) { int prev_bucket = calc_bucket(value, previous_start_bit, previous_mask, greater); if (prev_bucket == want_bucket) { - idxT pos = atomicAdd(&filter_cnt, 1); + IdxT pos = atomicAdd(&filter_cnt, 1); out_buf[pos] = value; if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } int bucket = calc_bucket(value, start_bit, mask, greater); @@ -251,7 +251,7 @@ __device__ void filter_and_histogram(const T* in_buf, } } } else if (out && prev_bucket < want_bucket) { - idxT pos = atomicAdd(&out_cnt, 1); + IdxT pos = atomicAdd(&out_cnt, 1); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -266,32 +266,32 @@ __device__ void filter_and_histogram(const T* in_buf, } } -template -__device__ void scan(volatile idxT* histogram, +template +__device__ void scan(volatile IdxT* histogram, const int start, const int num_buckets, - const idxT current) + const IdxT current) { - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; - idxT thread_data = 0; + IdxT thread_data = 0; int index = start + threadIdx.x; if (index < num_buckets) { thread_data = histogram[index]; } BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); __syncthreads(); if (index < num_buckets) { histogram[index] = thread_data + current; } - __syncthreads(); // This sync is necessary, as the content of histogram needs to be - // read after + __syncthreads(); // This sync is necessary, as the content of histogram needs + // to be read after } -template -__device__ void choose_bucket(Counter* counter, idxT* histogram, const idxT k) +template +__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) { constexpr int num_buckets = calc_num_buckets(); int index = threadIdx.x; - idxT current_value = 0; + IdxT current_value = 0; int num_pass = 1; if constexpr (num_buckets >= NUM_THREAD) { static_assert(num_buckets % NUM_THREAD == 0); @@ -299,10 +299,10 @@ __device__ void choose_bucket(Counter* counter, idxT* histogram, const } for (int i = 0; i < num_pass && (current_value < k); i++) { - scan(histogram, i * NUM_THREAD, num_buckets, current_value); + scan(histogram, i * NUM_THREAD, num_buckets, current_value); if (index < num_buckets) { - idxT prev = (index == 0) ? 0 : histogram[index - 1]; - idxT cur = histogram[index]; + IdxT prev = (index == 0) ? 0 : histogram[index - 1]; + IdxT cur = histogram[index]; // one and only one thread will satisfy this condition, so only write once if (prev < k && cur >= k) { @@ -317,17 +317,17 @@ __device__ void choose_bucket(Counter* counter, idxT* histogram, const } } -template +template __global__ void radix_kernel(const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out_buf, - idxT* out_idx_buf, + IdxT* out_idx_buf, T* out, - idxT* out_idx, - Counter* counters, - idxT* histograms, - const idxT len, - const idxT k, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const int k, const bool greater, const int pass) { @@ -347,7 +347,7 @@ __global__ void radix_kernel(const T* in_buf, auto counter = counters + batch_id; auto histogram = histograms + batch_id * num_buckets; - filter_and_histogram(in_buf, + filter_and_histogram(in_buf, in_idx_buf, out_buf, out_idx_buf, @@ -374,7 +374,8 @@ __global__ void radix_kernel(const T* in_buf, counter->previous_len = 0; counter->len = 0; } - // init counter, other members of counter is initialized with 0 by cudaMemset() + // init counter, other members of counter is initialized with 0 by + // cudaMemset() if (pass == 0 && threadIdx.x == 0) { counter->k = k; counter->len = len; @@ -382,21 +383,21 @@ __global__ void radix_kernel(const T* in_buf, } __syncthreads(); - idxT ori_k = counter->k; + IdxT ori_k = counter->k; if (counter->len > 0) { - choose_bucket(counter, histogram, ori_k); + choose_bucket(counter, histogram, ori_k); } __syncthreads(); if (pass == num_passes - 1) { - const idxT previous_len = counter->previous_len; + const IdxT previous_len = counter->previous_len; const int want_bucket = counter->bucket; int start_bit = calc_start_bit(pass); unsigned mask = calc_mask(pass); if (!out) { // radix select - for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = out_buf[i]; int bucket = calc_bucket(value, start_bit, mask, greater); if (bucket == want_bucket) { @@ -407,19 +408,19 @@ __global__ void radix_kernel(const T* in_buf, } } } else { // radix topk - idxT& out_cnt = counter->out_cnt; - for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + IdxT& out_cnt = counter->out_cnt; + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = out_buf[i]; int bucket = calc_bucket(value, start_bit, mask, greater); if (bucket < want_bucket) { - idxT pos = atomicAdd(&out_cnt, 1); + IdxT pos = atomicAdd(&out_cnt, 1); out[pos] = value; out_idx[pos] = out_idx_buf[i]; } else if (bucket == want_bucket) { - idxT needed_num_of_kth = counter->k; - idxT back_pos = atomicAdd(&(counter->out_back_cnt), 1); + IdxT needed_num_of_kth = counter->k; + IdxT back_pos = atomicAdd(&(counter->out_back_cnt), 1); if (back_pos < needed_num_of_kth) { - idxT pos = k - 1 - back_pos; + IdxT pos = k - 1 - back_pos; out[pos] = value; out_idx[pos] = out_idx_buf[i]; } @@ -437,35 +438,35 @@ __global__ void radix_kernel(const T* in_buf, } } -template +template __global__ void final_filter(const T* in, - const idxT len, - const idxT k, - Counter* counters, + const IdxT len, + const IdxT k, + Counter* counters, T* out, - idxT* out_idx, + IdxT* out_idx, bool greater) { const int batch_id = blockIdx.y; const T kth_value = counters[batch_id].kth_value; - const idxT needed_num_of_kth = counters[batch_id].k; - idxT& out_cnt = counters[batch_id].out_cnt; - idxT& out_back_cnt = counters[batch_id].out_back_cnt; + const IdxT needed_num_of_kth = counters[batch_id].k; + IdxT& out_cnt = counters[batch_id].out_cnt; + IdxT& out_back_cnt = counters[batch_id].out_back_cnt; in = in + batch_id * len; out = out + batch_id * k; out_idx = out_idx + batch_id * k; auto f = [k, greater, kth_value, needed_num_of_kth, &out_cnt, &out_back_cnt, out, out_idx]( - T val, idxT i) { + T val, IdxT i) { if ((greater && val > kth_value) || (!greater && val < kth_value)) { - idxT pos = atomicAdd(&out_cnt, 1); + IdxT pos = atomicAdd(&out_cnt, 1); out[pos] = val; out_idx[pos] = i; } else if (val == kth_value) { - idxT back_pos = atomicAdd(&out_back_cnt, 1); + IdxT back_pos = atomicAdd(&out_back_cnt, 1); if (back_pos < needed_num_of_kth) { - idxT pos = k - 1 - back_pos; + IdxT pos = k - 1 - back_pos; out[pos] = val; out_idx[pos] = i; } @@ -474,15 +475,15 @@ __global__ void final_filter(const T* in, vectorized_process(in, len, f); } -template +template void radix_select_topk(void* buf, size_t& buf_size, const T* in, - idxT batch_size, - idxT len, - idxT k, + IdxT batch_size, + IdxT len, + IdxT k, T* out, - idxT* out_idx, + IdxT* out_idx, bool greater, cudaStream_t stream) { @@ -490,8 +491,8 @@ void radix_select_topk(void* buf, static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - Counter* counters = nullptr; - idxT* histograms = nullptr; + Counter* counters = nullptr; + IdxT* histograms = nullptr; T* buf1 = nullptr; T* buf2 = nullptr; { @@ -535,7 +536,7 @@ void radix_select_topk(void* buf, in_buf = (pass % 2 == 0) ? buf1 : buf2; out_buf = (pass % 2 == 0) ? buf2 : buf1; } - radix_kernel<<>>(in_buf, + radix_kernel<<>>(in_buf, nullptr, out_buf, nullptr, @@ -556,61 +557,36 @@ void radix_select_topk(void* buf, in, len, k, counters, out, out_idx, greater); } -template -void radix_topk(void* buf, - size_t& buf_size, - const T* in, - const idxT* in_idx, - idxT batch_size, - idxT len, - idxT k, +template +void radix_topk(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, T* out, - idxT* out_idx, + IdxT* out_idx, bool greater, - cudaStream_t stream) + rmm::cuda_stream_view stream) { // TODO: is it possible to relax this restriction? static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - Counter* counters = nullptr; - idxT* histograms = nullptr; - T* buf1 = nullptr; - idxT* idx_buf1 = nullptr; - T* buf2 = nullptr; - idxT* idx_buf2 = nullptr; - { - std::vector sizes = {sizeof(*counters) * batch_size, - sizeof(*histograms) * num_buckets * batch_size, - sizeof(*buf1) * len * batch_size, - sizeof(*idx_buf1) * len * batch_size, - sizeof(*buf2) * len * batch_size, - sizeof(*idx_buf2) * len * batch_size}; - size_t total_size = calc_aligned_size(sizes); - if (!buf) { - buf_size = total_size; - return; - } - - std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); - counters = static_cast(aligned_pointers[0]); - histograms = static_cast(aligned_pointers[1]); - buf1 = static_cast(aligned_pointers[2]); - idx_buf1 = static_cast(aligned_pointers[3]); - buf2 = static_cast(aligned_pointers[4]); - idx_buf2 = static_cast(aligned_pointers[5]); + rmm::device_uvector> counters(batch_size, stream); + rmm::device_uvector histograms(num_buckets * batch_size, stream); + rmm::device_uvector buf1(len * batch_size, stream); + rmm::device_uvector idx_buf1(len * batch_size, stream); + rmm::device_uvector buf2(len * batch_size, stream); + rmm::device_uvector idx_buf2(len * batch_size, stream); - RAFT_CUDA_TRY(cudaMemsetAsync( - buf, - 0, - static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), - stream)); - } + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); const T* in_buf = nullptr; - const idxT* in_idx_buf = nullptr; + const IdxT* in_idx_buf = nullptr; T* out_buf = nullptr; - idxT* out_idx_buf = nullptr; + IdxT* out_idx_buf = nullptr; dim3 blocks((len - 1) / (NUM_THREAD * ITEM_PER_THREAD) + 1, batch_size); @@ -625,32 +601,33 @@ void radix_topk(void* buf, } else if (pass == 1) { in_buf = in; in_idx_buf = in_idx ? in_idx : nullptr; - out_buf = buf1; - out_idx_buf = idx_buf1; + out_buf = buf1.data(); + out_idx_buf = idx_buf1.data(); } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; + in_buf = buf1.data(); + in_idx_buf = idx_buf1.data(); + out_buf = buf2.data(); + out_idx_buf = idx_buf2.data(); } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; + in_buf = buf2.data(); + in_idx_buf = idx_buf2.data(); + out_buf = buf1.data(); + out_idx_buf = idx_buf1.data(); } - radix_kernel<<>>(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - counters, - histograms, - len, - k, - greater, - pass); + radix_kernel + <<>>(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters.data(), + histograms.data(), + len, + k, + greater, + pass); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh index 7cf0ff566d..47221b92bb 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/warpsort_topk.cuh @@ -117,7 +117,7 @@ namespace raft::spatial::knn::detail::ivf_flat { -static constexpr int kMaxCapacity = 1024; +static constexpr int kMaxCapacity = 512; namespace { @@ -734,8 +734,7 @@ struct LaunchThreshold { }; template