diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 41e9fda701..62e46e3ae1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -20,6 +20,7 @@ #include // RAFT_WEAK_FUNCTION #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit +#include // NoneSampleFilter #include // raft::neighbors::ivf_pq::codebook_gen #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -36,6 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, uint32_t n_probes, uint32_t pq_dim, uint32_t n_queries, + uint32_t queries_offset, distance::DistanceType metric, codebook_gen codebook_kind, uint32_t topk, @@ -95,6 +100,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, + SampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) RAFT_EXPLICIT; @@ -113,7 +119,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -123,62 +129,78 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected RAFT_EXPLICIT; + uint32_t topk) -> selected RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_pq::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - extern template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index bc899c7ca7..37174f54e1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -19,6 +19,7 @@ #include // raft::distance::DistanceType #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t +#include // NoneSampleFilter #include // codebook_gen #include // RAFT_CUDA_TRY #include // raft::atomicMin @@ -200,6 +201,9 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * @param pq_dim * The dimensionality of an encoded vector after compression by PQ. * @param n_queries the number of queries. + * @param queries_offset + * An offset of the current query batch. It is used for feeding sample_filter with the + * correct query index. * @param metric the distance type. * @param codebook_kind Defines the way PQ codebooks have been trained. * @param topk the `k` in the select top-k. @@ -221,6 +225,12 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * @param index_list * An optional device pointer to the enforced order of search [n_queries, n_probes]. * One can pass reordered indices here to try to improve data reading locality. + * @param query_kth + * query_kths keep the current state of the filtering - atomically updated distances to the + * k-th closest neighbors for each query [n_queries]. + * @param sample_filter + * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. * Ignored when `EnableSMemLut == true`. @@ -236,6 +246,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, */ template ( pq_dim, reinterpret_cast(pq_thread_data), @@ -479,22 +493,27 @@ __global__ void compute_similarity_kernel(uint32_t n_rows, } // The signature of the kernel defined by a minimal set of template parameters -template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); // The config struct lifts the runtime parameters to the template parameters -template +template struct compute_similarity_kernel_config { public: - static auto get(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t + static auto get(uint32_t pq_bits, uint32_t k_max) + -> compute_similarity_kernel_t { return kernel_choose_bits(pq_bits, k_max); } private: static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { switch (pq_bits) { case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); @@ -507,7 +526,8 @@ struct compute_similarity_kernel_config { } template - static auto kernel_try_capacity(uint32_t k_max) -> compute_similarity_kernel_t + static auto kernel_try_capacity(uint32_t k_max) + -> compute_similarity_kernel_t { if constexpr (Capacity > 0) { if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } @@ -515,23 +535,36 @@ struct compute_similarity_kernel_config { if constexpr (Capacity > 1) { if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } } - return compute_similarity_kernel; + return compute_similarity_kernel; } }; // A standalone accessor function was necessary to make sure template // instantiation work correctly. This accessor function is not used anymore and // may be removed. -template +template auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { - return compute_similarity_kernel_config::get(pq_bits, - k_max); + return compute_similarity_kernel_config::get(pq_bits, k_max); } /** Estimate the occupancy for the given kernel on the given device. */ -template +template struct occupancy_t { using shmem_unit = Pow2<128>; @@ -542,7 +575,7 @@ struct occupancy_t { inline occupancy_t() = default; inline occupancy_t(size_t smem, uint32_t n_threads, - compute_similarity_kernel_t kernel, + compute_similarity_kernel_t kernel, const cudaDeviceProp& dev_props) { RAFT_CUDA_TRY( @@ -553,23 +586,24 @@ struct occupancy_t { } }; -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, uint32_t n_probes, uint32_t pq_dim, uint32_t n_queries, + uint32_t queries_offset, distance::DistanceType metric, codebook_gen codebook_kind, uint32_t topk, @@ -582,6 +616,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, + SampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) @@ -591,6 +626,7 @@ void compute_similarity_run(selected s, n_probes, pq_dim, n_queries, + queries_offset, metric, codebook_kind, topk, @@ -603,6 +639,7 @@ void compute_similarity_run(selected s, queries, index_list, query_kths, + sample_filter, lut_scores, _out_scores, _out_indices); @@ -623,7 +660,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -633,7 +670,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected + uint32_t topk) -> selected { // Shared memory for storing the lookup table size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); @@ -705,9 +742,9 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further optimize occupancy and data locality for the L1 cache. */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), @@ -716,8 +753,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; + occupancy_t selected_perf{}; + selected selected_config; for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { if (smem_size_const > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. @@ -753,7 +790,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, continue; } - occupancy_t cur(smem_size, n_threads, kernel, dev_props); + occupancy_t cur(smem_size, n_threads, kernel, dev_props); if (cur.blocks_per_sm <= 0) { // For some reason, we still cannot make this kernel run. Skip the candidate. continue; @@ -768,7 +805,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp(smem_size_tmp, n_threads_tmp, kernel, dev_props); + occupancy_t tmp( + smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false; if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { // Normally, the smaller the block the better for L1 cache hit rate. diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 149ea52b6a..d402a2436b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -414,19 +415,21 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, * 3. split the query batch into smaller chunks, so that the device workspace * is guaranteed to fit into GPU memory. */ -template +template void ivfpq_search_worker(raft::resources const& handle, const index& index, uint32_t max_samples, uint32_t n_probes, uint32_t topK, uint32_t n_queries, + uint32_t queries_offset, // needed for filtering const uint32_t* clusters_to_probe, // [n_queries, n_probes] const float* query, // [n_queries, rot_dim] IdxT* neighbors, // [n_queries, topK] float* distances, // [n_queries, topK] float scaling_factor, double preferred_shmem_carveout, + SampleFilterT sample_filter, rmm::mr::device_memory_resource* mr) { auto stream = resource::get_cuda_stream(handle); @@ -529,16 +532,16 @@ void ivfpq_search_worker(raft::resources const& handle, } auto search_instance = - compute_similarity_select(resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + compute_similarity_select(resource::get_device_properties(handle), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -558,6 +561,7 @@ void ivfpq_search_worker(raft::resources const& handle, n_probes, index.pq_dim(), n_queries, + queries_offset, index.metric(), index.codebook_kind(), topK, @@ -570,6 +574,7 @@ void ivfpq_search_worker(raft::resources const& handle, query, index_list_sorted, query_kths, + sample_filter, device_lut.data(), distances_buf.data(), neighbors_ptr); @@ -605,10 +610,10 @@ void ivfpq_search_worker(raft::resources const& handle, * This structure helps selecting a proper instance of the worker search function, * which contains a few template parameters. */ -template +template struct ivfpq_search { public: - using fun_t = decltype(&ivfpq_search_worker); + using fun_t = decltype(&ivfpq_search_worker); /** * Select an instance of the ivf-pq search function based on search tuning parameters, @@ -624,7 +629,7 @@ struct ivfpq_search { static auto filter_reasonable_instances(const search_params& params) -> fun_t { if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; + return ivfpq_search_worker; } else { RAFT_FAIL( "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " @@ -712,7 +717,7 @@ inline auto get_max_batch_size(uint32_t k, } /** See raft::spatial::knn::ivf_pq::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -721,7 +726,8 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported element type."); @@ -781,7 +787,7 @@ inline void search(raft::resources const& handle, rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - auto search_instance = ivfpq_search::fun(params, index.metric()); + auto search_instance = ivfpq_search::fun(params, index.metric()); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); @@ -830,12 +836,14 @@ inline void search(raft::resources const& handle, n_probes, k, batch_size, + offset_q + offset_b, clusters_to_probe.data() + uint64_t(n_probes) * offset_b, rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, neighbors + uint64_t(k) * (offset_q + offset_b), distances + uint64_t(k) * (offset_q + offset_b), utils::config::kDivisor / utils::config::kDivisor, params.preferred_shmem_carveout, + sample_filter, mr); } } diff --git a/cpp/include/raft/neighbors/detail/sample_filter.cuh b/cpp/include/raft/neighbors/detail/sample_filter.cuh new file mode 100644 index 0000000000..f5c3d91afe --- /dev/null +++ b/cpp/include/raft/neighbors/detail/sample_filter.cuh @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::ivf_pq::detail { + +/* A filter that filters nothing. This is the default behavior. */ +struct NoneSampleFilter { + inline __device__ __host__ bool operator()( + // query index + const uint32_t query_ix, + // the current inverted list index + const uint32_t cluster_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + return true; + } +}; + +/** + * If the filtering depends on the index of a sample, then the following + * filter template can be used: + * + * template + * struct IndexSampleFilter { + * using index_type = IdxT; + * + * const index_type* const* inds_ptr = nullptr; + * + * IndexSampleFilter() {} + * IndexSampleFilter(const index_type* const* _inds_ptr) + * : inds_ptr{_inds_ptr} {} + * IndexSampleFilter(const IndexSampleFilter&) = default; + * IndexSampleFilter(IndexSampleFilter&&) = default; + * IndexSampleFilter& operator=(const IndexSampleFilter&) = default; + * IndexSampleFilter& operator=(IndexSampleFilter&&) = default; + * + * inline __device__ __host__ bool operator()( + * const uint32_t query_ix, + * const uint32_t cluster_ix, + * const uint32_t sample_ix) const { + * index_type database_idx = inds_ptr[cluster_ix][sample_ix]; + * + * // return true or false, depending on the database_idx + * return true; + * } + * }; + * + * Initialize it as: + * using filter_type = IndexSampleFilter; + * filter_type filter(raft_ivfpq_index.inds_ptrs().data_handle()); + * + * Use it as: + * raft::neighbors::ivf_pq::search_with_filtering( + * ...regular parameters here..., + * filter + * ); + * + * Another example would be the following filter that greenlights samples according + * to a contiguous bit mask vector. + * + * template + * struct BitMaskSampleFilter { + * using index_type = IdxT; + * + * const index_type* const* inds_ptr = nullptr; + * const uint64_t* const bit_mask_ptr = nullptr; + * const int64_t bit_mask_stride_64 = 0; + * + * BitMaskSampleFilter() {} + * BitMaskSampleFilter( + * const index_type* const* _inds_ptr, + * const uint64_t* const _bit_mask_ptr, + * const int64_t _bit_mask_stride_64) + * : inds_ptr{_inds_ptr}, + * bit_mask_ptr{_bit_mask_ptr}, + * bit_mask_stride_64{_bit_mask_stride_64} {} + * BitMaskSampleFilter(const BitMaskSampleFilter&) = default; + * BitMaskSampleFilter(BitMaskSampleFilter&&) = default; + * BitMaskSampleFilter& operator=(const BitMaskSampleFilter&) = default; + * BitMaskSampleFilter& operator=(BitMaskSampleFilter&&) = default; + * + * inline __device__ __host__ bool operator()( + * const uint32_t query_ix, + * const uint32_t cluster_ix, + * const uint32_t sample_ix) const { + * const index_type database_idx = inds_ptr[cluster_ix][sample_ix]; + * const uint64_t bit_mask_element = + * bit_mask_ptr[query_ix * bit_mask_stride_64 + database_idx / 64]; + * const uint64_t masked_bool = + * bit_mask_element & (1ULL << (uint64_t)(database_idx % 64)); + * const bool is_bit_set = (masked_bool != 0); + * + * return is_bit_set; + * } + * }; + */ +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index 42dc776c97..f203709b1b 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -45,6 +45,15 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* idx) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + SampleFilterT sample_filter) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, @@ -74,6 +83,18 @@ void extend(raft::resources const& handle, const IdxT* new_indices, IdxT n_rows) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const raft::neighbors::ivf_pq::search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const raft::neighbors::ivf_pq::search_params& params, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index 83e7931c78..e2e60f0cd3 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -133,7 +133,7 @@ void extend(raft::resources const& handle, } /** - * @brief Search ANN using the constructed index. + * @brief Search ANN using the constructed index using the given filter. * * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. * @@ -156,14 +156,16 @@ void extend(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter a filter the greenlights samples for a given query. */ -template -void search(raft::resources const& handle, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + SampleFilterT sample_filter = SampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -176,15 +178,53 @@ void search(raft::resources const& handle, "Number of query dimensions should equal number of dimensions in the index."); std::uint32_t k = neighbors.extent(1); - return detail::search(handle, - params, - idx, - queries.data_handle(), - static_cast(queries.extent(0)), - k, - neighbors.data_handle(), - distances.data_handle(), - resource::get_workspace_resource(handle)); + detail::search(handle, + params, + idx, + queries.data_handle(), + static_cast(queries.extent(0)), + k, + neighbors.data_handle(), + distances.data_handle(), + resource::get_workspace_resource(handle), + sample_filter); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + search_with_filtering( + handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter()); } /** @} */ // end group ivf_pq @@ -297,6 +337,22 @@ void extend(raft::resources const& handle, detail::extend(handle, idx, new_vectors, new_indices, n_rows); } +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) +{ + detail::search( + handle, params, idx, queries, n_queries, k, neighbors, distances, mr, sample_filter); +} + /** * @brief Search ANN using the constructed index. * @@ -350,7 +406,7 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); + detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } } // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index a740d01bd2..ac547626bb 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -41,8 +41,8 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, SampleFilterT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ const cudaDeviceProp& dev_props, \\ bool manage_local_topk, \\ int locality_hint, \\ @@ -52,16 +52,17 @@ uint32_t precomp_data_count, \\ uint32_t n_queries, \\ uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ \\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ rmm::cuda_stream_view stream, \\ uint32_t n_rows, \\ uint32_t dim, \\ uint32_t n_probes, \\ uint32_t pq_dim, \\ uint32_t n_queries, \\ + uint32_t queries_offset, \\ raft::distance::DistanceType metric, \\ raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\ uint32_t topk, \\ @@ -74,6 +75,7 @@ const float* queries, \\ const uint32_t* index_list, \\ float* query_kths, \\ + SampleFilterT sample_filter, \\ LutT* lut_scores, \\ OutT* _out_scores, \\ uint32_t* _out_indices); @@ -102,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT});\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::ivf_pq::detail::NoneSampleFilter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 956b7010d5..67b67df19f 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -27,46 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index fba72ad1dd..1c97a1c9ba 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 030f429315..14e2d19fe7 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 31a4d7d503..7fd3a8d0b2 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -27,46 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index c623c80446..01df4d87e3 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index f2aaca20db..251515a552 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -27,47 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index 4420b2534b..b29f4bca96 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -27,46 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, SampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + SampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); #undef COMMA