From 335236c705c0c53da8a4bf6a22835fdbe669f1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Wed, 20 Mar 2024 15:08:19 +0100 Subject: [PATCH] Performance optimization of IVF-flat / select_k (#2221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is a followup to #2169. To enable IVF-flat with k>256 we need an additional select_k invocation which was unexpectedly slow. There are two reasons for that: First problem is the data handed to select_k: The valid data length per row is much smaller than the conservative maximum that could be achieved by probing the N largest probes. Therefore each query row contains roughly ~50% dummy values. This is also the case for IVF-PQ, but did not show up as prominent due to the second reason. The second problem, and also a difference to the IVF-PQ algorithm - is that a 64bit payload data type is used for selectK. The performance of selectK with 64bit index type is significantly slower than with 32bit, especially when many elements are in the same range: ``` Benchmark Time CPU Iterations ----------------------------------------------------------------------------------------------------- SelectK/float/uint32_t/kRadix11bitsExtraPass/1/manual_time 1.68 ms 1.74 ms 413 1357#200000#512 SelectK/float/uint32_t/kRadix11bitsExtraPass/3/manual_time 2.31 ms 2.37 ms 302 1357#200000#512#same-leading-bits SelectK/float/int64_t/kRadix11bitsExtraPass/1/manual_time 5.92 ms 5.98 ms 116 1357#200000#512 SelectK/float/int64_t/kRadix11bitsExtraPass/3/manual_time 83.7 ms 83.8 ms 8 1357#200000#512#same-leading-bits ----------------------------------------------------------------------------------------------------- ``` The data distribution within a IVF-flat benchmark resulted in a select_k time of ~24ms. ### scope: * additional parameter added to select_k to optionally pass individual row lengths for every batch entry. This parameter is utilized by both IVF-Flat and IVF-PQ and results in a ~2x speedup (50 nodes out of 5000) of the final `select_k`. * refactor ivf-flat search to work with 32bit indices by storing positions instead of actual indices. This allows to utilize 32bit index type select_k for ~10x speedup in the final `select_k`. FYI @tfeher @achirkin ### not in scope: * General optimization of select_k: In the current implementation there is no difference in the type of the payload and the actual index type. Especially the type of the histogram has a large effect on performance (due to the atomics). Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2221 --- .../raft/matrix/detail/select_k-ext.cuh | 8 +- .../raft/matrix/detail/select_k-inl.cuh | 16 ++-- .../raft/matrix/detail/select_radix.cuh | 35 ++++++- .../raft/neighbors/detail/ivf_common.cuh | 20 ++-- .../detail/ivf_flat_interleaved_scan-ext.cuh | 4 +- .../detail/ivf_flat_interleaved_scan-inl.cuh | 25 ++--- .../neighbors/detail/ivf_flat_search-inl.cuh | 91 +++++++++++-------- .../raft/neighbors/detail/ivf_pq_search.cuh | 5 +- .../raft/neighbors/detail/refine_device.cuh | 36 +++++++- .../matrix/detail/select_k_double_int64_t.cu | 3 +- .../matrix/detail/select_k_double_uint32_t.cu | 3 +- cpp/src/matrix/detail/select_k_float_int32.cu | 3 +- .../matrix/detail/select_k_float_int64_t.cu | 3 +- .../matrix/detail/select_k_float_uint32_t.cu | 3 +- .../matrix/detail/select_k_half_int64_t.cu | 3 +- .../matrix/detail/select_k_half_uint32_t.cu | 3 +- ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...flat_interleaved_scan_half_half_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- cpp/test/neighbors/ann_cagra.cuh | 8 +- cpp/test/neighbors/ann_utils.cuh | 43 ++++++++- 22 files changed, 221 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 6a7847d8a0..506cbffcb9 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 8f40e6ae00..93d233152b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle, * whether to make sure selected pairs are sorted by value * @param[in] algo * the selection algorithm to use + * @param[in] len_i + * array of size (batch_size) providing lengths for each individual row + * only radix select-k supported */ template void select_k(raft::resources const& handle, @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - true // fused_last_filter - ); - + true, // fused_last_filter + len_i); } else { bool fused_last_filter = algo == SelectAlgo::kRadix11bits; detail::select::radix::select_k(handle, @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - fused_last_filter); + fused_last_filter, + len_i); } if (sorted) { auto offsets = make_device_mdarray( diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 82983b7cd2..36a346fda3 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in, Counter* counters, IdxT* histograms, const IdxT len, + const IdxT* len_i, const IdxT k, const bool select_min, const int pass) @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in, in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + // "current_len > buf_len" means current pass will skip writing buffer if (pass == 0 || current_len > buf_len) { out_buf = nullptr; @@ -829,6 +838,7 @@ void radix_topk(const T* in, IdxT* out_idx, bool select_min, bool fused_last_filter, + const IdxT* len_i, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, @@ -868,6 +878,7 @@ void radix_topk(const T* in, const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; T* chunk_out = out + offset * k; IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; @@ -905,6 +916,7 @@ void radix_topk(const T* in, counters.data(), histograms.data(), len, + chunk_len_i, k, select_min, pass); @@ -1007,6 +1019,7 @@ template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, + const IdxT* len_i, const IdxT k, T* out, IdxT* out_idx, @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_idx_buf = nullptr; } + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + filter_and_histogram_for_one_block(in_buf, in_idx_buf, out_buf, @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in, T* out, IdxT* out_idx, bool select_min, + const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); + int chunk_size = std::min(max_chunk_size, batch_size - offset); + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; kernel<<>>(in + offset * len, in_idx ? (in_idx + offset * len) : nullptr, len, + chunk_len_i, k, out + offset * k, out_idx + offset * k, @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in, * blocks is called. The later case is preferable when leading bits of input data are almost the * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. + * @param len_i + * optional array of size (batch_size) providing lengths for each individual row */ template void select_k(raft::resources const& res, @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res, T* out, IdxT* out_idx, bool select_min, - bool fused_last_filter) + bool fused_last_filter, + const IdxT* len_i) { auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { impl::radix_topk(in, in_idx, @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res, out_idx, select_min, fused_last_filter, + len_i, grid_dim, sm_cnt, stream, diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index ef7ae7c804..df0319e181 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT return ix_min; } -template +template __launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] + postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); const bool valid = chunk_ix < n_probes; neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; + valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; } /** @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL * probed clusters / defined by the `chunk_indices`. * We assume the searched sample sizes (for a single query) fit into `uint32_t`. */ -template -void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] +template +void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to { constexpr int kPNThreads = 256; const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel + postprocess_neighbors_kernel <<>>(neighbors_out, neighbors_in, db_indices, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 7c2d1d2157..140a9f17c8 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) RAFT_EXPLICIT; @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 6fc528e26b..9cd8b70148 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, const uint32_t queries_offset, @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t* chunk_indices, const uint32_t dim, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances) { extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; @@ -719,8 +718,8 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) distances += query_id * k * gridDim.x + blockIdx.x * k; } else { distances += query_id * uint64_t(max_samples); - chunk_indices += (n_probes * query_id); } + chunk_indices += (n_probes * query_id); coarse_index += query_id * n_probes; } @@ -728,7 +727,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - using local_topk_t = block_sort_t; + using local_topk_t = block_sort_t; local_topk_t queue(k); { using align_warp = Pow2; @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 uint32_t sample_offset = 0; - if constexpr (!kManageLocalTopK) { - if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } - assert(list_length == chunk_indices[probe_id] - sample_offset); - assert(sample_offset + list_length <= max_samples); - } + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); constexpr int kUnroll = WarpSize / Veclen; constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; if constexpr (kManageLocalTopK) { - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); + queue.add(val, sample_offset + vec_id); } else { if (vec_id < list_length) distances[sample_offset + vec_id] = val; } @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda, const uint32_t max_samples, const uint32_t* chunk_indices, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), queries_offset + query_offset, @@ -945,8 +940,8 @@ void launch_kernel(Lambda lambda, distances += grid_dim_y * grid_dim_x * k; } else { distances += grid_dim_y * max_samples; - chunk_indices += grid_dim_y * n_probes; } + chunk_indices += grid_dim_y * n_probes; coarse_index += grid_dim_y * n_probes; } } @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index& index, const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 98bdeda42f..441fb76b2f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -67,13 +67,16 @@ void search_impl(raft::resources const& handle, // Optional structures if postprocessing is required // The topk distance value of candidate vectors from each cluster(list) rmm::device_uvector distances_tmp_dev(0, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector indices_tmp_dev(0, stream, search_mr); // Number of samples for each query rmm::device_uvector num_samples(0, stream, search_mr); // Offsets per probe for each query rmm::device_uvector chunk_index(0, stream, search_mr); + // The topk index of candidate vectors from each cluster(list), local index offset + // also we might need additional storage for select_k + rmm::device_uvector indices_tmp_dev(0, stream, search_mr); + rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + size_t float_query_size; if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); @@ -175,23 +178,29 @@ void search_impl(raft::resources const& handle, grid_dim_x = 1; } + num_samples.resize(n_queries, stream); + chunk_index.resize(n_queries_probes, stream); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + auto distances_dev_ptr = distances; - auto indices_dev_ptr = neighbors; + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(neighbors); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + + uint32_t* indices_dev_ptr = nullptr; bool manage_local_topk = is_local_topk_feasible(k); if (!manage_local_topk || grid_dim_x > 1) { - if (!manage_local_topk) { - num_samples.resize(n_queries, stream); - chunk_index.resize(n_queries_probes, stream); - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); - } - auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); distances_tmp_dev.resize(target_size, stream); @@ -199,6 +208,8 @@ void search_impl(raft::resources const& handle, distances_dev_ptr = distances_tmp_dev.data(); indices_dev_ptr = indices_tmp_dev.data(); + } else { + indices_dev_ptr = neighbors_uint32; } ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( @@ -224,31 +235,33 @@ void search_impl(raft::resources const& handle, // Merge topk values from different blocks if (!manage_local_topk || grid_dim_x > 1) { - matrix::detail::select_k(handle, - distances_tmp_dev.data(), - indices_tmp_dev.data(), - n_queries, - manage_local_topk ? (k * grid_dim_x) : max_samples, - k, - distances, - neighbors, - select_min); - - if (!manage_local_topk) { - // post process distances && neighbor IDs - ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); - ivf::detail::postprocess_neighbors(neighbors, - neighbors, - index.inds_ptrs().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - n_queries, - n_probes, - k, - stream); - } + matrix::detail::select_k(handle, + distances_tmp_dev.data(), + indices_tmp_dev.data(), + n_queries, + manage_local_topk ? (k * grid_dim_x) : max_samples, + k, + distances, + neighbors_uint32, + select_min, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); + } + if (!manage_local_topk) { + // post process distances && neighbor IDs + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); } + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); } /** See raft::neighbors::ivf_flat::search docs */ diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d445f909e5..4c5da38092 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -447,7 +447,10 @@ void ivfpq_search_worker(raft::resources const& handle, topK, topk_dists.data(), neighbors_uint32, - true); + true, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); // Postprocessing ivf::detail::postprocess_distances( diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index e76e52657b..bdc29ca121 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -88,6 +88,27 @@ void refine_device(raft::resources const& handle, n_queries, n_candidates); uint32_t grid_dim_x = 1; + + // the neighbor ids will be computed in uint32_t as offset + rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); + // Offsets per probe for each query [n_queries] as n_probes = 1 + rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); + + // we know that each cluster has exactly n_candidates entries + thrust::fill(resource::get_thrust_policy(handle), + chunk_index.data(), + chunk_index.data() + n_queries, + uint32_t(n_candidates)); + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(indices.data_handle()); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), + resource::get_cuda_stream(handle)); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, @@ -100,13 +121,24 @@ void refine_device(raft::resources const& handle, 1, k, 0, - nullptr, + chunk_index.data(), raft::distance::is_min_close(metric), raft::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), + neighbors_uint32, distances.data_handle(), grid_dim_x, resource::get_cuda_stream(handle)); + + // postprocessing -- neighbors from position to actual id + ivf::detail::postprocess_neighbors(indices.data_handle(), + neighbors_uint32, + refinement_index.inds_ptrs().data_handle(), + fake_coarse_idx.data(), + chunk_index.data(), + n_queries, + 1, + k, + resource::get_cuda_stream(handle)); } } // namespace raft::neighbors::detail diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index e32b4ef6f0..bf234aacbf 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 21c954ca46..7f0511a76a 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -29,7 +29,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 7f163a0b0d..e68b1e32df 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 87b6525356..5aa40d8c9d 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index e698f811d8..9aba147edf 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index 0eee20b1fa..bc513e4aeb 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index f4e6bae21f..e46c7d46bb 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index def33e493e..5ac820e0dd 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu index e96600ee02..4d847cdeb1 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu @@ -35,7 +35,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 13c9d2e283..8a0e8f0118 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 51f02343fc..7cad992e2b 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a111de0762..7278f71a24 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -549,6 +549,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { EXPECT_FALSE(unacceptable_node); double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -556,7 +557,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), @@ -668,6 +670,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -675,7 +678,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index afd083d512..6be2ac7fc7 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -35,6 +35,7 @@ #include #include +#include namespace raft::neighbors { @@ -153,13 +154,40 @@ auto calc_recall(const std::vector& expected_idx, static_cast(match_count) / static_cast(total_count), match_count, total_count); } +/** check uniqueness of indices + */ +template +auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +{ + size_t max_count; + std::set unique_indices; + for (size_t i = 0; i < rows; ++i) { + unique_indices.clear(); + max_count = 0; + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + if (act_idx == std::numeric_limits::max()) { + max_count++; + } else if (unique_indices.find(act_idx) == unique_indices.end()) { + unique_indices.insert(act_idx); + } else { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } + } + } + return testing::AssertionSuccess(); +} + template auto eval_recall(const std::vector& expected_idx, const std::vector& actual_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, rows, cols); @@ -176,7 +204,10 @@ auto eval_recall(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } /** Overload of calc_recall to account for distances @@ -224,7 +255,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -241,7 +273,10 @@ auto eval_neighbours(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } template