From 81dad56d4cf278407d8a88cc5fbc023194b5f135 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 16 Oct 2023 00:49:43 +0200 Subject: [PATCH] Add sample filter conversion --- build.sh | 4 +- .../detail/ivf_flat_interleaved_scan-inl.cuh | 2 +- cpp/include/raft/neighbors/sample_filter.cuh | 30 +++++ cpp/test/neighbors/ann_ivf_flat.cuh | 108 ++++++++++++++++++ .../ann_ivf_flat/test_float_int64_t.cu | 1 + 5 files changed, 143 insertions(+), 2 deletions(-) diff --git a/build.sh b/build.sh index 6200e6a2fa..51e59cc259 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" @@ -324,6 +324,8 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index ad3d158e48..ab2844c97c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -740,7 +740,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const bool valid = vec_id < list_length; // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) { + if (valid && sample_filter(queries_offset + blockIdx.y, list_id, vec_id)) { loadAndComputeDist lc(dist, compute_dist); for (int pos = 0; pos < shm_assisted_dim; diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh index 9182d72da9..63384e50bd 100644 --- a/cpp/include/raft/neighbors/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -45,4 +45,34 @@ struct bitset_filter { return bitset_view_.test(sample_ix); } }; + +/** + * @brief Filter used to convert the cluster index and sample index + * of an IVF search into a sample index. This can be used as an + * intermediate filter. + * + * @tparam index_t Indexing type + * @tparam filter_t + */ +template +struct ivf_to_sample_filter { + index_t** const inds_ptrs_; + const filter_t next_filter_; + + ivf_to_sample_filter(index_t** const inds_ptrs, const filter_t next_filter) + : inds_ptrs_{inds_ptrs}, next_filter_{next_filter} + { + } + + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the current inverted list index + const uint32_t cluster_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + return next_filter_(query_ix, inds_ptrs_[cluster_ix][sample_ix]); + } +}; } // namespace raft::neighbors::filtering diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 7b1d32ca83..9e22faaecb 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -15,6 +15,8 @@ */ #pragma once +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter + #include "../test_utils.cuh" #include "ann_utils.cuh" #include @@ -26,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -57,6 +60,10 @@ namespace raft::neighbors::ivf_flat { +struct test_ivf_sample_filter { + static constexpr unsigned offset = 300; +}; + template struct AnnIvfFlatInputs { IdxT num_queries; @@ -406,6 +413,107 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } + void testFilter() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfflat(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_ivf_sample_filter::offset * ps.dim; + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + // unless something is really wrong with clustering, this could serve as a lower bound on + // recall + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + + auto distances_ivfflat_dev = raft::make_device_matrix(handle_, ps.num_queries, ps.k); + auto indices_ivfflat_dev = + raft::make_device_matrix(handle_, ps.num_queries, ps.k); + + { + ivf_flat::index_params index_params; + ivf_flat::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + index_params.metric_arg = 0; + + // Create IVF Flat index + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_flat::build(handle_, index_params, database_view); + + // Create Bitset filter + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + thrust::sequence(resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + + test_ivf_sample_filter::offset)); + resource::sync_stream(handle_); + + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + + // Search with the filter + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + ivf_flat::search_with_filtering( + handle_, + search_params, + index, + search_queries_view, + indices_ivfflat_dev.view(), + distances_ivfflat_dev.view(), + raft::neighbors::filtering::ivf_to_sample_filter( + index.inds_ptrs().data_handle(), + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view()))); + + update_host( + distances_ivfflat.data(), distances_ivfflat_dev.data_handle(), queries_size, stream_); + update_host( + indices_ivfflat.data(), indices_ivfflat_dev.data_handle(), queries_size, stream_); + resource::sync_stream(handle_); + } + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + } + void SetUp() override { database.resize(ps.num_db_vecs * ps.dim, stream_); diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index 3bfea283e5..d22c3837a3 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -25,6 +25,7 @@ TEST_P(AnnIVFFlatTestF, AnnIVFFlat) { this->testIVFFlat(); this->testPacker(); + this->testFilter(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF, ::testing::ValuesIn(inputs));