diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index bb405088bb..63f6c14686 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -18,8 +18,10 @@ #include #include +#include #include #include +#include #include @@ -40,6 +42,8 @@ struct params { int block_size; int search_width; int max_iterations; + /** Ratio of removed indices. */ + double removed_ratio; }; template @@ -49,7 +53,8 @@ struct CagraBench : public fixture { params_(ps), queries_(make_device_matrix(handle, ps.n_queries, ps.n_dims)), dataset_(make_device_matrix(handle, ps.n_samples, ps.n_dims)), - knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)) + knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)), + removed_indices_bitset_(handle, ps.n_samples) { // Generate random dataset and queriees raft::random::RngState state{42}; @@ -74,6 +79,13 @@ struct CagraBench : public fixture { auto metric = raft::distance::DistanceType::L2Expanded; + auto removed_indices = + raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + removed_indices_bitset_.set(handle, removed_indices.view()); index_.emplace(raft::neighbors::cagra::index( handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); } @@ -95,10 +107,18 @@ struct CagraBench : public fixture { distances.data_handle(), params_.n_queries, params_.k); auto queries_v = make_const_mdspan(queries_.view()); - loop_on_state(state, [&]() { - raft::neighbors::cagra::search( - this->handle, search_params, *this->index_, queries_v, ind_v, dist_v); - }); + if (params_.removed_ratio > 0) { + auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); + loop_on_state(state, [&]() { + raft::neighbors::cagra::search_with_filtering( + this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter); + }); + } else { + loop_on_state(state, [&]() { + raft::neighbors::cagra::search( + this->handle, search_params, *this->index_, queries_v, ind_v, dist_v); + }); + } double data_size = params_.n_samples * params_.n_dims * sizeof(T); double graph_size = params_.n_samples * params_.degree * sizeof(IdxT); @@ -120,6 +140,7 @@ struct CagraBench : public fixture { state.counters["block_size"] = params_.block_size; state.counters["search_width"] = params_.search_width; state.counters["iterations"] = iterations; + state.counters["removed_ratio"] = params_.removed_ratio; } private: @@ -128,6 +149,7 @@ struct CagraBench : public fixture { raft::device_matrix queries_; raft::device_matrix dataset_; raft::device_matrix knn_graph_; + raft::core::bitset removed_indices_bitset_; }; inline const std::vector generate_inputs() @@ -141,7 +163,8 @@ inline const std::vector generate_inputs() {64}, // itopk_size {0}, // block_size {1}, // search_width - {0} // max_iterations + {0}, // max_iterations + {0.0} // removed_ratio ); auto inputs2 = raft::util::itertools::product({2000000ull, 10000000ull}, // n_samples {128}, // dataset dim @@ -151,7 +174,22 @@ inline const std::vector generate_inputs() {64}, // itopk_size {64, 128, 256, 512, 1024}, // block_size {1}, // search_width - {0} // max_iterations + {0}, // max_iterations + {0.0} // removed_ratio + ); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = raft::util::itertools::product( + {2000000ull, 10000000ull}, // n_samples + {128}, // dataset dim + {1, 10, 10000}, // n_queries + {255}, // k + {64}, // knn graph degree + {300}, // itopk_size + {256}, // block_size + {2}, // search_width + {0}, // max_iterations + {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs; diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index cc934b7a98..19dd6b8350 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -66,11 +66,11 @@ struct index : ann::index { /** Dataset norms */ [[nodiscard]] inline auto norms() const -> device_vector_view { - return make_const_mdspan(norms_.value().view()); + return norms_view_.value(); } /** Whether ot not this index has dataset norms */ - [[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); } + [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); } [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } @@ -102,10 +102,30 @@ struct index : ann::index { norms_(std::move(norms)), metric_arg_(metric_arg) { + if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); } update_dataset(res, dataset); resource::sync_stream(res); } + /** Construct a brute force index from dataset + * + * This class stores a non-owning reference to the dataset and norms here. + * Having precomputed norms gives us a performance advantage at query time. + */ + index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional> norms_view, + raft::distance::DistanceType metric, + T metric_arg = 0.0) + : ann::index(), + metric_(metric), + dataset_(make_device_matrix(res, 0, 0)), + dataset_view_(dataset_view), + norms_view_(norms_view), + metric_arg_(metric_arg) + { + } + private: /** * Replace the dataset with a new dataset. @@ -135,6 +155,7 @@ struct index : ann::index { raft::distance::DistanceType metric_; raft::device_matrix dataset_; std::optional> norms_; + std::optional> norms_view_; raft::device_matrix_view dataset_view_; T metric_arg_; }; diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index f96dd34e05..f9682a973f 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -391,7 +391,25 @@ void search(raft::resources const& res, /** * @brief Search ANN using the constructed index with the given sample filter. * - * See the [cagra::build](#cagra::build) documentation for a usage example. + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * // use default search parameters + * cagra::search_params search_params; + * // create a bitset to filter the search + * auto removed_indices = raft::make_device_vector(res, n_removed_indices); + * raft::core::bitset removed_indices_bitset( + * res, removed_indices.view(), dataset.extent(0)); + * // search K nearest neighbours according to a bitset + * auto neighbors = raft::make_device_matrix(res, n_queries, k); + * auto distances = raft::make_device_matrix(res, n_queries, k); + * cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances, + * filtering::bitset_filter(removed_indices_bitset.view())); + * @endcode * * @tparam T data element type * @tparam IdxT type of the indices diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 5dcfcb3929..9392bde440 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -478,13 +478,15 @@ __global__ void apply_filter_kernel(INDEX_T* const result_indices_ptr, const INDEX_T query_id_offset, SAMPLE_FILTER_T sample_filter) { - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= result_buffer_size * num_queries) { return; } const auto i = tid % result_buffer_size; const auto j = tid / result_buffer_size; const auto index = i + j * lds; - if (!sample_filter(query_id_offset + j, result_indices_ptr[index])) { + if (result_indices_ptr[index] != ~index_msb_1_mask && + !sample_filter(query_id_offset + j, result_indices_ptr[index])) { result_indices_ptr[index] = utils::get_max_value(); result_distances_ptr[index] = utils::get_max_value(); } @@ -788,12 +790,15 @@ struct search : search_plan_impl { auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size; auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size; - // Remove parent bit in search results - remove_parent_bit( - num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream); + if constexpr (!std::is_same::value) { + // Remove parent bit in search results + remove_parent_bit(num_queries, + result_buffer_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + stream); - if (!std::is_same::value) { apply_filter( result_indices.data() + (iter & 0x1) * itopk_size, result_distances.data() + (iter & 0x1) * itopk_size, @@ -821,6 +826,10 @@ struct search : search_plan_impl { true, topk_hint.data(), stream); + } else { + // Remove parent bit in search results + remove_parent_bit( + num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream); } // Copy results from working buffer to final buffer diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index a0f346ab51..147b8b753d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -291,6 +291,14 @@ struct search_plan_impl : public search_plan_impl_base { "`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " + std::to_string(hashmap_max_fill_rate) + " has been given."; } + if constexpr (!std::is_same::value) { + if (hashmap_mode == hash_mode::SMALL) { + error_message += "`SMALL` hash is not available when filtering"; + } else { + hashmap_mode = hash_mode::HASH; + } + } if (algo == search_algo::MULTI_CTA) { if (hashmap_mode == hash_mode::SMALL) { error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\""; diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 3e4d0409bd..009ffd4684 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -1278,8 +1278,7 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out std::thread update_and_sample_thread(update_and_sample, it); - std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r"; - std::fflush(stdout); + RAFT_LOG_DEBUG("# GNND iteraton: %lu / %lu", it + 1, build_config_.max_iterations); // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it // contains some information for local_join. diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh new file mode 100644 index 0000000000..9182d72da9 --- /dev/null +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace raft::neighbors::filtering { +/** + * @brief Filter an index with a bitset + * + * @tparam index_t Indexing type + */ +template +struct bitset_filter { + // View of the bitset to use as a filter + const raft::core::bitset_view bitset_view_; + + bitset_filter(const raft::core::bitset_view bitset_for_filtering) + : bitset_view_{bitset_for_filtering} + { + } + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the index of the current sample + const uint32_t sample_ix) const + { + return bitset_view_.test(sample_ix); + } +}; +} // namespace raft::neighbors::filtering diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index b750372244..e6c3873063 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -525,6 +526,119 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } } + void testCagraRemoved() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - test_cagra_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_cagra_sample_filter::offset), + queries_size, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + { + cagra::index_params index_params; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.hashmap_mode = cagra::hash_mode::HASH; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + } + + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + auto removed_indices = + raft::make_device_vector(handle_, test_cagra_sample_filter::offset); + thrust::sequence( + resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + resource::sync_stream(handle_); + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.n_rows); + cagra::search_with_filtering( + handle_, + search_params, + index, + search_queries_view, + indices_out_view, + dists_out_view, + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); + update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.001, + min_recall)); + EXPECT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + void SetUp() override { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 01d7e1e1ea..944c2cbc89 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -27,7 +27,11 @@ typedef AnnCagraSortTest AnnCagraSortTestF_U32; TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } typedef AnnCagraFilterTest AnnCagraFilterTestF_U32; -TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) { this->testCagraFilter(); } +TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) +{ + this->testCagraFilter(); + this->testCagraRemoved(); +} INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index ee06d369fa..3d9dc76953 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -25,7 +25,11 @@ TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestI8_U32; TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } typedef AnnCagraFilterTest AnnCagraFilterTestI8_U32; -TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) { this->testCagraFilter(); } +TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) +{ + this->testCagraFilter(); + this->testCagraRemoved(); +} INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index 3243e73ccd..c5b1b1704b 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -27,7 +27,11 @@ typedef AnnCagraSortTest AnnCagraSortTestU8_ TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } typedef AnnCagraFilterTest AnnCagraFilterTestU8_U32; -TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) { this->testCagraFilter(); } +TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) +{ + this->testCagraFilter(); + this->testCagraRemoved(); +} INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs));