diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 1d729d728b..4e6b6ceb40 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -124,6 +124,7 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_float_uint32_t.cu bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu + bench/neighbors/refine.cu bench/neighbors/selection.cu bench/main.cpp OPTIONAL diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu new file mode 100644 index 0000000000..a038905ace --- /dev/null +++ b/cpp/bench/neighbors/refine.cu @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + +#if defined RAFT_NN_COMPILED +#include +#endif + +#include +#include +#include + +#include "../../test/neighbors/refine_helper.cuh" + +#include +#include + +using namespace raft::neighbors::detail; + +namespace raft::bench::neighbors { + +template +inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::ostream& +{ + os << p.n_rows << "#" << p.dim << "#" << p.n_queries << "#" << p.k0 << "#" << p.k << "#" + << (p.host_data ? "host" : "device"); + return os; +} + +RefineInputs p; + +template +class RefineAnn : public fixture { + public: + RefineAnn(RefineInputs p) : data(handle_, p) {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << data.p; + state.SetLabel(label_stream.str()); + + auto old_mr = rmm::mr::get_current_device_resource(); + rmm::mr::pool_memory_resource pool_mr(old_mr); + rmm::mr::set_current_device_resource(&pool_mr); + + if (data.p.host_data) { + loop_on_state(state, [this]() { + raft::neighbors::refine(handle_, + data.dataset_host.view(), + data.queries_host.view(), + data.candidates_host.view(), + data.refined_indices_host.view(), + data.refined_distances_host.view(), + data.p.metric); + }); + } else { + loop_on_state(state, [&]() { + raft::neighbors::refine(handle_, + data.dataset.view(), + data.queries.view(), + data.candidates.view(), + data.refined_indices.view(), + data.refined_distances.view(), + data.p.metric); + }); + } + rmm::mr::set_current_device_resource(old_mr); + } + + private: + raft::handle_t handle_; + RefineHelper data; +}; + +std::vector> getInputs() +{ + std::vector> out; + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; + for (bool host_data : {true, false}) { + for (int64_t n_queries : {1000, 10000}) { + for (int64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + } + } + } + return out; +} + +using refine_float_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); + +using refine_uint8_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +} // namespace raft::bench::neighbors diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh new file mode 100644 index 0000000000..c838af85d6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::neighbors::detail { + +/** Checks whether the input data extents are compatible. */ +template +void check_input(extents_t dataset, + extents_t queries, + extents_t candidates, + extents_t indices, + extents_t distances, + distance::DistanceType metric) +{ + auto n_queries = queries.extent(0); + auto k = distances.extent(1); + + RAFT_EXPECTS(k <= raft::spatial::knn::detail::topk::kMaxCapacity, + "k must be lest than topk::kMaxCapacity (%d).", + raft::spatial::knn::detail::topk::kMaxCapacity); + + RAFT_EXPECTS(indices.extent(0) == n_queries && distances.extent(0) == n_queries && + candidates.extent(0) == n_queries, + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + + RAFT_EXPECTS(indices.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), + "Number of columns must be equal for dataset and queries"); + + RAFT_EXPECTS(candidates.extent(1) >= k, + "Number of neighbor candidates must not be smaller than k (%d vs %d)", + static_cast(candidates.extent(1)), + static_cast(k)); +} + +/** + * See raft::neighbors::refine for docs. + */ +template +void refine_device(raft::handle_t const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + matrix_idx n_candidates = neighbor_candidates.extent(1); + matrix_idx n_queries = queries.extent(0); + matrix_idx dim = dataset.extent(1); + uint32_t k = static_cast(indices.extent(1)); + + common::nvtx::range fun_scope( + "neighbors::refine(%zu, %u)", size_t(n_queries), uint32_t(n_candidates)); + + check_input(dataset.extents(), + queries.extents(), + neighbor_candidates.extents(), + indices.extents(), + distances.extents(), + metric); + + // The refinement search can be mapped to an IVF flat search: + // - We consider that the candidate vectors form a cluster, separately for each query. + // - In other words, the n_queries * n_candidates vectors form n_queries clusters, each with + // n_candidates elements. + // - We consider that the coarse level search is already performed and assigned a single cluster + // to search for each query (the cluster formed from the corresponding candidates). + // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. + rmm::device_uvector fake_coarse_idx(n_queries, handle.get_stream()); + + thrust::sequence( + handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries); + + raft::neighbors::ivf_flat::index refinement_index( + handle, metric, n_queries, false, dim); + + raft::spatial::knn::ivf_flat::detail::fill_refinement_index(handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + n_queries, + n_candidates); + + uint32_t grid_dim_x = 1; + raft::spatial::knn::ivf_flat::detail::ivfflat_interleaved_scan< + data_t, + typename raft::spatial::knn::detail::utils::config::value_t, + idx_t>(refinement_index, + queries.data_handle(), + fake_coarse_idx.data(), + static_cast(n_queries), + refinement_index.metric(), + 1, + k, + raft::spatial::knn::ivf_flat::detail::is_min_close(metric), + indices.data_handle(), + distances.data_handle(), + grid_dim_x, + handle.get_stream()); +} + +/** Helper structure for naive CPU implementation of refine. */ +typedef struct { + uint64_t id; + float distance; +} struct_for_refinement; + +int _postprocessing_qsort_compare(const void* v1, const void* v2) +{ + // sort in ascending order + if (((struct_for_refinement*)v1)->distance > ((struct_for_refinement*)v2)->distance) { + return 1; + } else if (((struct_for_refinement*)v1)->distance < ((struct_for_refinement*)v2)->distance) { + return -1; + } else { + return 0; + } +} + +/** + * Naive CPU implementation of refine operation + * + * All pointers are expected to be accessible on the host. + */ +template +void refine_host(raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + check_input(dataset.extents(), + queries.extents(), + neighbor_candidates.extents(), + indices.extents(), + distances.extents(), + metric); + + switch (metric) { + case raft::distance::DistanceType::L2Expanded: break; + case raft::distance::DistanceType::InnerProduct: break; + default: throw raft::logic_error("Unsopported metric"); + } + + size_t numDataset = dataset.extent(0); + size_t numQueries = queries.extent(0); + size_t dimDataset = dataset.extent(1); + const data_t* dataset_ptr = dataset.data_handle(); + const data_t* queries_ptr = queries.data_handle(); + const idx_t* neighbors = neighbor_candidates.data_handle(); + idx_t topK = neighbor_candidates.extent(1); + idx_t refinedTopK = indices.extent(1); + idx_t* refinedNeighbors = indices.data_handle(); + distance_t* refinedDistances = distances.data_handle(); + + common::nvtx::range fun_scope( + "neighbors::refine_host(%zu, %u)", size_t(numQueries), uint32_t(topK)); + +#pragma omp parallel + { + struct_for_refinement* sfr = + (struct_for_refinement*)malloc(sizeof(struct_for_refinement) * topK); + for (size_t i = omp_get_thread_num(); i < numQueries; i += omp_get_num_threads()) { + // compute distance with original dataset vectors + const data_t* cur_query = queries_ptr + ((uint64_t)dimDataset * i); + for (size_t j = 0; j < (size_t)topK; j++) { + idx_t id = neighbors[j + (topK * i)]; + const data_t* cur_dataset = dataset_ptr + ((uint64_t)dimDataset * id); + float distance = 0.0; + for (size_t k = 0; k < (size_t)dimDataset; k++) { + float val_q = (float)(cur_query[k]); + float val_d = (float)(cur_dataset[k]); + if (metric == raft::distance::DistanceType::InnerProduct) { + distance += -val_q * val_d; // Negate because we sort in scending order. + } else { + distance += (val_q - val_d) * (val_q - val_d); + } + } + sfr[j].id = id; + sfr[j].distance = distance; + } + + qsort(sfr, topK, sizeof(struct_for_refinement), _postprocessing_qsort_compare); + + for (size_t j = 0; j < (size_t)refinedTopK; j++) { + refinedNeighbors[j + (refinedTopK * i)] = sfr[j].id; + if (refinedDistances == NULL) continue; + if (metric == raft::distance::DistanceType::InnerProduct) { + refinedDistances[j + (refinedTopK * i)] = -sfr[j].distance; + } else { + refinedDistances[j + (refinedTopK * i)] = -sfr[j].distance; + } + } + } + free(sfr); + } +} + +} // namespace raft::neighbors::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/refine.cuh b/cpp/include/raft/neighbors/refine.cuh new file mode 100644 index 0000000000..7b6708f18c --- /dev/null +++ b/cpp/include/raft/neighbors/refine.cuh @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors { + +/** + * @brief Refine nearest neighbor search. + * + * Refinement is an operation that follows an approximate NN search. The approximate search has + * already selected n_candidates neighbor candidates for each query. We narrow it down to k + * neighbors. For each query, we calculate the exact distance between the query and its + * n_candidates neighbor candidate, and select the k nearest ones. + * + * The k nearest neighbors and distances are returned. + * + * Example usage + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search m = 4 * k nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, N, 4 * k, neighbor_candidates, + * out_dists_tmp); + * // refine it to the k nearest one + * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, + * index.metric()); + * @endcode + * + * + * @param[in] handle the raft handle + * @param[in] dataset device matrix that stores the dataset [n_rows, dims] + * @param[in] queries device matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where + * n_candidates >= k + * @param[out] indices device matrix that stores the refined indices [n_queries, k] + * @param[out] distances device matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::handle_t const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_device(handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +/** Same as above, but all input and out data is in host memory. + * @param[in] handle the raft handle + * @param[in] dataset host matrix that stores the dataset [n_rows, dims] + * @param[in] queries host matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates host matrix with indices of candidate vectors [n_queries, + * n_candidates], where n_candidates >= k + * @param[out] indices host matrix that stores the refined indices [n_queries, k] + * @param[out] distances host matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::handle_t const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_host(dataset, queries, neighbor_candidates, indices, distances, metric); +} +} // namespace raft::neighbors diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index e9af97b547..14c4dd85f1 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -50,6 +51,8 @@ using namespace raft::spatial::knn::detail; // NOLINT * @tparam T element type. * @tparam IdxT type of the indices in the source source_vecs * @tparam LabelT label type + * @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise + * we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1. * * @param[in] labels device pointer to the cluster ids for each row [n_rows] * @param[in] list_offsets device pointer to the cluster offsets in the output (index) [n_lists] @@ -64,7 +67,7 @@ using namespace raft::spatial::knn::detail; // NOLINT * @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`. * */ -template +template __global__ void build_index_kernel(const LabelT* labels, const IdxT* list_offsets, const T* source_vecs, @@ -95,8 +98,11 @@ __global__ void build_index_kernel(const LabelT* labels, list_data += (list_offset + group_offset) * dim; // Point to the source vector - source_vecs += i * dim; - + if constexpr (gather_src) { + source_vecs += source_ixs[i] * dim; + } else { + source_vecs += i * dim; + } // Interleave dimensions of the source vector while recording it. // NB: such `veclen` is selected, that `dim % veclen == 0` for (uint32_t l = 0; l < dim; l += veclen) { @@ -299,4 +305,76 @@ inline auto build( } } +/** + * Build an index that can be used in refinement operation. + * + * See raft::neighbors::refine for details on the refinement operation. + * + * The returned index cannot be used for a regular ivf_flat::search. The index misses information + * about coarse clusters. Instead, the neighbor candidates are assumed to form clusters, one for + * each query. The candidate vectors are gathered into the index dataset, that can be later used + * in ivfflat_interleaved_scan. + * + * @param[in] handle the raft handle + * @param[inout] refinement_index + * @param[in] dataset device pointer to dataset vectors, size [n_rows, dim]. Note that n_rows is + * not known to this function, but each candidate_idx has to be smaller than n_rows. + * @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates] + * @param[in] n_candidates of neighbor_candidates + */ +template +inline void fill_refinement_index(const handle_t& handle, + index* refinement_index, + const T* dataset, + const IdxT* candidate_idx, + IdxT n_queries, + uint32_t n_candidates) +{ + using LabelT = uint32_t; + + auto stream = handle.get_stream(); + uint32_t n_lists = n_queries; + common::nvtx::range fun_scope( + "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); + + rmm::device_uvector new_labels(n_queries * n_candidates, stream); + linalg::writeOnlyUnaryOp( + new_labels.data(), + n_queries * n_candidates, + [n_candidates] __device__(LabelT * out, uint32_t i) { *out = i / n_candidates; }, + stream); + + auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); + auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); + // We do not fill centers and center norms, since we will not run coarse search. + + // Calculate new offsets + uint32_t n_roundup = Pow2::roundUp(n_candidates); + linalg::writeOnlyUnaryOp( + refinement_index->list_offsets().data_handle(), + refinement_index->list_offsets().size(), + [n_roundup] __device__(IdxT * out, uint32_t i) { *out = i * n_roundup; }, + stream); + + IdxT index_size = n_roundup * n_lists; + refinement_index->allocate( + handle, index_size, refinement_index->metric() == raft::distance::DistanceType::L2Expanded); + + RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); + + const dim3 block_dim(256); + const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); + build_index_kernel + <<>>(new_labels.data(), + list_offsets_ptr, + dataset, + candidate_idx, + refinement_index->data().data_handle(), + refinement_index->indices().data_handle(), + list_sizes_ptr, + n_queries * n_candidates, + refinement_index->dim(), + refinement_index->veclen()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} } // namespace raft::spatial::knn::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index f6b9e62008..94f4dc96c6 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1212,6 +1212,26 @@ void search_impl(const handle_t& handle, } } +/** + * Whether minimal distance corresponds to similar elements (using the given metric). + */ +inline bool is_min_close(distance::DistanceType metric) +{ + bool select_min; + switch (metric) { + case raft::distance::DistanceType::InnerProduct: + case raft::distance::DistanceType::CosineExpanded: + case raft::distance::DistanceType::CorrelationExpanded: + // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger + // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 + // {perform_k_selection}) + select_min = false; + break; + default: select_min = true; + } + return select_min; +} + /** See raft::spatial::knn::ivf_flat::search docs */ template inline void search(const handle_t& handle, @@ -1231,27 +1251,22 @@ inline void search(const handle_t& handle, "n_probes (number of clusters to probe in the search) must be positive."); auto n_probes = std::min(params.n_probes, index.n_lists()); - bool select_min; - switch (index.metric()) { - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: - // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger - // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 - // {perform_k_selection}) - select_min = false; - break; - default: select_min = true; - } - auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); if (pool_guard) { RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } - return search_impl( - handle, index, queries, n_queries, k, n_probes, select_min, neighbors, distances, mr); + return search_impl(handle, + index, + queries, + n_queries, + k, + n_probes, + is_min_close(index.metric()), + neighbors, + distances, + mr); } } // namespace raft::spatial::knn::ivf_flat::detail diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3e8f944a5b..dae0f6f6b1 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -245,6 +245,7 @@ if(BUILD_TESTS) test/neighbors/ball_cover.cu test/neighbors/epsilon_neighborhood.cu test/neighbors/faiss_mr.cu + test/neighbors/refine.cu test/neighbors/selection.cu OPTIONAL DIST diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 3e3735719f..07ef410d36 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -25,6 +25,9 @@ #include #include +#include "../test_utils.h" +#include + namespace raft::neighbors { struct print_dtype { diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu new file mode 100644 index 0000000000..e1700e44b3 --- /dev/null +++ b/cpp/test/neighbors/refine.cu @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "ann_utils.cuh" + +#include "refine_helper.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#if defined RAFT_NN_COMPILED +#include +#endif + +#include + +namespace raft::neighbors { + +template +class RefineTest : public ::testing::TestWithParam> { + public: + RefineTest() + : stream_(handle_.get_stream()), + data(handle_, ::testing::TestWithParam>::GetParam()) + { + } + + protected: + public: // tamas remove + void testRefine() + { + std::vector indices(data.p.n_queries * data.p.k); + std::vector distances(data.p.n_queries * data.p.k); + + if (data.p.host_data) { + raft::neighbors::refine(handle_, + data.dataset_host.view(), + data.queries_host.view(), + data.candidates_host.view(), + data.refined_indices_host.view(), + data.refined_distances_host.view(), + data.p.metric); + raft::copy(indices.data(), + data.refined_indices_host.data_handle(), + data.refined_indices_host.size(), + stream_); + raft::copy(distances.data(), + data.refined_distances_host.data_handle(), + data.refined_distances_host.size(), + stream_); + + } else { + raft::neighbors::refine(handle_, + data.dataset.view(), + data.queries.view(), + data.candidates.view(), + data.refined_indices.view(), + data.refined_distances.view(), + data.p.metric); + update_host(distances.data(), + data.refined_distances.data_handle(), + data.refined_distances.size(), + stream_); + update_host( + indices.data(), data.refined_indices.data_handle(), data.refined_indices.size(), stream_); + } + handle_.sync_stream(stream_); + + double min_recall = 1; + + ASSERT_TRUE(raft::neighbors::eval_neighbours(data.true_refined_indices_host, + indices, + data.true_refined_distances_host, + distances, + data.p.n_queries, + data.p.k, + 0.001, + min_recall)); + } + + public: + raft::handle_t handle_; + rmm::cuda_stream_view stream_; + detail::RefineHelper data; +}; + +const std::vector> inputs = + raft::util::itertools::product>( + {137}, + {1000}, + {16}, + {1, 10, 33}, + {33}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, + {false, true}); + +typedef RefineTest RefineTestF; +TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } + +INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); + +typedef RefineTest RefineTestF_uint8; +TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } +INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine_helper.cuh b/cpp/test/neighbors/refine_helper.cuh new file mode 100644 index 0000000000..3c69a8f5b7 --- /dev/null +++ b/cpp/test/neighbors/refine_helper.cuh @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "ann_utils.cuh" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::neighbors::detail { + +template +struct RefineInputs { + IdxT n_queries; + IdxT n_rows; + IdxT dim; + IdxT k; // after refinement + IdxT k0; // initial k before refinement (k0 >= k). + raft::distance::DistanceType metric; + bool host_data; +}; + +/** Helper class to allocate arrays and generate input data for refinement test and benchmark. */ +template +class RefineHelper { + public: + RefineHelper(const raft::handle_t& handle, RefineInputs params) + : handle_(handle), stream_(handle.get_stream()), p(params) + { + raft::random::Rng r(1234ULL); + + dataset = raft::make_device_matrix(handle_, p.n_rows, p.dim); + queries = raft::make_device_matrix(handle_, p.n_queries, p.dim); + if constexpr (std::is_same{}) { + r.uniform(dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0), stream_); + r.uniform(queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0), stream_); + } else { + r.uniformInt(dataset.data_handle(), dataset.size(), DataT(1), DataT(20), stream_); + r.uniformInt(queries.data_handle(), queries.size(), DataT(1), DataT(20), stream_); + } + + refined_distances = raft::make_device_matrix(handle_, p.n_queries, p.k); + refined_indices = raft::make_device_matrix(handle_, p.n_queries, p.k); + + // Generate candidate vectors + { + candidates = raft::make_device_matrix(handle_, p.n_queries, p.k0); + rmm::device_uvector distances_tmp(p.n_queries * p.k0, stream_); + raft::neighbors::naiveBfKnn(distances_tmp.data(), + candidates.data_handle(), + queries.data_handle(), + dataset.data_handle(), + p.n_queries, + p.n_rows, + p.dim, + p.k0, + p.metric, + stream_); + handle_.sync_stream(stream_); + } + + if (p.host_data) { + dataset_host = raft::make_host_matrix(p.n_rows, p.dim); + queries_host = raft::make_host_matrix(p.n_queries, p.dim); + candidates_host = raft::make_host_matrix(p.n_queries, p.k0); + + raft::copy(dataset_host.data_handle(), dataset.data_handle(), dataset.size(), stream_); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries.size(), stream_); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates.size(), stream_); + + refined_distances_host = raft::make_host_matrix(p.n_queries, p.k); + refined_indices_host = raft::make_host_matrix(p.n_queries, p.k); + handle_.sync_stream(stream_); + } + + // Generate ground thruth for testing. + { + rmm::device_uvector distances_dev(p.n_queries * p.k, stream_); + rmm::device_uvector indices_dev(p.n_queries * p.k, stream_); + raft::neighbors::naiveBfKnn(distances_dev.data(), + indices_dev.data(), + queries.data_handle(), + dataset.data_handle(), + p.n_queries, + p.n_rows, + p.dim, + p.k, + p.metric, + stream_); + true_refined_distances_host.resize(p.n_queries * p.k); + true_refined_indices_host.resize(p.n_queries * p.k); + raft::copy(true_refined_indices_host.data(), indices_dev.data(), indices_dev.size(), stream_); + raft::copy( + true_refined_distances_host.data(), distances_dev.data(), distances_dev.size(), stream_); + handle_.sync_stream(stream_); + } + } + + public: + RefineInputs p; + const raft::handle_t& handle_; + rmm::cuda_stream_view stream_; + + raft::device_matrix dataset; + raft::device_matrix queries; + raft::device_matrix candidates; // Neighbor candidate indices + raft::device_matrix refined_indices; + raft::device_matrix refined_distances; + + raft::host_matrix dataset_host; + raft::host_matrix queries_host; + raft::host_matrix candidates_host; + raft::host_matrix refined_indices_host; + raft::host_matrix refined_distances_host; + + std::vector true_refined_indices_host; + std::vector true_refined_distances_host; +}; +} // namespace raft::neighbors::detail \ No newline at end of file