From 28eb0b3592bfda5aa22fcc5137d455e0cd505433 Mon Sep 17 00:00:00 2001 From: Micka Date: Mon, 6 Nov 2023 15:39:01 +0100 Subject: [PATCH] [FEA] Support vector deletion in ANN IVF (#1831) PR based on the new Bitset feature (#1803) to support vector deletion in ANN. Closes #1177. Closes #1620. This PR adds `ivf_to_sample_filter` that acts as an intermediate filter to use an IVF index with a bitset filter. Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/raft/pull/1831 --- build.sh | 4 +- cpp/bench/prims/CMakeLists.txt | 2 + cpp/bench/prims/neighbors/knn.cuh | 119 ++++++++++++- .../knn/ivf_flat_filter_float_int64_t.cu | 24 +++ .../knn/ivf_pq_filter_float_int64_t.cu | 24 +++ .../raft/neighbors/detail/ivf_flat_build.cuh | 1 + .../detail/ivf_flat_interleaved_scan-inl.cuh | 41 +++-- .../raft/neighbors/detail/ivf_pq_build.cuh | 1 + .../detail/ivf_pq_compute_similarity-ext.cuh | 27 ++- .../raft/neighbors/detail/ivf_pq_search.cuh | 6 +- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 6 +- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 6 +- cpp/include/raft/neighbors/sample_filter.cuh | 1 + .../raft/neighbors/sample_filter_types.hpp | 45 +++++ .../ivf_pq_compute_similarity_00_generate.py | 2 +- .../ivf_pq_compute_similarity_float_float.cu | 5 +- ...f_pq_compute_similarity_float_fp8_false.cu | 3 +- ...vf_pq_compute_similarity_float_fp8_true.cu | 3 +- .../ivf_pq_compute_similarity_float_half.cu | 5 +- ...vf_pq_compute_similarity_half_fp8_false.cu | 3 +- ...ivf_pq_compute_similarity_half_fp8_true.cu | 3 +- .../ivf_pq_compute_similarity_half_half.cu | 5 +- cpp/test/CMakeLists.txt | 4 +- cpp/test/neighbors/ann_ivf_flat.cuh | 104 +++++++++++ .../ann_ivf_flat/test_filter_float_int64_t.cu | 29 ++++ cpp/test/neighbors/ann_ivf_pq.cuh | 163 ++++++++++++++++++ .../ann_ivf_pq/test_filter_float_int64_t.cu | 26 +++ .../ann_ivf_pq/test_filter_int8_t_int64_t.cu | 27 +++ .../ann_ivf_pq/test_float_uint32_t.cu | 5 +- 29 files changed, 649 insertions(+), 45 deletions(-) create mode 100644 cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu create mode 100644 cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu diff --git a/build.sh b/build.sh index 6200e6a2fa..51e59cc259 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" @@ -324,6 +324,8 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 5da2cd916b..fe58453d0d 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -147,10 +147,12 @@ if(BUILD_PRIMS_BENCH) bench/prims/neighbors/knn/brute_force_float_int64_t.cu bench/prims/neighbors/knn/brute_force_float_uint32_t.cu bench/prims/neighbors/knn/cagra_float_uint32_t.cu + bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu + bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu bench/prims/neighbors/refine_float_int64_t.cu diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 31ac869b37..55865b577a 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -21,9 +21,12 @@ #include +#include #include #include +#include #include +#include #include #include @@ -31,6 +34,8 @@ #include #include +#include + #include namespace raft::bench::spatial { @@ -44,11 +49,14 @@ struct params { size_t n_queries; /** Number of nearest neighbours to find for every probe. */ size_t k; + /** Ratio of removed indices. */ + double removed_ratio; }; inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& { - os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k; + os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#" + << p.removed_ratio; return os; } @@ -221,6 +229,104 @@ struct brute_force_knn { } }; +template +struct ivf_flat_filter_knn { + using dist_t = float; + + std::optional> index; + raft::neighbors::ivf_flat::index_params index_params; + raft::neighbors::ivf_flat::search_params search_params; + raft::core::bitset removed_indices_bitset_; + params ps; + + ivf_flat_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) + : ps(ps), removed_indices_bitset_(handle, ps.n_samples) + { + index_params.n_lists = 4096; + index_params.metric = raft::distance::DistanceType::L2Expanded; + index.emplace(raft::neighbors::ivf_flat::build( + handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + auto removed_indices = + raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + removed_indices_bitset_.set(handle, removed_indices.view()); + } + + void search(const raft::device_resources& handle, + const ValT* search_items, + dist_t* out_dists, + IdxT* out_idxs) + { + search_params.n_probes = 20; + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto neighbors_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto distance_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); + auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); + + if (ps.removed_ratio > 0) { + raft::neighbors::ivf_flat::search_with_filtering( + handle, search_params, *index, queries_view, neighbors_view, distance_view, filter); + } else { + raft::neighbors::ivf_flat::search( + handle, search_params, *index, queries_view, neighbors_view, distance_view); + } + } +}; + +template +struct ivf_pq_filter_knn { + using dist_t = float; + + std::optional> index; + raft::neighbors::ivf_pq::index_params index_params; + raft::neighbors::ivf_pq::search_params search_params; + raft::core::bitset removed_indices_bitset_; + params ps; + + ivf_pq_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) + : ps(ps), removed_indices_bitset_(handle, ps.n_samples) + { + index_params.n_lists = 4096; + index_params.metric = raft::distance::DistanceType::L2Expanded; + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); + auto removed_indices = + raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + removed_indices_bitset_.set(handle, removed_indices.view()); + } + + void search(const raft::device_resources& handle, + const ValT* search_items, + dist_t* out_dists, + IdxT* out_idxs) + { + search_params.n_probes = 20; + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto neighbors_view = + raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto distance_view = + raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); + auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); + + if (ps.removed_ratio > 0) { + raft::neighbors::ivf_pq::search_with_filtering( + handle, search_params, *index, queries_view, neighbors_view, distance_view, filter); + } else { + raft::neighbors::ivf_pq::search( + handle, search_params, *index, queries_view, neighbors_view, distance_view); + } + } +}; + template struct knn : public fixture { explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope) @@ -378,8 +484,15 @@ struct knn : public fixture { }; inline const std::vector kInputs{ - {2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}}; - + {2000000, 128, 1000, 32, 0}, {10000000, 128, 1000, 32, 0}, {10000, 8192, 1000, 32, 0}}; + +const std::vector kInputsFilter = + raft::util::itertools::product({size_t(10000000)}, // n_samples + {size_t(128)}, // n_dim + {size_t(1000)}, // n_queries + {size_t(255)}, // k + {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio + ); inline const std::vector kAllStrategies{ TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; diff --git a/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu new file mode 100644 index 0000000000..bf5118ceae --- /dev/null +++ b/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, ivf_flat_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu new file mode 100644 index 0000000000..9534515cbb --- /dev/null +++ b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); + +} // namespace raft::bench::spatial diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index c1b6056c7d..a9a6ac025f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -310,6 +310,7 @@ inline auto build(raft::resources const& handle, static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "unsupported data type"); RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); + RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); index index(handle, params, dim); utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index ad3d158e48..f1f0ce10d6 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include // RAFT_CUDA_TRY #include @@ -737,10 +738,11 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // This is the vector a given lane/thread handles const uint32_t vec_id = group_id * WarpSize + lane_id; - const bool valid = vec_id < list_length; + const bool valid = + vec_id < list_length && sample_filter(queries_offset + blockIdx.y, list_id, vec_id); // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) { + if (valid) { loadAndComputeDist lc(dist, compute_dist); for (int pos = 0; pos < shm_assisted_dim; @@ -1096,22 +1098,25 @@ void ivfflat_interleaved_scan(const index& index, rmm::cuda_stream_view stream) { const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - queries_offset, - n_probes, - k, - sample_filter, - neighbors, - distances, - grid_dim_x, - stream); + + auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter( + index.inds_ptrs().data_handle(), sample_filter); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + queries_offset, + n_probes, + k, + filter_adapter, + neighbors, + distances, + grid_dim_x, + stream); } } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 20ef6a05e0..33ed51ad05 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -1521,6 +1521,7 @@ auto build(raft::resources const& handle, "Unsupported data type"); RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); + RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); auto stream = resource::get_cuda_stream(handle); diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 6afb7e4299..37b6efc1eb 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -180,25 +180,38 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::filtering::none_ivf_sample_filter); + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::filtering::none_ivf_sample_filter); + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::filtering::none_ivf_sample_filter); + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 016fd8c693..7f5b316d41 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -794,7 +794,9 @@ inline void search(raft::resources const& handle, rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - auto search_instance = ivfpq_search::fun(params, index.metric()); + auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter( + index.inds_ptrs().data_handle(), sample_filter); + auto search_instance = ivfpq_search::fun(params, index.metric()); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); @@ -850,7 +852,7 @@ inline void search(raft::resources const& handle, distances + uint64_t(k) * (offset_q + offset_b), utils::config::kDivisor / utils::config::kDivisor, params.preferred_shmem_carveout, - sample_filter); + filter_adapter); } } } diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 6641346a67..692fb08810 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -375,7 +375,8 @@ void extend(raft::resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or + * `(uint32_t query_ix, uint32 sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -504,7 +505,8 @@ void search(raft::resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or + * `(uint32_t query_ix, uint32 sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index 9f203d92fb..d14456d6f6 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -149,7 +149,8 @@ void extend(raft::resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or + * `(uint32_t query_ix, uint32 sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -375,7 +376,8 @@ void extend(raft::resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or + * `(uint32_t query_ix, uint32 sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh index 9182d72da9..5864590034 100644 --- a/cpp/include/raft/neighbors/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -45,4 +45,5 @@ struct bitset_filter { return bitset_view_.test(sample_ix); } }; + } // namespace raft::neighbors::filtering diff --git a/cpp/include/raft/neighbors/sample_filter_types.hpp b/cpp/include/raft/neighbors/sample_filter_types.hpp index 10c5e99372..25030f48c8 100644 --- a/cpp/include/raft/neighbors/sample_filter_types.hpp +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -49,6 +49,51 @@ struct none_cagra_sample_filter { } }; +template +struct takes_three_args : std::false_type {}; +template +struct takes_three_args< + filter_t, + std::void_t()(uint32_t{}, uint32_t{}, uint32_t{}))>> + : std::true_type {}; + +/** + * @brief Filter used to convert the cluster index and sample index + * of an IVF search into a sample index. This can be used as an + * intermediate filter. + * + * @tparam index_t Indexing type + * @tparam filter_t + */ +template +struct ivf_to_sample_filter { + const index_t* const* inds_ptrs_; + const filter_t next_filter_; + + ivf_to_sample_filter(const index_t* const* inds_ptrs, const filter_t next_filter) + : inds_ptrs_{inds_ptrs}, next_filter_{next_filter} + { + } + + /** If the original filter takes three arguments, then don't modify the arguments. + * If the original filter takes two arguments, then we are using `inds_ptr_` to obtain the sample + * index. + */ + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the current inverted list index + const uint32_t cluster_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + if constexpr (takes_three_args::value) { + return next_filter_(query_ix, cluster_ix, sample_ix); + } else { + return next_filter_(query_ix, inds_ptrs_[cluster_ix][sample_ix]); + } + } +}; /** * If the filtering depends on the index of a sample, then the following * filter template can be used: diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index 5132048d40..670ed57ed1 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -103,6 +103,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::none_ivf_sample_filter);\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::ivf_to_sample_filter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index bfc07b0321..7e17d6822a 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -71,7 +71,10 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::filtering::none_ivf_sample_filter); + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index 537868b590..c1b72dab33 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -73,7 +73,8 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 59b64b892d..fdff0860fc 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -73,7 +73,8 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index f9e899f8e9..7205544370 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -71,7 +71,10 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::filtering::none_ivf_sample_filter); + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index bf699d7af6..2ac6c3527b 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -73,7 +73,8 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 9689ec88e1..70f3ffdb0c 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -73,7 +73,8 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::none_ivf_sample_filter); + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index deed61dd3d..5cc1cb8038 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -71,7 +71,10 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::filtering::none_ivf_sample_filter); + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); #undef COMMA diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 9b9b882d1d..6c03da8d7f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -379,14 +379,16 @@ if(BUILD_TESTS) NAME NEIGHBORS_ANN_IVF_TEST PATH + test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu - test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_float_uint32_t.cu test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu + test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu + test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 7b1d32ca83..a9fd696f1f 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -57,6 +58,10 @@ namespace raft::neighbors::ivf_flat { +struct test_ivf_sample_filter { + static constexpr unsigned offset = 300; +}; + template struct AnnIvfFlatInputs { IdxT num_queries; @@ -406,6 +411,105 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } + void testFilter() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfflat(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_ivf_sample_filter::offset * ps.dim; + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + // unless something is really wrong with clustering, this could serve as a lower bound on + // recall + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + + auto distances_ivfflat_dev = raft::make_device_matrix(handle_, ps.num_queries, ps.k); + auto indices_ivfflat_dev = + raft::make_device_matrix(handle_, ps.num_queries, ps.k); + + { + ivf_flat::index_params index_params; + ivf_flat::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + index_params.metric_arg = 0; + + // Create IVF Flat index + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_flat::build(handle_, index_params, database_view); + + // Create Bitset filter + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + thrust::sequence(resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + + test_ivf_sample_filter::offset)); + resource::sync_stream(handle_); + + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + + // Search with the filter + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + ivf_flat::search_with_filtering( + handle_, + search_params, + index, + search_queries_view, + indices_ivfflat_dev.view(), + distances_ivfflat_dev.view(), + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + + update_host( + distances_ivfflat.data(), distances_ivfflat_dev.data_handle(), queries_size, stream_); + update_host( + indices_ivfflat.data(), indices_ivfflat_dev.data_handle(), queries_size, stream_); + resource::sync_stream(handle_); + } + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + } + void SetUp() override { database.resize(ps.num_db_vecs * ps.dim, stream_); diff --git a/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu new file mode 100644 index 0000000000..0e1036e566 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../ann_ivf_flat.cuh" + +namespace raft::neighbors::ivf_flat { + +typedef AnnIVFFlatTest AnnIVFFlatFilterTestF; +TEST_P(AnnIVFFlatFilterTestF, AnnIVFFlatFilter) { this->testFilter(); } + +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatFilterTestF, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 87baf31c2b..bdb83ecfdc 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -39,6 +40,7 @@ #include #include +#include #include #include @@ -48,6 +50,10 @@ namespace raft::neighbors::ivf_pq { +struct test_ivf_sample_filter { + static constexpr unsigned offset = 1500; +}; + struct ivf_pq_inputs { uint32_t num_db_vecs = 4096; uint32_t num_queries = 1024; @@ -499,6 +505,163 @@ class ivf_pq_test : public ::testing::TestWithParam { std::vector distances_ref; // NOLINT }; +template +class ivf_pq_filter_test : public ::testing::TestWithParam { + public: + ivf_pq_filter_test() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void gen_data() + { + database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); + search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void calc_ref() + { + size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data() + test_ivf_sample_filter::offset * ps.dim, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + ps.index_params.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + distances_ref.resize(queries_size); + update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); + indices_ref.resize(queries_size); + update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + auto build_only() + { + auto ipams = ps.index_params; + ipams.add_data_on_build = true; + + auto index_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + return ivf_pq::build(handle_, ipams, index_view); + } + + template + void run(BuildIndex build_index) + { + index index = build_index(); + + double compression_ratio = + static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivf_pq(queries_size); + std::vector distances_ivf_pq(queries_size); + + rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); + rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); + + auto query_view = + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = raft::make_device_matrix_view( + indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = raft::make_device_matrix_view( + distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + + // Create Bitset filter + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + thrust::sequence( + resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + test_ivf_sample_filter::offset)); + resource::sync_stream(handle_); + + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + ivf_pq::search_with_filtering( + handle_, + ps.search_params, + index, + query_view, + inds_view, + dists_view, + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + + update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); + update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + + // A very conservative lower bound on recall + double min_recall = + static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); + // Using a heuristic to lower the required recall due to code-packing errors + min_recall = + std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall); + // Use explicit per-test min recall value if provided. + min_recall = ps.min_recall.value_or(min_recall); + + ASSERT_TRUE(eval_neighbours(indices_ref, + indices_ivf_pq, + distances_ref, + distances_ivf_pq, + ps.num_queries, + ps.k, + 0.0001 * compression_ratio, + min_recall)) + << ps; + } + + void SetUp() override // NOLINT + { + gen_data(); + calc_ref(); + } + + void TearDown() override // NOLINT + { + cudaGetLastError(); + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + ivf_pq_inputs ps; // NOLINT + rmm::device_uvector database; // NOLINT + rmm::device_uvector search_queries; // NOLINT + std::vector indices_ref; // NOLINT + std::vector distances_ref; // NOLINT +}; + /* Test cases */ using test_cases_t = std::vector; diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu new file mode 100644 index 0000000000..17f72fb08a --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../ann_ivf_pq.cuh" + +namespace raft::neighbors::ivf_pq { + +using f32_f32_i64_filter = ivf_pq_filter_test; + +TEST_BUILD_SEARCH(f32_f32_i64_filter) +INSTANTIATE(f32_f32_i64_filter, defaults() + big_dims_moderate_lut()); +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu new file mode 100644 index 0000000000..537dbb4979 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../ann_ivf_pq.cuh" + +namespace raft::neighbors::ivf_pq { + +using f32_i08_i64_filter = ivf_pq_filter_test; + +TEST_BUILD_SEARCH(f32_i08_i64_filter) +INSTANTIATE(f32_i08_i64_filter, big_dims()); + +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index 3d362a5261..5405ddc4a3 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -25,10 +25,13 @@ namespace raft::neighbors::ivf_pq { -using f32_f32_u32 = ivf_pq_test; +using f32_f32_u32 = ivf_pq_test; +using f32_f32_u32_filter = ivf_pq_filter_test; TEST_BUILD_SEARCH(f32_f32_u32) TEST_BUILD_SERIALIZE_SEARCH(f32_f32_u32) INSTANTIATE(f32_f32_u32, defaults() + var_n_probes() + var_k() + special_cases()); +TEST_BUILD_SEARCH(f32_f32_u32_filter) +INSTANTIATE(f32_f32_u32_filter, defaults()); } // namespace raft::neighbors::ivf_pq