From cdf107bbd415e756f87a3eee605bc35cdf81e113 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Fri, 19 May 2023 22:37:48 +0000 Subject: [PATCH] Introduce sample filtering to IVFPQ index search (#1513) A prototype that introduces a per-sample filtering for IVFPQ search. Please feel free to use it as a foundation for the future change, if appropriate, because the code is functional, but is not super clean-and-neat. The diff introduces a template parameter called `SampleFilterT`. An instance is expected * to be `SampleFilterT()` constructible (which was mostly needed to define a default behavior in the form of `SampleFilterT sample_filter=SampleFilterT()`, see below) * to provide a `inline __device__ bool operator(...)` that returns `true` is a given sample is valid for being used against a given query in IVFPQ search The default filter (that I set as a default one in certain facilities in the form of `typename SampleFilterT = NoneSampleFilter` in order not to modify way to many files) allows all samples to be used: ``` struct NoneSampleFilter { inline __device__ __host__ bool operator()( // query index const uint32_t query_ix, // the current inverted list index const uint32_t cluster_ix, // the index of the current sample inside the current inverted list const uint32_t sample_ix ) const { return true; } }; ``` Here `__host__` is needed for a CPU-based testing only. Also, I've provided an implementation of `BitMaskSampleFilter` that allows to filter samples based on a bit mask, as an example. The implementation was tested in the semi-production environment. All the filter-related files were added to `cpp/include/raft/neighbors/detail/sample_filter.cuh`. I did not change the default `ivf_pq_search()` method remains unchanged, but one more `ivf_pq_search_with_filtering()` method with an additional template argument `SampleFilterT` and one more input parameter was introduced. ``` template void search_with_filtering(raft::device_resources const& handle, const raft::neighbors::ivf_pq::search_params& params, const index& idx, const T* queries, uint32_t n_queries, uint32_t k, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, SampleFilterT sample_filter = SampleFilterT()); ``` All the current instantiations use `NoneSampleFilter` only. I've used `SampleFilterT sample_filter` parameters passing instead of `const SampleFilterT sample_filter` in the function calls in order to be able to add some debugging facilities to a filter and with the hope that the compiler is smart enough to understand the de-facto constness if needed. The filter does not take a computed distance score into account by design, thus the current implementation cannot have a distance threshold. This can be easily changed, if appropriate. It is still questionable to me whether this filtering needs to be injected right inside the search kernel instead of doing post-processing, please let me know if you have any thoughts on the topic. I'm happy to address the comments. Thanks. Authors: - Alexander Guzhva (https://github.com/alexanderguzhva) - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1513 --- .../detail/ivf_pq_compute_similarity-ext.cuh | 124 +++++++++++------- .../detail/ivf_pq_compute_similarity-inl.cuh | 92 +++++++++---- .../raft/neighbors/detail/ivf_pq_search.cuh | 42 +++--- .../raft/neighbors/detail/sample_filter.cuh | 116 ++++++++++++++++ cpp/include/raft/neighbors/ivf_pq-ext.cuh | 21 +++ cpp/include/raft/neighbors/ivf_pq-inl.cuh | 92 ++++++++++--- .../ivf_pq_compute_similarity_00_generate.py | 14 +- .../ivf_pq_compute_similarity_float_float.cu | 80 +++++------ ...f_pq_compute_similarity_float_fp8_false.cu | 81 ++++++------ ...vf_pq_compute_similarity_float_fp8_true.cu | 81 ++++++------ .../ivf_pq_compute_similarity_float_half.cu | 80 +++++------ ...vf_pq_compute_similarity_half_fp8_false.cu | 81 ++++++------ ...ivf_pq_compute_similarity_half_fp8_true.cu | 81 ++++++------ .../ivf_pq_compute_similarity_half_half.cu | 80 +++++------ 14 files changed, 687 insertions(+), 378 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/sample_filter.cuh diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 41e9fda701..62e46e3ae1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -20,6 +20,7 @@ #include // RAFT_WEAK_FUNCTION #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit +#include // NoneSampleFilter #include // raft::neighbors::ivf_pq::codebook_gen #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -36,6 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, uint32_t n_probes, uint32_t pq_dim, uint32_t n_queries, + uint32_t queries_offset, distance::DistanceType metric, codebook_gen codebook_kind, uint32_t topk, @@ -95,6 +100,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, + SampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) RAFT_EXPLICIT; @@ -113,7 +119,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -123,62 +129,78 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected RAFT_EXPLICIT; + uint32_t topk) -> selected RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_pq::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - extern template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index bc899c7ca7..37174f54e1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -19,6 +19,7 @@ #include // raft::distance::DistanceType #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t +#include // NoneSampleFilter #include // codebook_gen #include // RAFT_CUDA_TRY #include // raft::atomicMin @@ -200,6 +201,9 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * @param pq_dim * The dimensionality of an encoded vector after compression by PQ. * @param n_queries the number of queries. + * @param queries_offset + * An offset of the current query batch. It is used for feeding sample_filter with the + * correct query index. * @param metric the distance type. * @param codebook_kind Defines the way PQ codebooks have been trained. * @param topk the `k` in the select top-k. @@ -221,6 +225,12 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * @param index_list * An optional device pointer to the enforced order of search [n_queries, n_probes]. * One can pass reordered indices here to try to improve data reading locality. + * @param query_kth + * query_kths keep the current state of the filtering - atomically updated distances to the + * k-th closest neighbors for each query [n_queries]. + * @param sample_filter + * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. * Ignored when `EnableSMemLut == true`. @@ -236,6 +246,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, */ template ( pq_dim, reinterpret_cast(pq_thread_data), @@ -479,22 +493,27 @@ __global__ void compute_similarity_kernel(uint32_t n_rows, } // The signature of the kernel defined by a minimal set of template parameters -template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); // The config struct lifts the runtime parameters to the template parameters -template +template struct compute_similarity_kernel_config { public: - static auto get(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t + static auto get(uint32_t pq_bits, uint32_t k_max) + -> compute_similarity_kernel_t { return kernel_choose_bits(pq_bits, k_max); } private: static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { switch (pq_bits) { case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); @@ -507,7 +526,8 @@ struct compute_similarity_kernel_config { } template - static auto kernel_try_capacity(uint32_t k_max) -> compute_similarity_kernel_t + static auto kernel_try_capacity(uint32_t k_max) + -> compute_similarity_kernel_t { if constexpr (Capacity > 0) { if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } @@ -515,23 +535,36 @@ struct compute_similarity_kernel_config { if constexpr (Capacity > 1) { if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } } - return compute_similarity_kernel; + return compute_similarity_kernel; } }; // A standalone accessor function was necessary to make sure template // instantiation work correctly. This accessor function is not used anymore and // may be removed. -template +template auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { - return compute_similarity_kernel_config::get(pq_bits, - k_max); + return compute_similarity_kernel_config::get(pq_bits, k_max); } /** Estimate the occupancy for the given kernel on the given device. */ -template +template struct occupancy_t { using shmem_unit = Pow2<128>; @@ -542,7 +575,7 @@ struct occupancy_t { inline occupancy_t() = default; inline occupancy_t(size_t smem, uint32_t n_threads, - compute_similarity_kernel_t kernel, + compute_similarity_kernel_t kernel, const cudaDeviceProp& dev_props) { RAFT_CUDA_TRY( @@ -553,23 +586,24 @@ struct occupancy_t { } }; -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, uint32_t n_probes, uint32_t pq_dim, uint32_t n_queries, + uint32_t queries_offset, distance::DistanceType metric, codebook_gen codebook_kind, uint32_t topk, @@ -582,6 +616,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, + SampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) @@ -591,6 +626,7 @@ void compute_similarity_run(selected s, n_probes, pq_dim, n_queries, + queries_offset, metric, codebook_kind, topk, @@ -603,6 +639,7 @@ void compute_similarity_run(selected s, queries, index_list, query_kths, + sample_filter, lut_scores, _out_scores, _out_indices); @@ -623,7 +660,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -633,7 +670,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected + uint32_t topk) -> selected { // Shared memory for storing the lookup table size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); @@ -705,9 +742,9 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further optimize occupancy and data locality for the L1 cache. */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), @@ -716,8 +753,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; + occupancy_t selected_perf{}; + selected selected_config; for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { if (smem_size_const > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. @@ -753,7 +790,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, continue; } - occupancy_t cur(smem_size, n_threads, kernel, dev_props); + occupancy_t cur(smem_size, n_threads, kernel, dev_props); if (cur.blocks_per_sm <= 0) { // For some reason, we still cannot make this kernel run. Skip the candidate. continue; @@ -768,7 +805,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp(smem_size_tmp, n_threads_tmp, kernel, dev_props); + occupancy_t tmp( + smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false; if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { // Normally, the smaller the block the better for L1 cache hit rate. diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 149ea52b6a..d402a2436b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -414,19 +415,21 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, * 3. split the query batch into smaller chunks, so that the device workspace * is guaranteed to fit into GPU memory. */ -template +template void ivfpq_search_worker(raft::resources const& handle, const index& index, uint32_t max_samples, uint32_t n_probes, uint32_t topK, uint32_t n_queries, + uint32_t queries_offset, // needed for filtering const uint32_t* clusters_to_probe, // [n_queries, n_probes] const float* query, // [n_queries, rot_dim] IdxT* neighbors, // [n_queries, topK] float* distances, // [n_queries, topK] float scaling_factor, double preferred_shmem_carveout, + SampleFilterT sample_filter, rmm::mr::device_memory_resource* mr) { auto stream = resource::get_cuda_stream(handle); @@ -529,16 +532,16 @@ void ivfpq_search_worker(raft::resources const& handle, } auto search_instance = - compute_similarity_select(resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + compute_similarity_select(resource::get_device_properties(handle), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -558,6 +561,7 @@ void ivfpq_search_worker(raft::resources const& handle, n_probes, index.pq_dim(), n_queries, + queries_offset, index.metric(), index.codebook_kind(), topK, @@ -570,6 +574,7 @@ void ivfpq_search_worker(raft::resources const& handle, query, index_list_sorted, query_kths, + sample_filter, device_lut.data(), distances_buf.data(), neighbors_ptr); @@ -605,10 +610,10 @@ void ivfpq_search_worker(raft::resources const& handle, * This structure helps selecting a proper instance of the worker search function, * which contains a few template parameters. */ -template +template struct ivfpq_search { public: - using fun_t = decltype(&ivfpq_search_worker); + using fun_t = decltype(&ivfpq_search_worker); /** * Select an instance of the ivf-pq search function based on search tuning parameters, @@ -624,7 +629,7 @@ struct ivfpq_search { static auto filter_reasonable_instances(const search_params& params) -> fun_t { if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; + return ivfpq_search_worker; } else { RAFT_FAIL( "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " @@ -712,7 +717,7 @@ inline auto get_max_batch_size(uint32_t k, } /** See raft::spatial::knn::ivf_pq::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -721,7 +726,8 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported element type."); @@ -781,7 +787,7 @@ inline void search(raft::resources const& handle, rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - auto search_instance = ivfpq_search::fun(params, index.metric()); + auto search_instance = ivfpq_search::fun(params, index.metric()); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); @@ -830,12 +836,14 @@ inline void search(raft::resources const& handle, n_probes, k, batch_size, + offset_q + offset_b, clusters_to_probe.data() + uint64_t(n_probes) * offset_b, rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, neighbors + uint64_t(k) * (offset_q + offset_b), distances + uint64_t(k) * (offset_q + offset_b), utils::config::kDivisor / utils::config::kDivisor, params.preferred_shmem_carveout, + sample_filter, mr); } } diff --git a/cpp/include/raft/neighbors/detail/sample_filter.cuh b/cpp/include/raft/neighbors/detail/sample_filter.cuh new file mode 100644 index 0000000000..f5c3d91afe --- /dev/null +++ b/cpp/include/raft/neighbors/detail/sample_filter.cuh @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::ivf_pq::detail { + +/* A filter that filters nothing. This is the default behavior. */ +struct NoneSampleFilter { + inline __device__ __host__ bool operator()( + // query index + const uint32_t query_ix, + // the current inverted list index + const uint32_t cluster_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + return true; + } +}; + +/** + * If the filtering depends on the index of a sample, then the following + * filter template can be used: + * + * template + * struct IndexSampleFilter { + * using index_type = IdxT; + * + * const index_type* const* inds_ptr = nullptr; + * + * IndexSampleFilter() {} + * IndexSampleFilter(const index_type* const* _inds_ptr) + * : inds_ptr{_inds_ptr} {} + * IndexSampleFilter(const IndexSampleFilter&) = default; + * IndexSampleFilter(IndexSampleFilter&&) = default; + * IndexSampleFilter& operator=(const IndexSampleFilter&) = default; + * IndexSampleFilter& operator=(IndexSampleFilter&&) = default; + * + * inline __device__ __host__ bool operator()( + * const uint32_t query_ix, + * const uint32_t cluster_ix, + * const uint32_t sample_ix) const { + * index_type database_idx = inds_ptr[cluster_ix][sample_ix]; + * + * // return true or false, depending on the database_idx + * return true; + * } + * }; + * + * Initialize it as: + * using filter_type = IndexSampleFilter; + * filter_type filter(raft_ivfpq_index.inds_ptrs().data_handle()); + * + * Use it as: + * raft::neighbors::ivf_pq::search_with_filtering( + * ...regular parameters here..., + * filter + * ); + * + * Another example would be the following filter that greenlights samples according + * to a contiguous bit mask vector. + * + * template + * struct BitMaskSampleFilter { + * using index_type = IdxT; + * + * const index_type* const* inds_ptr = nullptr; + * const uint64_t* const bit_mask_ptr = nullptr; + * const int64_t bit_mask_stride_64 = 0; + * + * BitMaskSampleFilter() {} + * BitMaskSampleFilter( + * const index_type* const* _inds_ptr, + * const uint64_t* const _bit_mask_ptr, + * const int64_t _bit_mask_stride_64) + * : inds_ptr{_inds_ptr}, + * bit_mask_ptr{_bit_mask_ptr}, + * bit_mask_stride_64{_bit_mask_stride_64} {} + * BitMaskSampleFilter(const BitMaskSampleFilter&) = default; + * BitMaskSampleFilter(BitMaskSampleFilter&&) = default; + * BitMaskSampleFilter& operator=(const BitMaskSampleFilter&) = default; + * BitMaskSampleFilter& operator=(BitMaskSampleFilter&&) = default; + * + * inline __device__ __host__ bool operator()( + * const uint32_t query_ix, + * const uint32_t cluster_ix, + * const uint32_t sample_ix) const { + * const index_type database_idx = inds_ptr[cluster_ix][sample_ix]; + * const uint64_t bit_mask_element = + * bit_mask_ptr[query_ix * bit_mask_stride_64 + database_idx / 64]; + * const uint64_t masked_bool = + * bit_mask_element & (1ULL << (uint64_t)(database_idx % 64)); + * const bool is_bit_set = (masked_bool != 0); + * + * return is_bit_set; + * } + * }; + */ +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index 42dc776c97..f203709b1b 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -45,6 +45,15 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* idx) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + SampleFilterT sample_filter) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, @@ -74,6 +83,18 @@ void extend(raft::resources const& handle, const IdxT* new_indices, IdxT n_rows) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const raft::neighbors::ivf_pq::search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const raft::neighbors::ivf_pq::search_params& params, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index 83e7931c78..e2e60f0cd3 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -133,7 +133,7 @@ void extend(raft::resources const& handle, } /** - * @brief Search ANN using the constructed index. + * @brief Search ANN using the constructed index using the given filter. * * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. * @@ -156,14 +156,16 @@ void extend(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter a filter the greenlights samples for a given query. */ -template -void search(raft::resources const& handle, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + SampleFilterT sample_filter = SampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -176,15 +178,53 @@ void search(raft::resources const& handle, "Number of query dimensions should equal number of dimensions in the index."); std::uint32_t k = neighbors.extent(1); - return detail::search(handle, - params, - idx, - queries.data_handle(), - static_cast(queries.extent(0)), - k, - neighbors.data_handle(), - distances.data_handle(), - resource::get_workspace_resource(handle)); + detail::search(handle, + params, + idx, + queries.data_handle(), + static_cast(queries.extent(0)), + k, + neighbors.data_handle(), + distances.data_handle(), + resource::get_workspace_resource(handle), + sample_filter); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + search_with_filtering( + handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter()); } /** @} */ // end group ivf_pq @@ -297,6 +337,22 @@ void extend(raft::resources const& handle, detail::extend(handle, idx, new_vectors, new_indices, n_rows); } +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) +{ + detail::search( + handle, params, idx, queries, n_queries, k, neighbors, distances, mr, sample_filter); +} + /** * @brief Search ANN using the constructed index. * @@ -350,7 +406,7 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); + detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } } // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index a740d01bd2..ac547626bb 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -41,8 +41,8 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, SampleFilterT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ const cudaDeviceProp& dev_props, \\ bool manage_local_topk, \\ int locality_hint, \\ @@ -52,16 +52,17 @@ uint32_t precomp_data_count, \\ uint32_t n_queries, \\ uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ \\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ rmm::cuda_stream_view stream, \\ uint32_t n_rows, \\ uint32_t dim, \\ uint32_t n_probes, \\ uint32_t pq_dim, \\ uint32_t n_queries, \\ + uint32_t queries_offset, \\ raft::distance::DistanceType metric, \\ raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\ uint32_t topk, \\ @@ -74,6 +75,7 @@ const float* queries, \\ const uint32_t* index_list, \\ float* query_kths, \\ + SampleFilterT sample_filter, \\ LutT* lut_scores, \\ OutT* _out_scores, \\ uint32_t* _out_indices); @@ -102,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT});\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::ivf_pq::detail::NoneSampleFilter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 956b7010d5..67b67df19f 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -27,46 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index fba72ad1dd..1c97a1c9ba 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 030f429315..14e2d19fe7 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 31a4d7d503..7fd3a8d0b2 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -27,46 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index c623c80446..01df4d87e3 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index f2aaca20db..251515a552 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index 4420b2534b..b29f4bca96 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -27,46 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA