From 8d742fdb96cae5dc24444413a43fb033ae4b882c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 23 Nov 2022 17:00:15 +0100 Subject: [PATCH] Add ANN refinement method (#1038) This PR implements refinement for approximate nearest neighbor search. Refinement is a post processing step for ANN search, it follows an ANN search that returned `k0` neighbor candidates, and select `k` out of these candidates. The selection by calculating exact distances from the original dataset. Refinement can increase accuracy. It is useful for ANN methods that quantize the dataset and therefore loose accuracy during distance calculation (e.g. IVF-PQ). Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1038 --- cpp/bench/CMakeLists.txt | 1 + cpp/bench/neighbors/refine.cu | 122 +++++++++ cpp/include/raft/neighbors/detail/refine.cuh | 232 ++++++++++++++++++ cpp/include/raft/neighbors/refine.cuh | 98 ++++++++ .../spatial/knn/detail/ivf_flat_build.cuh | 84 ++++++- .../spatial/knn/detail/ivf_flat_search.cuh | 45 ++-- cpp/test/CMakeLists.txt | 1 + cpp/test/neighbors/ann_utils.cuh | 3 + cpp/test/neighbors/refine.cu | 129 ++++++++++ cpp/test/neighbors/refine_helper.cuh | 140 +++++++++++ 10 files changed, 837 insertions(+), 18 deletions(-) create mode 100644 cpp/bench/neighbors/refine.cu create mode 100644 cpp/include/raft/neighbors/detail/refine.cuh create mode 100644 cpp/include/raft/neighbors/refine.cuh create mode 100644 cpp/test/neighbors/refine.cu create mode 100644 cpp/test/neighbors/refine_helper.cuh diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 1d729d728b..4e6b6ceb40 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -124,6 +124,7 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_float_uint32_t.cu bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu + bench/neighbors/refine.cu bench/neighbors/selection.cu bench/main.cpp OPTIONAL diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu new file mode 100644 index 0000000000..a038905ace --- /dev/null +++ b/cpp/bench/neighbors/refine.cu @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + +#if defined RAFT_NN_COMPILED +#include +#endif + +#include +#include +#include + +#include "../../test/neighbors/refine_helper.cuh" + +#include +#include + +using namespace raft::neighbors::detail; + +namespace raft::bench::neighbors { + +template +inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::ostream& +{ + os << p.n_rows << "#" << p.dim << "#" << p.n_queries << "#" << p.k0 << "#" << p.k << "#" + << (p.host_data ? "host" : "device"); + return os; +} + +RefineInputs p; + +template +class RefineAnn : public fixture { + public: + RefineAnn(RefineInputs p) : data(handle_, p) {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << data.p; + state.SetLabel(label_stream.str()); + + auto old_mr = rmm::mr::get_current_device_resource(); + rmm::mr::pool_memory_resource pool_mr(old_mr); + rmm::mr::set_current_device_resource(&pool_mr); + + if (data.p.host_data) { + loop_on_state(state, [this]() { + raft::neighbors::refine(handle_, + data.dataset_host.view(), + data.queries_host.view(), + data.candidates_host.view(), + data.refined_indices_host.view(), + data.refined_distances_host.view(), + data.p.metric); + }); + } else { + loop_on_state(state, [&]() { + raft::neighbors::refine(handle_, + data.dataset.view(), + data.queries.view(), + data.candidates.view(), + data.refined_indices.view(), + data.refined_distances.view(), + data.p.metric); + }); + } + rmm::mr::set_current_device_resource(old_mr); + } + + private: + raft::handle_t handle_; + RefineHelper data; +}; + +std::vector> getInputs() +{ + std::vector> out; + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; + for (bool host_data : {true, false}) { + for (int64_t n_queries : {1000, 10000}) { + for (int64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + } + } + } + return out; +} + +using refine_float_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); + +using refine_uint8_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +} // namespace raft::bench::neighbors diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh new file mode 100644 index 0000000000..c838af85d6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::neighbors::detail { + +/** Checks whether the input data extents are compatible. */ +template +void check_input(extents_t dataset, + extents_t queries, + extents_t candidates, + extents_t indices, + extents_t distances, + distance::DistanceType metric) +{ + auto n_queries = queries.extent(0); + auto k = distances.extent(1); + + RAFT_EXPECTS(k <= raft::spatial::knn::detail::topk::kMaxCapacity, + "k must be lest than topk::kMaxCapacity (%d).", + raft::spatial::knn::detail::topk::kMaxCapacity); + + RAFT_EXPECTS(indices.extent(0) == n_queries && distances.extent(0) == n_queries && + candidates.extent(0) == n_queries, + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + + RAFT_EXPECTS(indices.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), + "Number of columns must be equal for dataset and queries"); + + RAFT_EXPECTS(candidates.extent(1) >= k, + "Number of neighbor candidates must not be smaller than k (%d vs %d)", + static_cast(candidates.extent(1)), + static_cast(k)); +} + +/** + * See raft::neighbors::refine for docs. + */ +template +void refine_device(raft::handle_t const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + matrix_idx n_candidates = neighbor_candidates.extent(1); + matrix_idx n_queries = queries.extent(0); + matrix_idx dim = dataset.extent(1); + uint32_t k = static_cast(indices.extent(1)); + + common::nvtx::range fun_scope( + "neighbors::refine(%zu, %u)", size_t(n_queries), uint32_t(n_candidates)); + + check_input(dataset.extents(), + queries.extents(), + neighbor_candidates.extents(), + indices.extents(), + distances.extents(), + metric); + + // The refinement search can be mapped to an IVF flat search: + // - We consider that the candidate vectors form a cluster, separately for each query. + // - In other words, the n_queries * n_candidates vectors form n_queries clusters, each with + // n_candidates elements. + // - We consider that the coarse level search is already performed and assigned a single cluster + // to search for each query (the cluster formed from the corresponding candidates). + // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. + rmm::device_uvector fake_coarse_idx(n_queries, handle.get_stream()); + + thrust::sequence( + handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries); + + raft::neighbors::ivf_flat::index refinement_index( + handle, metric, n_queries, false, dim); + + raft::spatial::knn::ivf_flat::detail::fill_refinement_index(handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + n_queries, + n_candidates); + + uint32_t grid_dim_x = 1; + raft::spatial::knn::ivf_flat::detail::ivfflat_interleaved_scan< + data_t, + typename raft::spatial::knn::detail::utils::config::value_t, + idx_t>(refinement_index, + queries.data_handle(), + fake_coarse_idx.data(), + static_cast(n_queries), + refinement_index.metric(), + 1, + k, + raft::spatial::knn::ivf_flat::detail::is_min_close(metric), + indices.data_handle(), + distances.data_handle(), + grid_dim_x, + handle.get_stream()); +} + +/** Helper structure for naive CPU implementation of refine. */ +typedef struct { + uint64_t id; + float distance; +} struct_for_refinement; + +int _postprocessing_qsort_compare(const void* v1, const void* v2) +{ + // sort in ascending order + if (((struct_for_refinement*)v1)->distance > ((struct_for_refinement*)v2)->distance) { + return 1; + } else if (((struct_for_refinement*)v1)->distance < ((struct_for_refinement*)v2)->distance) { + return -1; + } else { + return 0; + } +} + +/** + * Naive CPU implementation of refine operation + * + * All pointers are expected to be accessible on the host. + */ +template +void refine_host(raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + check_input(dataset.extents(), + queries.extents(), + neighbor_candidates.extents(), + indices.extents(), + distances.extents(), + metric); + + switch (metric) { + case raft::distance::DistanceType::L2Expanded: break; + case raft::distance::DistanceType::InnerProduct: break; + default: throw raft::logic_error("Unsopported metric"); + } + + size_t numDataset = dataset.extent(0); + size_t numQueries = queries.extent(0); + size_t dimDataset = dataset.extent(1); + const data_t* dataset_ptr = dataset.data_handle(); + const data_t* queries_ptr = queries.data_handle(); + const idx_t* neighbors = neighbor_candidates.data_handle(); + idx_t topK = neighbor_candidates.extent(1); + idx_t refinedTopK = indices.extent(1); + idx_t* refinedNeighbors = indices.data_handle(); + distance_t* refinedDistances = distances.data_handle(); + + common::nvtx::range fun_scope( + "neighbors::refine_host(%zu, %u)", size_t(numQueries), uint32_t(topK)); + +#pragma omp parallel + { + struct_for_refinement* sfr = + (struct_for_refinement*)malloc(sizeof(struct_for_refinement) * topK); + for (size_t i = omp_get_thread_num(); i < numQueries; i += omp_get_num_threads()) { + // compute distance with original dataset vectors + const data_t* cur_query = queries_ptr + ((uint64_t)dimDataset * i); + for (size_t j = 0; j < (size_t)topK; j++) { + idx_t id = neighbors[j + (topK * i)]; + const data_t* cur_dataset = dataset_ptr + ((uint64_t)dimDataset * id); + float distance = 0.0; + for (size_t k = 0; k < (size_t)dimDataset; k++) { + float val_q = (float)(cur_query[k]); + float val_d = (float)(cur_dataset[k]); + if (metric == raft::distance::DistanceType::InnerProduct) { + distance += -val_q * val_d; // Negate because we sort in scending order. + } else { + distance += (val_q - val_d) * (val_q - val_d); + } + } + sfr[j].id = id; + sfr[j].distance = distance; + } + + qsort(sfr, topK, sizeof(struct_for_refinement), _postprocessing_qsort_compare); + + for (size_t j = 0; j < (size_t)refinedTopK; j++) { + refinedNeighbors[j + (refinedTopK * i)] = sfr[j].id; + if (refinedDistances == NULL) continue; + if (metric == raft::distance::DistanceType::InnerProduct) { + refinedDistances[j + (refinedTopK * i)] = -sfr[j].distance; + } else { + refinedDistances[j + (refinedTopK * i)] = -sfr[j].distance; + } + } + } + free(sfr); + } +} + +} // namespace raft::neighbors::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/refine.cuh b/cpp/include/raft/neighbors/refine.cuh new file mode 100644 index 0000000000..7b6708f18c --- /dev/null +++ b/cpp/include/raft/neighbors/refine.cuh @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors { + +/** + * @brief Refine nearest neighbor search. + * + * Refinement is an operation that follows an approximate NN search. The approximate search has + * already selected n_candidates neighbor candidates for each query. We narrow it down to k + * neighbors. For each query, we calculate the exact distance between the query and its + * n_candidates neighbor candidate, and select the k nearest ones. + * + * The k nearest neighbors and distances are returned. + * + * Example usage + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search m = 4 * k nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, N, 4 * k, neighbor_candidates, + * out_dists_tmp); + * // refine it to the k nearest one + * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, + * index.metric()); + * @endcode + * + * + * @param[in] handle the raft handle + * @param[in] dataset device matrix that stores the dataset [n_rows, dims] + * @param[in] queries device matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where + * n_candidates >= k + * @param[out] indices device matrix that stores the refined indices [n_queries, k] + * @param[out] distances device matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::handle_t const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_device(handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +/** Same as above, but all input and out data is in host memory. + * @param[in] handle the raft handle + * @param[in] dataset host matrix that stores the dataset [n_rows, dims] + * @param[in] queries host matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates host matrix with indices of candidate vectors [n_queries, + * n_candidates], where n_candidates >= k + * @param[out] indices host matrix that stores the refined indices [n_queries, k] + * @param[out] distances host matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::handle_t const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_host(dataset, queries, neighbor_candidates, indices, distances, metric); +} +} // namespace raft::neighbors diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index e9af97b547..14c4dd85f1 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -50,6 +51,8 @@ using namespace raft::spatial::knn::detail; // NOLINT * @tparam T element type. * @tparam IdxT type of the indices in the source source_vecs * @tparam LabelT label type + * @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise + * we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1. * * @param[in] labels device pointer to the cluster ids for each row [n_rows] * @param[in] list_offsets device pointer to the cluster offsets in the output (index) [n_lists] @@ -64,7 +67,7 @@ using namespace raft::spatial::knn::detail; // NOLINT * @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`. * */ -template +template __global__ void build_index_kernel(const LabelT* labels, const IdxT* list_offsets, const T* source_vecs, @@ -95,8 +98,11 @@ __global__ void build_index_kernel(const LabelT* labels, list_data += (list_offset + group_offset) * dim; // Point to the source vector - source_vecs += i * dim; - + if constexpr (gather_src) { + source_vecs += source_ixs[i] * dim; + } else { + source_vecs += i * dim; + } // Interleave dimensions of the source vector while recording it. // NB: such `veclen` is selected, that `dim % veclen == 0` for (uint32_t l = 0; l < dim; l += veclen) { @@ -299,4 +305,76 @@ inline auto build( } } +/** + * Build an index that can be used in refinement operation. + * + * See raft::neighbors::refine for details on the refinement operation. + * + * The returned index cannot be used for a regular ivf_flat::search. The index misses information + * about coarse clusters. Instead, the neighbor candidates are assumed to form clusters, one for + * each query. The candidate vectors are gathered into the index dataset, that can be later used + * in ivfflat_interleaved_scan. + * + * @param[in] handle the raft handle + * @param[inout] refinement_index + * @param[in] dataset device pointer to dataset vectors, size [n_rows, dim]. Note that n_rows is + * not known to this function, but each candidate_idx has to be smaller than n_rows. + * @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates] + * @param[in] n_candidates of neighbor_candidates + */ +template +inline void fill_refinement_index(const handle_t& handle, + index* refinement_index, + const T* dataset, + const IdxT* candidate_idx, + IdxT n_queries, + uint32_t n_candidates) +{ + using LabelT = uint32_t; + + auto stream = handle.get_stream(); + uint32_t n_lists = n_queries; + common::nvtx::range fun_scope( + "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); + + rmm::device_uvector new_labels(n_queries * n_candidates, stream); + linalg::writeOnlyUnaryOp( + new_labels.data(), + n_queries * n_candidates, + [n_candidates] __device__(LabelT * out, uint32_t i) { *out = i / n_candidates; }, + stream); + + auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); + auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); + // We do not fill centers and center norms, since we will not run coarse search. + + // Calculate new offsets + uint32_t n_roundup = Pow2::roundUp(n_candidates); + linalg::writeOnlyUnaryOp( + refinement_index->list_offsets().data_handle(), + refinement_index->list_offsets().size(), + [n_roundup] __device__(IdxT * out, uint32_t i) { *out = i * n_roundup; }, + stream); + + IdxT index_size = n_roundup * n_lists; + refinement_index->allocate( + handle, index_size, refinement_index->metric() == raft::distance::DistanceType::L2Expanded); + + RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); + + const dim3 block_dim(256); + const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); + build_index_kernel + <<>>(new_labels.data(), + list_offsets_ptr, + dataset, + candidate_idx, + refinement_index->data().data_handle(), + refinement_index->indices().data_handle(), + list_sizes_ptr, + n_queries * n_candidates, + refinement_index->dim(), + refinement_index->veclen()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} } // namespace raft::spatial::knn::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index f6b9e62008..94f4dc96c6 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1212,6 +1212,26 @@ void search_impl(const handle_t& handle, } } +/** + * Whether minimal distance corresponds to similar elements (using the given metric). + */ +inline bool is_min_close(distance::DistanceType metric) +{ + bool select_min; + switch (metric) { + case raft::distance::DistanceType::InnerProduct: + case raft::distance::DistanceType::CosineExpanded: + case raft::distance::DistanceType::CorrelationExpanded: + // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger + // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 + // {perform_k_selection}) + select_min = false; + break; + default: select_min = true; + } + return select_min; +} + /** See raft::spatial::knn::ivf_flat::search docs */ template inline void search(const handle_t& handle, @@ -1231,27 +1251,22 @@ inline void search(const handle_t& handle, "n_probes (number of clusters to probe in the search) must be positive."); auto n_probes = std::min(params.n_probes, index.n_lists()); - bool select_min; - switch (index.metric()) { - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: - // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger - // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 - // {perform_k_selection}) - select_min = false; - break; - default: select_min = true; - } - auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); if (pool_guard) { RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } - return search_impl( - handle, index, queries, n_queries, k, n_probes, select_min, neighbors, distances, mr); + return search_impl(handle, + index, + queries, + n_queries, + k, + n_probes, + is_min_close(index.metric()), + neighbors, + distances, + mr); } } // namespace raft::spatial::knn::ivf_flat::detail diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3e8f944a5b..dae0f6f6b1 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -245,6 +245,7 @@ if(BUILD_TESTS) test/neighbors/ball_cover.cu test/neighbors/epsilon_neighborhood.cu test/neighbors/faiss_mr.cu + test/neighbors/refine.cu test/neighbors/selection.cu OPTIONAL DIST diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 3e3735719f..07ef410d36 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -25,6 +25,9 @@ #include #include +#include "../test_utils.h" +#include + namespace raft::neighbors { struct print_dtype { diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu new file mode 100644 index 0000000000..e1700e44b3 --- /dev/null +++ b/cpp/test/neighbors/refine.cu @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "ann_utils.cuh" + +#include "refine_helper.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#if defined RAFT_NN_COMPILED +#include +#endif + +#include + +namespace raft::neighbors { + +template +class RefineTest : public ::testing::TestWithParam> { + public: + RefineTest() + : stream_(handle_.get_stream()), + data(handle_, ::testing::TestWithParam>::GetParam()) + { + } + + protected: + public: // tamas remove + void testRefine() + { + std::vector indices(data.p.n_queries * data.p.k); + std::vector distances(data.p.n_queries * data.p.k); + + if (data.p.host_data) { + raft::neighbors::refine(handle_, + data.dataset_host.view(), + data.queries_host.view(), + data.candidates_host.view(), + data.refined_indices_host.view(), + data.refined_distances_host.view(), + data.p.metric); + raft::copy(indices.data(), + data.refined_indices_host.data_handle(), + data.refined_indices_host.size(), + stream_); + raft::copy(distances.data(), + data.refined_distances_host.data_handle(), + data.refined_distances_host.size(), + stream_); + + } else { + raft::neighbors::refine(handle_, + data.dataset.view(), + data.queries.view(), + data.candidates.view(), + data.refined_indices.view(), + data.refined_distances.view(), + data.p.metric); + update_host(distances.data(), + data.refined_distances.data_handle(), + data.refined_distances.size(), + stream_); + update_host( + indices.data(), data.refined_indices.data_handle(), data.refined_indices.size(), stream_); + } + handle_.sync_stream(stream_); + + double min_recall = 1; + + ASSERT_TRUE(raft::neighbors::eval_neighbours(data.true_refined_indices_host, + indices, + data.true_refined_distances_host, + distances, + data.p.n_queries, + data.p.k, + 0.001, + min_recall)); + } + + public: + raft::handle_t handle_; + rmm::cuda_stream_view stream_; + detail::RefineHelper data; +}; + +const std::vector> inputs = + raft::util::itertools::product>( + {137}, + {1000}, + {16}, + {1, 10, 33}, + {33}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, + {false, true}); + +typedef RefineTest RefineTestF; +TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } + +INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); + +typedef RefineTest RefineTestF_uint8; +TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } +INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine_helper.cuh b/cpp/test/neighbors/refine_helper.cuh new file mode 100644 index 0000000000..3c69a8f5b7 --- /dev/null +++ b/cpp/test/neighbors/refine_helper.cuh @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "ann_utils.cuh" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::neighbors::detail { + +template +struct RefineInputs { + IdxT n_queries; + IdxT n_rows; + IdxT dim; + IdxT k; // after refinement + IdxT k0; // initial k before refinement (k0 >= k). + raft::distance::DistanceType metric; + bool host_data; +}; + +/** Helper class to allocate arrays and generate input data for refinement test and benchmark. */ +template +class RefineHelper { + public: + RefineHelper(const raft::handle_t& handle, RefineInputs params) + : handle_(handle), stream_(handle.get_stream()), p(params) + { + raft::random::Rng r(1234ULL); + + dataset = raft::make_device_matrix(handle_, p.n_rows, p.dim); + queries = raft::make_device_matrix(handle_, p.n_queries, p.dim); + if constexpr (std::is_same{}) { + r.uniform(dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0), stream_); + r.uniform(queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0), stream_); + } else { + r.uniformInt(dataset.data_handle(), dataset.size(), DataT(1), DataT(20), stream_); + r.uniformInt(queries.data_handle(), queries.size(), DataT(1), DataT(20), stream_); + } + + refined_distances = raft::make_device_matrix(handle_, p.n_queries, p.k); + refined_indices = raft::make_device_matrix(handle_, p.n_queries, p.k); + + // Generate candidate vectors + { + candidates = raft::make_device_matrix(handle_, p.n_queries, p.k0); + rmm::device_uvector distances_tmp(p.n_queries * p.k0, stream_); + raft::neighbors::naiveBfKnn(distances_tmp.data(), + candidates.data_handle(), + queries.data_handle(), + dataset.data_handle(), + p.n_queries, + p.n_rows, + p.dim, + p.k0, + p.metric, + stream_); + handle_.sync_stream(stream_); + } + + if (p.host_data) { + dataset_host = raft::make_host_matrix(p.n_rows, p.dim); + queries_host = raft::make_host_matrix(p.n_queries, p.dim); + candidates_host = raft::make_host_matrix(p.n_queries, p.k0); + + raft::copy(dataset_host.data_handle(), dataset.data_handle(), dataset.size(), stream_); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries.size(), stream_); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates.size(), stream_); + + refined_distances_host = raft::make_host_matrix(p.n_queries, p.k); + refined_indices_host = raft::make_host_matrix(p.n_queries, p.k); + handle_.sync_stream(stream_); + } + + // Generate ground thruth for testing. + { + rmm::device_uvector distances_dev(p.n_queries * p.k, stream_); + rmm::device_uvector indices_dev(p.n_queries * p.k, stream_); + raft::neighbors::naiveBfKnn(distances_dev.data(), + indices_dev.data(), + queries.data_handle(), + dataset.data_handle(), + p.n_queries, + p.n_rows, + p.dim, + p.k, + p.metric, + stream_); + true_refined_distances_host.resize(p.n_queries * p.k); + true_refined_indices_host.resize(p.n_queries * p.k); + raft::copy(true_refined_indices_host.data(), indices_dev.data(), indices_dev.size(), stream_); + raft::copy( + true_refined_distances_host.data(), distances_dev.data(), distances_dev.size(), stream_); + handle_.sync_stream(stream_); + } + } + + public: + RefineInputs p; + const raft::handle_t& handle_; + rmm::cuda_stream_view stream_; + + raft::device_matrix dataset; + raft::device_matrix queries; + raft::device_matrix candidates; // Neighbor candidate indices + raft::device_matrix refined_indices; + raft::device_matrix refined_distances; + + raft::host_matrix dataset_host; + raft::host_matrix queries_host; + raft::host_matrix candidates_host; + raft::host_matrix refined_indices_host; + raft::host_matrix refined_distances_host; + + std::vector true_refined_indices_host; + std::vector true_refined_distances_host; +}; +} // namespace raft::neighbors::detail \ No newline at end of file