diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index bb405088bb..63f6c14686 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -18,8 +18,10 @@ #include #include +#include #include #include +#include #include @@ -40,6 +42,8 @@ struct params { int block_size; int search_width; int max_iterations; + /** Ratio of removed indices. */ + double removed_ratio; }; template @@ -49,7 +53,8 @@ struct CagraBench : public fixture { params_(ps), queries_(make_device_matrix(handle, ps.n_queries, ps.n_dims)), dataset_(make_device_matrix(handle, ps.n_samples, ps.n_dims)), - knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)) + knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)), + removed_indices_bitset_(handle, ps.n_samples) { // Generate random dataset and queriees raft::random::RngState state{42}; @@ -74,6 +79,13 @@ struct CagraBench : public fixture { auto metric = raft::distance::DistanceType::L2Expanded; + auto removed_indices = + raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + removed_indices_bitset_.set(handle, removed_indices.view()); index_.emplace(raft::neighbors::cagra::index( handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); } @@ -95,10 +107,18 @@ struct CagraBench : public fixture { distances.data_handle(), params_.n_queries, params_.k); auto queries_v = make_const_mdspan(queries_.view()); - loop_on_state(state, [&]() { - raft::neighbors::cagra::search( - this->handle, search_params, *this->index_, queries_v, ind_v, dist_v); - }); + if (params_.removed_ratio > 0) { + auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); + loop_on_state(state, [&]() { + raft::neighbors::cagra::search_with_filtering( + this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter); + }); + } else { + loop_on_state(state, [&]() { + raft::neighbors::cagra::search( + this->handle, search_params, *this->index_, queries_v, ind_v, dist_v); + }); + } double data_size = params_.n_samples * params_.n_dims * sizeof(T); double graph_size = params_.n_samples * params_.degree * sizeof(IdxT); @@ -120,6 +140,7 @@ struct CagraBench : public fixture { state.counters["block_size"] = params_.block_size; state.counters["search_width"] = params_.search_width; state.counters["iterations"] = iterations; + state.counters["removed_ratio"] = params_.removed_ratio; } private: @@ -128,6 +149,7 @@ struct CagraBench : public fixture { raft::device_matrix queries_; raft::device_matrix dataset_; raft::device_matrix knn_graph_; + raft::core::bitset removed_indices_bitset_; }; inline const std::vector generate_inputs() @@ -141,7 +163,8 @@ inline const std::vector generate_inputs() {64}, // itopk_size {0}, // block_size {1}, // search_width - {0} // max_iterations + {0}, // max_iterations + {0.0} // removed_ratio ); auto inputs2 = raft::util::itertools::product({2000000ull, 10000000ull}, // n_samples {128}, // dataset dim @@ -151,7 +174,22 @@ inline const std::vector generate_inputs() {64}, // itopk_size {64, 128, 256, 512, 1024}, // block_size {1}, // search_width - {0} // max_iterations + {0}, // max_iterations + {0.0} // removed_ratio + ); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = raft::util::itertools::product( + {2000000ull, 10000000ull}, // n_samples + {128}, // dataset dim + {1, 10, 10000}, // n_queries + {255}, // k + {64}, // knn graph degree + {300}, // itopk_size + {256}, // block_size + {2}, // search_width + {0}, // max_iterations + {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs; diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index f96dd34e05..f9682a973f 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -391,7 +391,25 @@ void search(raft::resources const& res, /** * @brief Search ANN using the constructed index with the given sample filter. * - * See the [cagra::build](#cagra::build) documentation for a usage example. + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * // use default search parameters + * cagra::search_params search_params; + * // create a bitset to filter the search + * auto removed_indices = raft::make_device_vector(res, n_removed_indices); + * raft::core::bitset removed_indices_bitset( + * res, removed_indices.view(), dataset.extent(0)); + * // search K nearest neighbours according to a bitset + * auto neighbors = raft::make_device_matrix(res, n_queries, k); + * auto distances = raft::make_device_matrix(res, n_queries, k); + * cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances, + * filtering::bitset_filter(removed_indices_bitset.view())); + * @endcode * * @tparam T data element type * @tparam IdxT type of the indices diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh new file mode 100644 index 0000000000..9182d72da9 --- /dev/null +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace raft::neighbors::filtering { +/** + * @brief Filter an index with a bitset + * + * @tparam index_t Indexing type + */ +template +struct bitset_filter { + // View of the bitset to use as a filter + const raft::core::bitset_view bitset_view_; + + bitset_filter(const raft::core::bitset_view bitset_for_filtering) + : bitset_view_{bitset_for_filtering} + { + } + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the index of the current sample + const uint32_t sample_ix) const + { + return bitset_view_.test(sample_ix); + } +}; +} // namespace raft::neighbors::filtering diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index b750372244..e6c3873063 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -525,6 +526,119 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } } + void testCagraRemoved() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - test_cagra_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_cagra_sample_filter::offset), + queries_size, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + { + cagra::index_params index_params; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.hashmap_mode = cagra::hash_mode::HASH; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + } + + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + auto removed_indices = + raft::make_device_vector(handle_, test_cagra_sample_filter::offset); + thrust::sequence( + resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + resource::sync_stream(handle_); + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.n_rows); + cagra::search_with_filtering( + handle_, + search_params, + index, + search_queries_view, + indices_out_view, + dists_out_view, + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); + update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.001, + min_recall)); + EXPECT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + void SetUp() override { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 01d7e1e1ea..944c2cbc89 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -27,7 +27,11 @@ typedef AnnCagraSortTest AnnCagraSortTestF_U32; TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } typedef AnnCagraFilterTest AnnCagraFilterTestF_U32; -TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) { this->testCagraFilter(); } +TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) +{ + this->testCagraFilter(); + this->testCagraRemoved(); +} INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index ee06d369fa..3d9dc76953 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -25,7 +25,11 @@ TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestI8_U32; TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } typedef AnnCagraFilterTest AnnCagraFilterTestI8_U32; -TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) { this->testCagraFilter(); } +TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) +{ + this->testCagraFilter(); + this->testCagraRemoved(); +} INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index 3243e73ccd..c5b1b1704b 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -27,7 +27,11 @@ typedef AnnCagraSortTest AnnCagraSortTestU8_ TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } typedef AnnCagraFilterTest AnnCagraFilterTestU8_U32; -TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) { this->testCagraFilter(); } +TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) +{ + this->testCagraFilter(); + this->testCagraRemoved(); +} INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs));