diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 46f72c4005..47f3e8888c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -16,24 +16,27 @@ #pragma once -#include // uintX_t -#include // raft::neighbors::ivf_flat::index -#include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::ivf_flat::detail { -template +template void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, + const uint32_t queries_offset, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, const bool select_min, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -43,23 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + extern template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 4eed2aa453..18f1906dc5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -646,6 +646,7 @@ struct loadAndComputeDist { * @param n_probes * @param k * @param dim + * @param sample_filter * @param[out] neighbors * @param[out] distances */ @@ -655,6 +656,7 @@ template __global__ void __launch_bounds__(kThreadsPerBlock) @@ -666,9 +668,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, + const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, const uint32_t dim, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances) { @@ -736,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool valid = vec_id < list_length; // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { + if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) { loadAndComputeDist lc(dist, compute_dist); for (int pos = 0; pos < shm_assisted_dim; @@ -803,6 +807,7 @@ template void launch_kernel(Lambda lambda, @@ -811,8 +816,10 @@ void launch_kernel(Lambda lambda, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, + const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -820,8 +827,15 @@ void launch_kernel(Lambda lambda, { RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = - interleaved_scan_kernel; + constexpr auto kKernel = interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); @@ -860,9 +874,11 @@ void launch_kernel(Lambda lambda, index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), + queries_offset + query_offset, n_probes, k, index.dim(), + sample_filter, neighbors, distances); queries += grid_dim_y * index.dim(); @@ -931,6 +947,7 @@ template void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) { @@ -943,6 +960,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + IvfSampleFilterT, euclidean_dist, raft::identity_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::L2SqrtExpanded: @@ -953,6 +971,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + IvfSampleFilterT, euclidean_dist, raft::sqrt_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::InnerProduct: @@ -962,6 +981,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + IvfSampleFilterT, inner_prod_dist, raft::identity_op>({}, {}, std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. @@ -976,6 +996,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg template (1, 16 / sizeof(T))> struct select_interleaved_scan_kernel { @@ -990,13 +1011,20 @@ struct select_interleaved_scan_kernel { { if constexpr (Capacity > 1) { if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); + return select_interleaved_scan_kernel::run(capacity, + veclen, + select_min, + std::forward(args)...); } } if constexpr (Veclen > 1) { if (veclen % Veclen != 0) { - return select_interleaved_scan_kernel::run( + return select_interleaved_scan_kernel::run( capacity, 1, select_min, std::forward(args)...); } } @@ -1010,9 +1038,11 @@ struct select_interleaved_scan_kernel { veclen == Veclen, "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); if (select_min) { - launch_with_fixed_consts(std::forward(args)...); + launch_with_fixed_consts( + std::forward(args)...); } else { - launch_with_fixed_consts(std::forward(args)...); + launch_with_fixed_consts( + std::forward(args)...); } } }; @@ -1028,6 +1058,9 @@ struct select_interleaved_scan_kernel { * @param[in] queries device pointer to the query vectors [batch_size, dim] * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] * @param n_queries batch size + * @param[in] queries_offset + * An offset of the current query batch. It is used for feeding sample_filter with the + * correct query index. * @param metric type of the measured distance * @param n_probes number of nearest clusters to query * @param k number of nearest neighbors. @@ -1041,36 +1074,43 @@ struct select_interleaved_scan_kernel { * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) * @param stream + * @param sample_filter + * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to + * provide a green light for every sample. */ -template +template void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, + const uint32_t queries_offset, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, const bool select_min, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - n_probes, - k, - neighbors, - distances, - grid_dim_x, - stream); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + queries_offset, + n_probes, + k, + sample_filter, + neighbors, + distances, + grid_dim_x, + stream); } } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index b97e64a259..976d15a61c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -16,15 +16,16 @@ #pragma once -#include // uintX_t -#include // raft::neighbors::ivf_flat::index -#include // RAFT_EXPLICIT +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::ivf_flat::detail { -template +template void search(raft::resources const& handle, const search_params& params, const raft::neighbors::ivf_flat::index& index, @@ -33,26 +34,31 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) - -instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + IvfSampleFilterT sample_filter) + +instantiate_raft_neighbors_ivf_flat_detail_search( + float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 66ad9682d7..366e9bfcd5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,6 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter #include // utils::mapping #include // rmm::device_memory_resource @@ -33,17 +34,19 @@ namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -template +template void search_impl(raft::resources const& handle, const raft::neighbors::ivf_flat::index& index, const T* queries, uint32_t n_queries, + uint32_t queries_offset, uint32_t k, uint32_t n_probes, bool select_min, IdxT* neighbors, AccT* distances, - rmm::mr::device_memory_resource* search_mr) + rmm::mr::device_memory_resource* search_mr, + IvfSampleFilterT sample_filter) { auto stream = resource::get_cuda_stream(handle); // The norm of query @@ -143,18 +146,21 @@ void search_impl(raft::resources const& handle, uint32_t grid_dim_x = 0; if (n_probes > 1) { // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT>(index, - nullptr, - nullptr, - n_queries, - index.metric(), - n_probes, - k, - select_min, - nullptr, - nullptr, - grid_dim_x, - stream); + ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( + index, + nullptr, + nullptr, + n_queries, + queries_offset, + index.metric(), + n_probes, + k, + select_min, + sample_filter, + nullptr, + nullptr, + grid_dim_x, + stream); } else { grid_dim_x = 1; } @@ -164,18 +170,21 @@ void search_impl(raft::resources const& handle, indices_dev_ptr = neighbors; } - ivfflat_interleaved_scan::value_t, IdxT>(index, - queries, - coarse_indices_dev.data(), - n_queries, - index.metric(), - n_probes, - k, - select_min, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); + ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( + index, + queries, + coarse_indices_dev.data(), + n_queries, + queries_offset, + index.metric(), + n_probes, + k, + select_min, + sample_filter, + indices_dev_ptr, + distances_dev_ptr, + grid_dim_x, + stream); RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); @@ -196,7 +205,9 @@ void search_impl(raft::resources const& handle, } /** See raft::neighbors::ivf_flat::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -205,7 +216,8 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); @@ -230,16 +242,18 @@ inline void search(raft::resources const& handle, for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); - search_impl(handle, - index, - queries + offset_q * index.dim(), - queries_batch, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors + offset_q * k, - distances + offset_q * k, - mr); + search_impl(handle, + index, + queries + offset_q * index.dim(), + queries_batch, + offset_q, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors + offset_q * k, + distances + offset_q * k, + mr, + sample_filter); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 62e46e3ae1..0ae2e23b63 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -20,8 +20,8 @@ #include // RAFT_WEAK_FUNCTION #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit -#include // NoneSampleFilter #include // raft::neighbors::ivf_pq::codebook_gen +#include // none_ivf_sample_filter #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -37,7 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, @@ -100,7 +100,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) RAFT_EXPLICIT; @@ -119,7 +119,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -129,78 +129,79 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected RAFT_EXPLICIT; + uint32_t topk) + -> selected RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_pq::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - extern template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + half, half, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, half, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, float, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 37174f54e1..2fefa900c3 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -19,8 +19,8 @@ #include // raft::distance::DistanceType #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t -#include // NoneSampleFilter #include // codebook_gen +#include // none_ivf_sample_filter #include // RAFT_CUDA_TRY #include // raft::atomicMin #include // raft::Pow2 @@ -229,7 +229,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * query_kths keep the current state of the filtering - atomically updated distances to the * k-th closest neighbors for each query [n_queries]. * @param sample_filter - * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. @@ -246,7 +246,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, */ template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); // The config struct lifts the runtime parameters to the template parameters template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> struct compute_similarity_kernel_config { public: static auto get(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { return kernel_choose_bits(pq_bits, k_max); } private: static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { switch (pq_bits) { case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); @@ -527,7 +529,7 @@ struct compute_similarity_kernel_config { template static auto kernel_try_capacity(uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { if constexpr (Capacity > 0) { if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } @@ -537,7 +539,7 @@ struct compute_similarity_kernel_config { } return compute_similarity_kernel + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { return compute_similarity_kernel_config::get(pq_bits, k_max); + IvfSampleFilterT>::get(pq_bits, k_max); } /** Estimate the occupancy for the given kernel on the given device. */ -template +template struct occupancy_t { using shmem_unit = Pow2<128>; @@ -575,7 +577,7 @@ struct occupancy_t { inline occupancy_t() = default; inline occupancy_t(size_t smem, uint32_t n_threads, - compute_similarity_kernel_t kernel, + compute_similarity_kernel_t kernel, const cudaDeviceProp& dev_props) { RAFT_CUDA_TRY( @@ -586,17 +588,19 @@ struct occupancy_t { } }; -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, @@ -616,7 +620,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) @@ -660,7 +664,9 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -670,7 +676,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected + uint32_t topk) -> selected { // Shared memory for storing the lookup table size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); @@ -742,9 +748,9 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further optimize occupancy and data locality for the L1 cache. */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), @@ -753,8 +759,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; + occupancy_t selected_perf{}; + selected selected_config; for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { if (smem_size_const > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. @@ -790,7 +796,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, continue; } - occupancy_t cur(smem_size, n_threads, kernel, dev_props); + occupancy_t cur(smem_size, n_threads, kernel, dev_props); if (cur.blocks_per_sm <= 0) { // For some reason, we still cannot make this kernel run. Skip the candidate. continue; @@ -805,7 +811,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp( + occupancy_t tmp( smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false; if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d402a2436b..8257f5ed35 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include #include @@ -415,7 +415,7 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, * 3. split the query batch into smaller chunks, so that the device workspace * is guaranteed to fit into GPU memory. */ -template +template void ivfpq_search_worker(raft::resources const& handle, const index& index, uint32_t max_samples, @@ -429,7 +429,7 @@ void ivfpq_search_worker(raft::resources const& handle, float* distances, // [n_queries, topK] float scaling_factor, double preferred_shmem_carveout, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, rmm::mr::device_memory_resource* mr) { auto stream = resource::get_cuda_stream(handle); @@ -531,17 +531,17 @@ void ivfpq_search_worker(raft::resources const& handle, } break; } - auto search_instance = - compute_similarity_select(resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + auto search_instance = compute_similarity_select( + resource::get_device_properties(handle), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -610,10 +610,10 @@ void ivfpq_search_worker(raft::resources const& handle, * This structure helps selecting a proper instance of the worker search function, * which contains a few template parameters. */ -template +template struct ivfpq_search { public: - using fun_t = decltype(&ivfpq_search_worker); + using fun_t = decltype(&ivfpq_search_worker); /** * Select an instance of the ivf-pq search function based on search tuning parameters, @@ -629,7 +629,7 @@ struct ivfpq_search { static auto filter_reasonable_instances(const search_params& params) -> fun_t { if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; + return ivfpq_search_worker; } else { RAFT_FAIL( "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " @@ -717,7 +717,9 @@ inline auto get_max_batch_size(uint32_t k, } /** See raft::spatial::knn::ivf_pq::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -727,7 +729,7 @@ inline void search(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported element type."); @@ -787,7 +789,7 @@ inline void search(raft::resources const& handle, rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - auto search_instance = ivfpq_search::fun(params, index.metric()); + auto search_instance = ivfpq_search::fun(params, index.metric()); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 64f9511ff9..251f725361 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -129,10 +130,12 @@ void refine_device(raft::resources const& handle, queries.data_handle(), fake_coarse_idx.data(), static_cast(n_queries), + 0, refinement_index.metric(), 1, k, raft::distance::is_min_close(metric), + raft::neighbors::filtering::none_ivf_sample_filter(), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index dff7b6b2ab..848703c9b5 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -74,6 +74,18 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* index) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, @@ -85,6 +97,15 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 739e012e08..a18ee065bf 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -357,6 +357,69 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_flat::search_params search_params; + * filtering::none_ivf_sample_filter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) +{ + raft::neighbors::ivf_flat::detail::search( + handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); +} + +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); @@ -394,8 +457,16 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); + raft::neighbors::ivf_flat::detail::search(handle, + params, + index, + queries, + n_queries, + k, + neighbors, + distances, + mr, + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @@ -403,6 +474,74 @@ void search(raft::resources const& handle, * @{ */ +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * filtering::none_ivf_sample_filter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries1, out_inds1, out_dists1, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries2, out_inds2, out_dists2, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries3, out_inds3, out_dists3, filter); + * ... + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + search_with_filtering(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(neighbors.extent(1)), + neighbors.data_handle(), + distances.data_handle(), + resource::get_workspace_resource(handle), + sample_filter); +} + /** * @brief Search ANN using the constructed index. * @@ -443,25 +582,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - nullptr); + search_with_filtering(handle, + params, + index, + queries, + neighbors, + distances, + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index 5b7391569b..1595f55d8c 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -45,14 +45,14 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* idx) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter) RAFT_EXPLICIT; template void search(raft::resources const& handle, @@ -83,7 +83,7 @@ void extend(raft::resources const& handle, const IdxT* new_indices, IdxT n_rows) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const raft::neighbors::ivf_pq::search_params& params, const index& idx, @@ -93,7 +93,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template void search(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index fbe2fcb30d..ad9d95f790 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -158,14 +158,14 @@ void extend(raft::resources const& handle, * k] * @param[in] sample_filter a filter the greenlights samples for a given query. */ -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -223,8 +223,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { - search_with_filtering( - handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter()); + search_with_filtering(handle, + params, + idx, + queries, + neighbors, + distances, + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @} */ // end group ivf_pq @@ -337,7 +342,54 @@ void extend(raft::resources const& handle, detail::extend(handle, idx, new_vectors, new_indices, n_rows); } -template +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_pq::search_params search_params; + * filtering::none_ivf_sample_filter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, @@ -347,7 +399,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { detail::search( handle, params, idx, queries, n_queries, k, neighbors, distances, mr, sample_filter); diff --git a/cpp/include/raft/neighbors/detail/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter_types.hpp similarity index 70% rename from cpp/include/raft/neighbors/detail/sample_filter.cuh rename to cpp/include/raft/neighbors/sample_filter_types.hpp index f5c3d91afe..5a301e9d2f 100644 --- a/cpp/include/raft/neighbors/detail/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -19,11 +19,13 @@ #include #include -namespace raft::neighbors::ivf_pq::detail { +#include + +namespace raft::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ -struct NoneSampleFilter { - inline __device__ __host__ bool operator()( +struct none_ivf_sample_filter { + inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, // the current inverted list index @@ -40,20 +42,20 @@ struct NoneSampleFilter { * filter template can be used: * * template - * struct IndexSampleFilter { + * struct index_ivf_sample_filter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * - * IndexSampleFilter() {} - * IndexSampleFilter(const index_type* const* _inds_ptr) + * index_ivf_sample_filter() {} + * index_ivf_sample_filter(const index_type* const* _inds_ptr) * : inds_ptr{_inds_ptr} {} - * IndexSampleFilter(const IndexSampleFilter&) = default; - * IndexSampleFilter(IndexSampleFilter&&) = default; - * IndexSampleFilter& operator=(const IndexSampleFilter&) = default; - * IndexSampleFilter& operator=(IndexSampleFilter&&) = default; + * index_ivf_sample_filter(const index_ivf_sample_filter&) = default; + * index_ivf_sample_filter(index_ivf_sample_filter&&) = default; + * index_ivf_sample_filter& operator=(const index_ivf_sample_filter&) = default; + * index_ivf_sample_filter& operator=(index_ivf_sample_filter&&) = default; * - * inline __device__ __host__ bool operator()( + * inline _RAFT_HOST_DEVICE bool operator()( * const uint32_t query_ix, * const uint32_t cluster_ix, * const uint32_t sample_ix) const { @@ -65,7 +67,7 @@ struct NoneSampleFilter { * }; * * Initialize it as: - * using filter_type = IndexSampleFilter; + * using filter_type = index_ivf_sample_filter; * filter_type filter(raft_ivfpq_index.inds_ptrs().data_handle()); * * Use it as: @@ -78,27 +80,27 @@ struct NoneSampleFilter { * to a contiguous bit mask vector. * * template - * struct BitMaskSampleFilter { + * struct bitmask_ivf_sample_filter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * const uint64_t* const bit_mask_ptr = nullptr; * const int64_t bit_mask_stride_64 = 0; * - * BitMaskSampleFilter() {} - * BitMaskSampleFilter( + * bitmask_ivf_sample_filter() {} + * bitmask_ivf_sample_filter( * const index_type* const* _inds_ptr, * const uint64_t* const _bit_mask_ptr, * const int64_t _bit_mask_stride_64) * : inds_ptr{_inds_ptr}, * bit_mask_ptr{_bit_mask_ptr}, * bit_mask_stride_64{_bit_mask_stride_64} {} - * BitMaskSampleFilter(const BitMaskSampleFilter&) = default; - * BitMaskSampleFilter(BitMaskSampleFilter&&) = default; - * BitMaskSampleFilter& operator=(const BitMaskSampleFilter&) = default; - * BitMaskSampleFilter& operator=(BitMaskSampleFilter&&) = default; + * bitmask_ivf_sample_filter(const bitmask_ivf_sample_filter&) = default; + * bitmask_ivf_sample_filter(bitmask_ivf_sample_filter&&) = default; + * bitmask_ivf_sample_filter& operator=(const bitmask_ivf_sample_filter&) = default; + * bitmask_ivf_sample_filter& operator=(bitmask_ivf_sample_filter&&) = default; * - * inline __device__ __host__ bool operator()( + * inline _RAFT_HOST_DEVICE bool operator()( * const uint32_t query_ix, * const uint32_t cluster_ix, * const uint32_t sample_ix) const { @@ -113,4 +115,4 @@ struct NoneSampleFilter { * } * }; */ -} // namespace raft::neighbors::ivf_pq::detail +} // namespace raft::neighbors::filtering diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 4dfa2a707c..a1d6cca7d5 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 2d54248e4d..514301562d 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 75fe52f3c7..32698a8e80 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 001281c8fc..9d39607750 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -15,21 +15,26 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + IvfSampleFilterT sample_filter) -instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search( + float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index ac547626bb..19c3070fd2 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -41,8 +41,8 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, SampleFilterT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, IvfSampleFilterT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ const cudaDeviceProp& dev_props, \\ bool manage_local_topk, \\ int locality_hint, \\ @@ -52,10 +52,10 @@ uint32_t precomp_data_count, \\ uint32_t n_queries, \\ uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ \\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ rmm::cuda_stream_view stream, \\ uint32_t n_rows, \\ uint32_t dim, \\ @@ -75,7 +75,7 @@ const float* queries, \\ const uint32_t* index_list, \\ float* query_kths, \\ - SampleFilterT sample_filter, \\ + IvfSampleFilterT sample_filter, \\ LutT* lut_scores, \\ OutT* _out_scores, \\ uint32_t* _out_indices); @@ -104,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::ivf_pq::detail::NoneSampleFilter);\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::none_ivf_sample_filter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 67b67df19f..46642b5595 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, float, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index 1c97a1c9ba..03d9fb9171 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 14e2d19fe7..221be5b4fd 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 7fd3a8d0b2..b665a37040 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, half, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 01df4d87e3..1acdab4c2e 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 251515a552..a8ad62c51b 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index b29f4bca96..91a69b0e54 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + half, half, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA