From a27fd30afa2cacd4afe836ed3b3703fc106c7dc4 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Sat, 20 May 2023 22:06:51 +0000 Subject: [PATCH 1/6] Add sample filtering for ivf_flat. Filtering code refactoring and cleanup --- .../detail/ivf_flat_interleaved_scan-ext.cuh | 44 +++-- .../detail/ivf_flat_interleaved_scan-inl.cuh | 79 ++++++--- .../neighbors/detail/ivf_flat_search-ext.cuh | 43 +++-- .../neighbors/detail/ivf_flat_search-inl.cuh | 90 ++++++---- .../detail/ivf_pq_compute_similarity-ext.cuh | 16 +- .../detail/ivf_pq_compute_similarity-inl.cuh | 18 +- .../raft/neighbors/detail/ivf_pq_search.cuh | 6 +- cpp/include/raft/neighbors/detail/refine.cuh | 4 + cpp/include/raft/neighbors/ivf_flat-ext.cuh | 21 +++ cpp/include/raft/neighbors/ivf_flat-inl.cuh | 165 +++++++++++++++--- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 54 +++++- .../neighbors/{detail => }/sample_filter.cuh | 4 +- ...at_interleaved_scan_float_float_int64_t.cu | 34 ++-- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 34 ++-- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 34 ++-- cpp/src/neighbors/detail/ivf_flat_search.cu | 36 ++-- .../ivf_pq_compute_similarity_00_generate.py | 2 +- .../ivf_pq_compute_similarity_float_float.cu | 2 +- ...f_pq_compute_similarity_float_fp8_false.cu | 2 +- ...vf_pq_compute_similarity_float_fp8_true.cu | 2 +- .../ivf_pq_compute_similarity_float_half.cu | 2 +- ...vf_pq_compute_similarity_half_fp8_false.cu | 2 +- ...ivf_pq_compute_similarity_half_fp8_true.cu | 2 +- .../ivf_pq_compute_similarity_half_half.cu | 2 +- 24 files changed, 498 insertions(+), 200 deletions(-) rename cpp/include/raft/neighbors/{detail => }/sample_filter.cuh (97%) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 46f72c4005..6a47ce823e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -18,6 +18,7 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index +#include // NoneSampleFilter #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view @@ -25,15 +26,17 @@ namespace raft::neighbors::ivf_flat::detail { -template +template void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, + const uint32_t queries_offset, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, const bool select_min, + SampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -43,23 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, SampleFilterT) \ + extern template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + SampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + float, float, int64_t, raft::neighbors::filtering::NoneSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 4eed2aa453..7269633070 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -646,6 +646,7 @@ struct loadAndComputeDist { * @param n_probes * @param k * @param dim + * @param sample_filter * @param[out] neighbors * @param[out] distances */ @@ -655,6 +656,7 @@ template __global__ void __launch_bounds__(kThreadsPerBlock) @@ -666,9 +668,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, + const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, const uint32_t dim, + SampleFilterT sample_filter, IdxT* neighbors, float* distances) { @@ -736,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool valid = vec_id < list_length; // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { + if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) { loadAndComputeDist lc(dist, compute_dist); for (int pos = 0; pos < shm_assisted_dim; @@ -803,6 +807,7 @@ template void launch_kernel(Lambda lambda, @@ -811,8 +816,10 @@ void launch_kernel(Lambda lambda, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, + const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, + SampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -820,8 +827,15 @@ void launch_kernel(Lambda lambda, { RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = - interleaved_scan_kernel; + constexpr auto kKernel = interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); @@ -860,9 +874,11 @@ void launch_kernel(Lambda lambda, index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), + queries_offset + query_offset, n_probes, k, index.dim(), + sample_filter, neighbors, distances); queries += grid_dim_y * index.dim(); @@ -931,6 +947,7 @@ template void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) { @@ -943,6 +960,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + SampleFilterT, euclidean_dist, raft::identity_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::L2SqrtExpanded: @@ -953,6 +971,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + SampleFilterT, euclidean_dist, raft::sqrt_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::InnerProduct: @@ -962,6 +981,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + SampleFilterT, inner_prod_dist, raft::identity_op>({}, {}, std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. @@ -976,6 +996,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg template (1, 16 / sizeof(T))> struct select_interleaved_scan_kernel { @@ -990,13 +1011,13 @@ struct select_interleaved_scan_kernel { { if constexpr (Capacity > 1) { if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); + return select_interleaved_scan_kernel:: + run(capacity, veclen, select_min, std::forward(args)...); } } if constexpr (Veclen > 1) { if (veclen % Veclen != 0) { - return select_interleaved_scan_kernel::run( + return select_interleaved_scan_kernel::run( capacity, 1, select_min, std::forward(args)...); } } @@ -1010,9 +1031,11 @@ struct select_interleaved_scan_kernel { veclen == Veclen, "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); if (select_min) { - launch_with_fixed_consts(std::forward(args)...); + launch_with_fixed_consts( + std::forward(args)...); } else { - launch_with_fixed_consts(std::forward(args)...); + launch_with_fixed_consts( + std::forward(args)...); } } }; @@ -1028,6 +1051,9 @@ struct select_interleaved_scan_kernel { * @param[in] queries device pointer to the query vectors [batch_size, dim] * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] * @param n_queries batch size + * @param[in] queries_offset + * An offset of the current query batch. It is used for feeding sample_filter with the + * correct query index. * @param metric type of the measured distance * @param n_probes number of nearest clusters to query * @param k number of nearest neighbors. @@ -1041,36 +1067,43 @@ struct select_interleaved_scan_kernel { * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) * @param stream + * @param sample_filter + * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * provide a green light for every sample. */ -template +template void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, + const uint32_t queries_offset, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, const bool select_min, + SampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - n_probes, - k, - neighbors, - distances, - grid_dim_x, - stream); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + queries_offset, + n_probes, + k, + sample_filter, + neighbors, + distances, + grid_dim_x, + stream); } } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index b97e64a259..98cc0552ce 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -18,13 +18,14 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index +#include // NoneSampleFilter #include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::ivf_flat::detail { -template +template void search(raft::resources const& handle, const search_params& params, const raft::neighbors::ivf_flat::index& index, @@ -33,26 +34,34 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) - -instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, SampleFilterT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + SampleFilterT sample_filter) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, + int64_t, + raft::neighbors::filtering::NoneSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, + int64_t, + raft::neighbors::filtering::NoneSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, + int64_t, + raft::neighbors::filtering::NoneSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 66ad9682d7..df5c6db615 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,6 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index +#include // NoneSampleFilter #include // utils::mapping #include // rmm::device_memory_resource @@ -33,17 +34,19 @@ namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -template +template void search_impl(raft::resources const& handle, const raft::neighbors::ivf_flat::index& index, const T* queries, uint32_t n_queries, + uint32_t queries_offset, uint32_t k, uint32_t n_probes, bool select_min, IdxT* neighbors, AccT* distances, - rmm::mr::device_memory_resource* search_mr) + rmm::mr::device_memory_resource* search_mr, + SampleFilterT sample_filter) { auto stream = resource::get_cuda_stream(handle); // The norm of query @@ -143,18 +146,21 @@ void search_impl(raft::resources const& handle, uint32_t grid_dim_x = 0; if (n_probes > 1) { // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT>(index, - nullptr, - nullptr, - n_queries, - index.metric(), - n_probes, - k, - select_min, - nullptr, - nullptr, - grid_dim_x, - stream); + ivfflat_interleaved_scan::value_t, IdxT, SampleFilterT>( + index, + nullptr, + nullptr, + n_queries, + queries_offset, + index.metric(), + n_probes, + k, + select_min, + sample_filter, + nullptr, + nullptr, + grid_dim_x, + stream); } else { grid_dim_x = 1; } @@ -164,18 +170,21 @@ void search_impl(raft::resources const& handle, indices_dev_ptr = neighbors; } - ivfflat_interleaved_scan::value_t, IdxT>(index, - queries, - coarse_indices_dev.data(), - n_queries, - index.metric(), - n_probes, - k, - select_min, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); + ivfflat_interleaved_scan::value_t, IdxT, SampleFilterT>( + index, + queries, + coarse_indices_dev.data(), + n_queries, + queries_offset, + index.metric(), + n_probes, + k, + select_min, + sample_filter, + indices_dev_ptr, + distances_dev_ptr, + grid_dim_x, + stream); RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); @@ -196,7 +205,9 @@ void search_impl(raft::resources const& handle, } /** See raft::neighbors::ivf_flat::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -205,7 +216,8 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); @@ -230,16 +242,18 @@ inline void search(raft::resources const& handle, for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); - search_impl(handle, - index, - queries + offset_q * index.dim(), - queries_batch, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors + offset_q * k, - distances + offset_q * k, - mr); + search_impl(handle, + index, + queries + offset_q * index.dim(), + queries_batch, + offset_q, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors + offset_q * k, + distances + offset_q * k, + mr, + sample_filter); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 62e46e3ae1..f93b6610c3 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -20,8 +20,8 @@ #include // RAFT_WEAK_FUNCTION #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit -#include // NoneSampleFilter #include // raft::neighbors::ivf_pq::codebook_gen +#include // NoneSampleFilter #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -182,25 +182,25 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + half, half, raft::neighbors::filtering::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, half, raft::neighbors::filtering::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, float, raft::neighbors::filtering::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 37174f54e1..4be9695d7e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -19,8 +19,8 @@ #include // raft::distance::DistanceType #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t -#include // NoneSampleFilter #include // codebook_gen +#include // NoneSampleFilter #include // RAFT_CUDA_TRY #include // raft::atomicMin #include // raft::Pow2 @@ -493,7 +493,9 @@ __global__ void compute_similarity_kernel(uint32_t n_rows, } // The signature of the kernel defined by a minimal set of template parameters -template +template using compute_similarity_kernel_t = decltype(&compute_similarity_kernel); @@ -502,7 +504,7 @@ template + typename SampleFilterT = raft::neighbors::filtering::NoneSampleFilter> struct compute_similarity_kernel_config { public: static auto get(uint32_t pq_bits, uint32_t k_max) @@ -552,7 +554,7 @@ template + typename SampleFilterT = raft::neighbors::filtering::NoneSampleFilter> auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t { @@ -595,7 +597,9 @@ struct selected { size_t device_lut_size; }; -template +template void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, @@ -660,7 +664,9 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d402a2436b..ff61b98e02 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include #include @@ -717,7 +717,9 @@ inline auto get_max_batch_size(uint32_t k, } /** See raft::spatial::knn::ivf_pq::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 64f9511ff9..4a0f9312e1 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -129,10 +130,13 @@ void refine_device(raft::resources const& handle, queries.data_handle(), fake_coarse_idx.data(), static_cast(n_queries), + 0, refinement_index.metric(), 1, k, raft::distance::is_min_close(metric), + // TODO: add the filtering support + raft::neighbors::filtering::NoneSampleFilter(), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index dff7b6b2ab..c29d8b00bb 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -74,6 +74,18 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* index) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, @@ -85,6 +97,15 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 739e012e08..ebbc3d21c8 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -357,6 +357,67 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_flat::search_params search_params; + * filtering::NoneSampleFilter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search_with_filtering(handle, search_params, index, queries1, N1, K, out_inds1, + * out_dists1, &mr, filter); ivf_flat::search_with_filtering(handle, search_params, index, queries2, + * N2, K, out_inds2, out_dists2, &mr, filter); ivf_flat::search_with_filtering(handle, + * search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + SampleFilterT sample_filter = SampleFilterT()) +{ + raft::neighbors::ivf_flat::detail::search( + handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); +} + +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); @@ -394,8 +455,16 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); + raft::neighbors::ivf_flat::detail::search(handle, + params, + index, + queries, + n_queries, + k, + neighbors, + distances, + mr, + raft::neighbors::filtering::NoneSampleFilter()); } /** @@ -403,6 +472,72 @@ void search(raft::resources const& handle, * @{ */ +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * filtering::NoneSampleFilter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search_with_filtering(handle, search_params, index, queries1, out_inds1, out_dists1, + * filter); ivf_flat::search_with_filtering(handle, search_params, index, queries2, out_inds2, + * out_dists2, filter); ivf_flat::search_with_filtering(handle, search_params, index, queries3, + * out_inds3, out_dists3, filter); + * ... + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + SampleFilterT sample_filter = SampleFilterT()) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + search_with_filtering(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(neighbors.extent(1)), + neighbors.data_handle(), + distances.data_handle(), + resource::get_workspace_resource(handle), + sample_filter); +} + /** * @brief Search ANN using the constructed index. * @@ -443,25 +578,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - nullptr); + search_with_filtering(handle, + params, + index, + queries, + neighbors, + distances, + raft::neighbors::filtering::NoneSampleFilter()); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index fbe2fcb30d..0937a215b0 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -223,8 +223,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { - search_with_filtering( - handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter()); + search_with_filtering(handle, + params, + idx, + queries, + neighbors, + distances, + raft::neighbors::filtering::NoneSampleFilter()); } /** @} */ // end group ivf_pq @@ -337,6 +342,51 @@ void extend(raft::resources const& handle, detail::extend(handle, idx, new_vectors, new_indices, n_rows); } +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_pq::search_params search_params; + * filtering::NoneSampleFilter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search_with_filtering(handle, search_params, index, queries1, N1, K, out_inds1, + * out_dists1, &mr, filter); ivf_pq::search_with_filtering(handle, search_params, index, queries2, + * N2, K, out_inds2, out_dists2, &mr, filter); ivf_pq::search_with_filtering(handle, search_params, + * index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + * @param[in] sample_filter a filter the greenlights samples for a given query + */ template void search_with_filtering(raft::resources const& handle, const search_params& params, diff --git a/cpp/include/raft/neighbors/detail/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh similarity index 97% rename from cpp/include/raft/neighbors/detail/sample_filter.cuh rename to cpp/include/raft/neighbors/sample_filter.cuh index f5c3d91afe..9ceb2f5fc2 100644 --- a/cpp/include/raft/neighbors/detail/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -19,7 +19,7 @@ #include #include -namespace raft::neighbors::ivf_pq::detail { +namespace raft::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ struct NoneSampleFilter { @@ -113,4 +113,4 @@ struct NoneSampleFilter { * } * }; */ -} // namespace raft::neighbors::ivf_pq::detail +} // namespace raft::neighbors::filtering diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 4dfa2a707c..08e24b8d53 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, SampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + SampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + float, float, int64_t, raft::neighbors::filtering::NoneSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 2d54248e4d..61eb62a3f4 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, SampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + SampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 75fe52f3c7..acb60387df 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, SampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + SampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 001281c8fc..0a3a0c48a7 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -15,21 +15,29 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, SampleFilterT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + SampleFilterT sample_filter) -instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(float, + int64_t, + raft::neighbors::filtering::NoneSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, + int64_t, + raft::neighbors::filtering::NoneSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, + int64_t, + raft::neighbors::filtering::NoneSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index ac547626bb..c8ce07b194 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -104,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::ivf_pq::detail::NoneSampleFilter);\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::NoneSampleFilter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 67b67df19f..0a2f230fef 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -72,7 +72,7 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, float, raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index 1c97a1c9ba..ded0e9cc2d 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 14e2d19fe7..007db45abd 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 7fd3a8d0b2..b441165c5a 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -72,7 +72,7 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, half, raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 01df4d87e3..2a6aadc92b 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 251515a552..6d0a4e0074 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::NoneSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index b29f4bca96..e495944ee8 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -72,7 +72,7 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + half, half, raft::neighbors::filtering::NoneSampleFilter); #undef COMMA From 0b700c0d176f4dc209a5abd0898ee8b031b78916 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Sun, 21 May 2023 18:18:01 +0000 Subject: [PATCH 2/6] Add Ivf to names of filter classes and templates --- .../detail/ivf_flat_interleaved_scan-ext.cuh | 46 +++---- .../detail/ivf_flat_interleaved_scan-inl.cuh | 75 ++++++----- .../neighbors/detail/ivf_flat_search-ext.cuh | 36 +++--- .../neighbors/detail/ivf_flat_search-inl.cuh | 38 +++--- .../detail/ivf_pq_compute_similarity-ext.cuh | 121 +++++++++--------- .../detail/ivf_pq_compute_similarity-inl.cuh | 60 ++++----- .../raft/neighbors/detail/ivf_pq_search.cuh | 38 +++--- cpp/include/raft/neighbors/detail/refine.cuh | 2 +- cpp/include/raft/neighbors/ivf_flat-ext.cuh | 8 +- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 36 +++--- cpp/include/raft/neighbors/ivf_pq-ext.cuh | 8 +- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 22 ++-- cpp/include/raft/neighbors/sample_filter.cuh | 32 ++--- ...at_interleaved_scan_float_float_int64_t.cu | 36 +++--- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 36 +++--- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 36 +++--- cpp/src/neighbors/detail/ivf_flat_search.cu | 30 ++--- .../ivf_pq_compute_similarity_00_generate.py | 14 +- .../ivf_pq_compute_similarity_float_float.cu | 84 ++++++------ ...f_pq_compute_similarity_float_fp8_false.cu | 84 ++++++------ ...vf_pq_compute_similarity_float_fp8_true.cu | 84 ++++++------ .../ivf_pq_compute_similarity_float_half.cu | 84 ++++++------ ...vf_pq_compute_similarity_half_fp8_false.cu | 84 ++++++------ ...ivf_pq_compute_similarity_half_fp8_true.cu | 84 ++++++------ .../ivf_pq_compute_similarity_half_half.cu | 84 ++++++------ 25 files changed, 638 insertions(+), 624 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 6a47ce823e..0131257e51 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -18,7 +18,7 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index -#include // NoneSampleFilter +#include // NoneIvfSampleFilter #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view @@ -26,7 +26,7 @@ namespace raft::neighbors::ivf_flat::detail { -template +template void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, const T* queries, const uint32_t* coarse_query_results, @@ -36,7 +36,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t n_probes, const uint32_t k, const bool select_min, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -46,30 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, SampleFilterT) \ - extern template void \ - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - SampleFilterT sample_filter, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + extern template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::NoneSampleFilter); + float, float, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); + int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 7269633070..289aae9618 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -656,7 +656,7 @@ template __global__ void __launch_bounds__(kThreadsPerBlock) @@ -672,7 +672,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const uint32_t n_probes, const uint32_t k, const uint32_t dim, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances) { @@ -807,7 +807,7 @@ template void launch_kernel(Lambda lambda, @@ -819,7 +819,7 @@ void launch_kernel(Lambda lambda, const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -833,7 +833,7 @@ void launch_kernel(Lambda lambda, T, AccT, IdxT, - SampleFilterT, + IvfSampleFilterT, Lambda, PostLambda>; const int max_query_smem = 16384; @@ -947,7 +947,7 @@ template void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) { @@ -960,7 +960,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, - SampleFilterT, + IvfSampleFilterT, euclidean_dist, raft::identity_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::L2SqrtExpanded: @@ -971,7 +971,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, - SampleFilterT, + IvfSampleFilterT, euclidean_dist, raft::sqrt_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::InnerProduct: @@ -981,7 +981,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, - SampleFilterT, + IvfSampleFilterT, inner_prod_dist, raft::identity_op>({}, {}, std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. @@ -996,7 +996,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg template (1, 16 / sizeof(T))> struct select_interleaved_scan_kernel { @@ -1011,13 +1011,20 @@ struct select_interleaved_scan_kernel { { if constexpr (Capacity > 1) { if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel:: - run(capacity, veclen, select_min, std::forward(args)...); + return select_interleaved_scan_kernel::run(capacity, + veclen, + select_min, + std::forward(args)...); } } if constexpr (Veclen > 1) { if (veclen % Veclen != 0) { - return select_interleaved_scan_kernel::run( + return select_interleaved_scan_kernel::run( capacity, 1, select_min, std::forward(args)...); } } @@ -1031,10 +1038,10 @@ struct select_interleaved_scan_kernel { veclen == Veclen, "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); if (select_min) { - launch_with_fixed_consts( + launch_with_fixed_consts( std::forward(args)...); } else { - launch_with_fixed_consts( + launch_with_fixed_consts( std::forward(args)...); } } @@ -1068,10 +1075,10 @@ struct select_interleaved_scan_kernel { * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) * @param stream * @param sample_filter - * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * A filter that selects samples for a given query. Use an instance of NoneIvfSampleFilter to * provide a green light for every sample. */ -template +template void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, @@ -1081,29 +1088,29 @@ void ivfflat_interleaved_scan(const index& index, const uint32_t n_probes, const uint32_t k, const bool select_min, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - queries_offset, - n_probes, - k, - sample_filter, - neighbors, - distances, - grid_dim_x, - stream); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + queries_offset, + n_probes, + k, + sample_filter, + neighbors, + distances, + grid_dim_x, + stream); } } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index 98cc0552ce..d9083ca57d 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -18,14 +18,14 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index -#include // NoneSampleFilter +#include // NoneIvfSampleFilter #include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::ivf_flat::detail { -template +template void search(raft::resources const& handle, const search_params& params, const raft::neighbors::ivf_flat::index& index, @@ -35,33 +35,33 @@ void search(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, SampleFilterT) \ - extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr, \ - SampleFilterT sample_filter) +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index df5c6db615..ccbca0c13d 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,7 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index -#include // NoneSampleFilter +#include // NoneIvfSampleFilter #include // utils::mapping #include // rmm::device_memory_resource @@ -34,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -template +template void search_impl(raft::resources const& handle, const raft::neighbors::ivf_flat::index& index, const T* queries, @@ -46,7 +46,7 @@ void search_impl(raft::resources const& handle, IdxT* neighbors, AccT* distances, rmm::mr::device_memory_resource* search_mr, - SampleFilterT sample_filter) + IvfSampleFilterT sample_filter) { auto stream = resource::get_cuda_stream(handle); // The norm of query @@ -146,7 +146,7 @@ void search_impl(raft::resources const& handle, uint32_t grid_dim_x = 0; if (n_probes > 1) { // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT, SampleFilterT>( + ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, nullptr, nullptr, @@ -170,7 +170,7 @@ void search_impl(raft::resources const& handle, indices_dev_ptr = neighbors; } - ivfflat_interleaved_scan::value_t, IdxT, SampleFilterT>( + ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, queries, coarse_indices_dev.data(), @@ -207,7 +207,7 @@ void search_impl(raft::resources const& handle, /** See raft::neighbors::ivf_flat::search docs */ template + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -217,7 +217,7 @@ inline void search(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); @@ -242,18 +242,18 @@ inline void search(raft::resources const& handle, for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); - search_impl(handle, - index, - queries + offset_q * index.dim(), - queries_batch, - offset_q, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors + offset_q * k, - distances + offset_q * k, - mr, - sample_filter); + search_impl(handle, + index, + queries + offset_q * index.dim(), + queries_batch, + offset_q, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors + offset_q * k, + distances + offset_q * k, + mr, + sample_filter); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index f93b6610c3..4ed1f81c86 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -21,7 +21,7 @@ #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit #include // raft::neighbors::ivf_pq::codebook_gen -#include // NoneSampleFilter +#include // NoneIvfSampleFilter #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -37,7 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, @@ -100,7 +100,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) RAFT_EXPLICIT; @@ -119,7 +119,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -129,78 +129,79 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected RAFT_EXPLICIT; + uint32_t topk) + -> selected RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_pq::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - extern template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::filtering::NoneSampleFilter); + half, half, raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::filtering::NoneSampleFilter); + float, half, raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::filtering::NoneSampleFilter); + float, float, raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 4be9695d7e..38e2b62001 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -20,7 +20,7 @@ #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t #include // codebook_gen -#include // NoneSampleFilter +#include // NoneIvfSampleFilter #include // RAFT_CUDA_TRY #include // raft::atomicMin #include // raft::Pow2 @@ -229,7 +229,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * query_kths keep the current state of the filtering - atomically updated distances to the * k-th closest neighbors for each query [n_queries]. * @param sample_filter - * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * A filter that selects samples for a given query. Use an instance of NoneIvfSampleFilter to * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. @@ -246,7 +246,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, */ template + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); // The config struct lifts the runtime parameters to the template parameters template + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> struct compute_similarity_kernel_config { public: static auto get(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { return kernel_choose_bits(pq_bits, k_max); } private: static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { switch (pq_bits) { case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); @@ -529,7 +529,7 @@ struct compute_similarity_kernel_config { template static auto kernel_try_capacity(uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { if constexpr (Capacity > 0) { if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } @@ -539,7 +539,7 @@ struct compute_similarity_kernel_config { } return compute_similarity_kernel + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { return compute_similarity_kernel_config::get(pq_bits, k_max); + IvfSampleFilterT>::get(pq_bits, k_max); } /** Estimate the occupancy for the given kernel on the given device. */ -template +template struct occupancy_t { using shmem_unit = Pow2<128>; @@ -577,7 +577,7 @@ struct occupancy_t { inline occupancy_t() = default; inline occupancy_t(size_t smem, uint32_t n_threads, - compute_similarity_kernel_t kernel, + compute_similarity_kernel_t kernel, const cudaDeviceProp& dev_props) { RAFT_CUDA_TRY( @@ -588,9 +588,9 @@ struct occupancy_t { } }; -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; @@ -599,8 +599,8 @@ struct selected { template -void compute_similarity_run(selected s, + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, @@ -620,7 +620,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) @@ -666,7 +666,7 @@ void compute_similarity_run(selected s, */ template + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -676,7 +676,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected + uint32_t topk) -> selected { // Shared memory for storing the lookup table size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); @@ -748,9 +748,9 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further optimize occupancy and data locality for the L1 cache. */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), @@ -759,8 +759,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; + occupancy_t selected_perf{}; + selected selected_config; for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { if (smem_size_const > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. @@ -796,7 +796,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, continue; } - occupancy_t cur(smem_size, n_threads, kernel, dev_props); + occupancy_t cur(smem_size, n_threads, kernel, dev_props); if (cur.blocks_per_sm <= 0) { // For some reason, we still cannot make this kernel run. Skip the candidate. continue; @@ -811,7 +811,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp( + occupancy_t tmp( smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false; if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index ff61b98e02..1f74131bbc 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -415,7 +415,7 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, * 3. split the query batch into smaller chunks, so that the device workspace * is guaranteed to fit into GPU memory. */ -template +template void ivfpq_search_worker(raft::resources const& handle, const index& index, uint32_t max_samples, @@ -429,7 +429,7 @@ void ivfpq_search_worker(raft::resources const& handle, float* distances, // [n_queries, topK] float scaling_factor, double preferred_shmem_carveout, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, rmm::mr::device_memory_resource* mr) { auto stream = resource::get_cuda_stream(handle); @@ -531,17 +531,17 @@ void ivfpq_search_worker(raft::resources const& handle, } break; } - auto search_instance = - compute_similarity_select(resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + auto search_instance = compute_similarity_select( + resource::get_device_properties(handle), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -610,10 +610,10 @@ void ivfpq_search_worker(raft::resources const& handle, * This structure helps selecting a proper instance of the worker search function, * which contains a few template parameters. */ -template +template struct ivfpq_search { public: - using fun_t = decltype(&ivfpq_search_worker); + using fun_t = decltype(&ivfpq_search_worker); /** * Select an instance of the ivf-pq search function based on search tuning parameters, @@ -629,7 +629,7 @@ struct ivfpq_search { static auto filter_reasonable_instances(const search_params& params) -> fun_t { if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; + return ivfpq_search_worker; } else { RAFT_FAIL( "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " @@ -719,7 +719,7 @@ inline auto get_max_batch_size(uint32_t k, /** See raft::spatial::knn::ivf_pq::search docs */ template + typename IvfSampleFilterT = raft::neighbors::filtering::NoneIvfSampleFilter> inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -729,7 +729,7 @@ inline void search(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported element type."); @@ -789,7 +789,7 @@ inline void search(raft::resources const& handle, rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - auto search_instance = ivfpq_search::fun(params, index.metric()); + auto search_instance = ivfpq_search::fun(params, index.metric()); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 4a0f9312e1..6198e04fe4 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -136,7 +136,7 @@ void refine_device(raft::resources const& handle, k, raft::distance::is_min_close(metric), // TODO: add the filtering support - raft::neighbors::filtering::NoneSampleFilter(), + raft::neighbors::filtering::NoneIvfSampleFilter(), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index c29d8b00bb..848703c9b5 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -74,7 +74,7 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* index) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& index, @@ -84,7 +84,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template void search(raft::resources const& handle, @@ -97,14 +97,14 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template void search(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index ebbc3d21c8..6844879252 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -357,13 +357,15 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_flat::search_params search_params; - * filtering::NoneSampleFilter filter; + * filtering::NoneIvfSampleFilter filter; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations - * ivf_flat::search_with_filtering(handle, search_params, index, queries1, N1, K, out_inds1, - * out_dists1, &mr, filter); ivf_flat::search_with_filtering(handle, search_params, index, queries2, - * N2, K, out_inds2, out_dists2, &mr, filter); ivf_flat::search_with_filtering(handle, - * search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); * ... * @endcode * The exact size of the temporary buffer depends on multiple factors and is an implementation @@ -386,7 +388,7 @@ void extend(raft::resources const& handle, * enough memory pool here to avoid memory allocations within search). * @param[in] sample_filter a filter the greenlights samples for a given query */ -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& index, @@ -396,7 +398,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { raft::neighbors::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); @@ -464,7 +466,7 @@ void search(raft::resources const& handle, neighbors, distances, mr, - raft::neighbors::filtering::NoneSampleFilter()); + raft::neighbors::filtering::NoneIvfSampleFilter()); } /** @@ -485,13 +487,15 @@ void search(raft::resources const& handle, * ... * // use default search parameters * ivf_flat::search_params search_params; - * filtering::NoneSampleFilter filter; + * filtering::NoneIvfSampleFilter filter; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations - * ivf_flat::search_with_filtering(handle, search_params, index, queries1, out_inds1, out_dists1, - * filter); ivf_flat::search_with_filtering(handle, search_params, index, queries2, out_inds2, - * out_dists2, filter); ivf_flat::search_with_filtering(handle, search_params, index, queries3, - * out_inds3, out_dists3, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries1, out_inds1, out_dists1, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries2, out_inds2, out_dists2, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries3, out_inds3, out_dists3, filter); * ... * @endcode * @@ -507,14 +511,14 @@ void search(raft::resources const& handle, * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] * @param[in] sample_filter a filter the greenlights samples for a given query */ -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -584,7 +588,7 @@ void search(raft::resources const& handle, queries, neighbors, distances, - raft::neighbors::filtering::NoneSampleFilter()); + raft::neighbors::filtering::NoneIvfSampleFilter()); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index 5b7391569b..1595f55d8c 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -45,14 +45,14 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* idx) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter) RAFT_EXPLICIT; template void search(raft::resources const& handle, @@ -83,7 +83,7 @@ void extend(raft::resources const& handle, const IdxT* new_indices, IdxT n_rows) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const raft::neighbors::ivf_pq::search_params& params, const index& idx, @@ -93,7 +93,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template void search(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index 0937a215b0..527d6792e8 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -158,14 +158,14 @@ void extend(raft::resources const& handle, * k] * @param[in] sample_filter a filter the greenlights samples for a given query. */ -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -229,7 +229,7 @@ void search(raft::resources const& handle, queries, neighbors, distances, - raft::neighbors::filtering::NoneSampleFilter()); + raft::neighbors::filtering::NoneIvfSampleFilter()); } /** @} */ // end group ivf_pq @@ -358,13 +358,15 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_pq::search_params search_params; - * filtering::NoneSampleFilter filter; + * filtering::NoneIvfSampleFilter filter; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations - * ivf_pq::search_with_filtering(handle, search_params, index, queries1, N1, K, out_inds1, - * out_dists1, &mr, filter); ivf_pq::search_with_filtering(handle, search_params, index, queries2, - * N2, K, out_inds2, out_dists2, &mr, filter); ivf_pq::search_with_filtering(handle, search_params, - * index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); * ... * @endcode * The exact size of the temporary buffer depends on multiple factors and is an implementation @@ -387,7 +389,7 @@ void extend(raft::resources const& handle, * enough memory pool here to avoid memory allocations within search). * @param[in] sample_filter a filter the greenlights samples for a given query */ -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, @@ -397,7 +399,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { detail::search( handle, params, idx, queries, n_queries, k, neighbors, distances, mr, sample_filter); diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh index 9ceb2f5fc2..866fae2f8f 100644 --- a/cpp/include/raft/neighbors/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -22,7 +22,7 @@ namespace raft::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ -struct NoneSampleFilter { +struct NoneIvfSampleFilter { inline __device__ __host__ bool operator()( // query index const uint32_t query_ix, @@ -40,18 +40,18 @@ struct NoneSampleFilter { * filter template can be used: * * template - * struct IndexSampleFilter { + * struct IndexIvfSampleFilter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * - * IndexSampleFilter() {} - * IndexSampleFilter(const index_type* const* _inds_ptr) + * IndexIvfSampleFilter() {} + * IndexIvfSampleFilter(const index_type* const* _inds_ptr) * : inds_ptr{_inds_ptr} {} - * IndexSampleFilter(const IndexSampleFilter&) = default; - * IndexSampleFilter(IndexSampleFilter&&) = default; - * IndexSampleFilter& operator=(const IndexSampleFilter&) = default; - * IndexSampleFilter& operator=(IndexSampleFilter&&) = default; + * IndexIvfSampleFilter(const IndexIvfSampleFilter&) = default; + * IndexIvfSampleFilter(IndexIvfSampleFilter&&) = default; + * IndexIvfSampleFilter& operator=(const IndexIvfSampleFilter&) = default; + * IndexIvfSampleFilter& operator=(IndexIvfSampleFilter&&) = default; * * inline __device__ __host__ bool operator()( * const uint32_t query_ix, @@ -65,7 +65,7 @@ struct NoneSampleFilter { * }; * * Initialize it as: - * using filter_type = IndexSampleFilter; + * using filter_type = IndexIvfSampleFilter; * filter_type filter(raft_ivfpq_index.inds_ptrs().data_handle()); * * Use it as: @@ -78,25 +78,25 @@ struct NoneSampleFilter { * to a contiguous bit mask vector. * * template - * struct BitMaskSampleFilter { + * struct BitMaskIvfSampleFilter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * const uint64_t* const bit_mask_ptr = nullptr; * const int64_t bit_mask_stride_64 = 0; * - * BitMaskSampleFilter() {} - * BitMaskSampleFilter( + * BitMaskIvfSampleFilter() {} + * BitMaskIvfSampleFilter( * const index_type* const* _inds_ptr, * const uint64_t* const _bit_mask_ptr, * const int64_t _bit_mask_stride_64) * : inds_ptr{_inds_ptr}, * bit_mask_ptr{_bit_mask_ptr}, * bit_mask_stride_64{_bit_mask_stride_64} {} - * BitMaskSampleFilter(const BitMaskSampleFilter&) = default; - * BitMaskSampleFilter(BitMaskSampleFilter&&) = default; - * BitMaskSampleFilter& operator=(const BitMaskSampleFilter&) = default; - * BitMaskSampleFilter& operator=(BitMaskSampleFilter&&) = default; + * BitMaskIvfSampleFilter(const BitMaskIvfSampleFilter&) = default; + * BitMaskIvfSampleFilter(BitMaskIvfSampleFilter&&) = default; + * BitMaskIvfSampleFilter& operator=(const BitMaskIvfSampleFilter&) = default; + * BitMaskIvfSampleFilter& operator=(BitMaskIvfSampleFilter&&) = default; * * inline __device__ __host__ bool operator()( * const uint32_t query_ix, diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 08e24b8d53..2c980ce032 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -17,26 +17,26 @@ #include #include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, SampleFilterT) \ - template void \ - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - SampleFilterT sample_filter, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::NoneSampleFilter); + float, float, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 61eb62a3f4..ca87cc6824 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -17,26 +17,26 @@ #include #include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, SampleFilterT) \ - template void \ - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - SampleFilterT sample_filter, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); + int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index acb60387df..7f9f824fbe 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -17,26 +17,26 @@ #include #include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, SampleFilterT) \ - template void \ - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - SampleFilterT sample_filter, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneSampleFilter); + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 0a3a0c48a7..3e8ee2cbc0 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -17,27 +17,27 @@ #include #include -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, SampleFilterT) \ - template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr, \ - SampleFilterT sample_filter) +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index c8ce07b194..61d1a87b53 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -41,8 +41,8 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, SampleFilterT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, IvfSampleFilterT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ const cudaDeviceProp& dev_props, \\ bool manage_local_topk, \\ int locality_hint, \\ @@ -52,10 +52,10 @@ uint32_t precomp_data_count, \\ uint32_t n_queries, \\ uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ \\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ rmm::cuda_stream_view stream, \\ uint32_t n_rows, \\ uint32_t dim, \\ @@ -75,7 +75,7 @@ const float* queries, \\ const uint32_t* index_list, \\ float* query_kths, \\ - SampleFilterT sample_filter, \\ + IvfSampleFilterT sample_filter, \\ LutT* lut_scores, \\ OutT* _out_scores, \\ uint32_t* _out_indices); @@ -104,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::NoneSampleFilter);\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::NoneIvfSampleFilter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 0a2f230fef..b4916b2582 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::filtering::NoneSampleFilter); + float, float, raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index ded0e9cc2d..c239d71a47 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 007db45abd..365d4a066e 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index b441165c5a..14c3496000 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::filtering::NoneSampleFilter); + float, half, raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 2a6aadc92b..9e3c5656bd 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 6d0a4e0074..8b902ecbbd 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneSampleFilter); + raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index e495944ee8..81741976ae 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::filtering::NoneSampleFilter); + half, half, raft::neighbors::filtering::NoneIvfSampleFilter); #undef COMMA From 1a3d0b5f2961a514843717ad73fece240ad7e353 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Tue, 23 May 2023 11:47:38 +0000 Subject: [PATCH 3/6] Convert struct names to snake case --- .../detail/ivf_flat_interleaved_scan-ext.cuh | 8 ++--- .../detail/ivf_flat_interleaved_scan-inl.cuh | 2 +- .../neighbors/detail/ivf_flat_search-ext.cuh | 17 ++++------ .../neighbors/detail/ivf_flat_search-inl.cuh | 4 +-- .../detail/ivf_pq_compute_similarity-ext.cuh | 16 +++++----- .../detail/ivf_pq_compute_similarity-inl.cuh | 14 ++++---- .../raft/neighbors/detail/ivf_pq_search.cuh | 2 +- cpp/include/raft/neighbors/detail/refine.cuh | 2 +- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 8 ++--- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 4 +-- cpp/include/raft/neighbors/sample_filter.cuh | 32 +++++++++---------- ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- cpp/src/neighbors/detail/ivf_flat_search.cu | 15 ++++----- .../ivf_pq_compute_similarity_00_generate.py | 2 +- .../ivf_pq_compute_similarity_float_float.cu | 2 +- ...f_pq_compute_similarity_float_fp8_false.cu | 2 +- ...vf_pq_compute_similarity_float_fp8_true.cu | 2 +- .../ivf_pq_compute_similarity_float_half.cu | 2 +- ...vf_pq_compute_similarity_half_fp8_false.cu | 2 +- ...ivf_pq_compute_similarity_half_fp8_true.cu | 2 +- .../ivf_pq_compute_similarity_half_half.cu | 2 +- 23 files changed, 70 insertions(+), 76 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 0131257e51..c3047155aa 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -18,7 +18,7 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index -#include // NoneIvfSampleFilter +#include // none_ivf_sample_filter #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view @@ -66,10 +66,10 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); + float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); + int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 289aae9618..18f1906dc5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -1075,7 +1075,7 @@ struct select_interleaved_scan_kernel { * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) * @param stream * @param sample_filter - * A filter that selects samples for a given query. Use an instance of NoneIvfSampleFilter to + * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to * provide a green light for every sample. */ template diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index d9083ca57d..2d0eee55af 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -18,7 +18,7 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index -#include // NoneIvfSampleFilter +#include // none_ivf_sample_filter #include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -54,14 +54,11 @@ void search(raft::resources const& handle, rmm::mr::device_memory_resource* mr, \ IvfSampleFilterT sample_filter) -instantiate_raft_neighbors_ivf_flat_detail_search(float, - int64_t, - raft::neighbors::filtering::NoneIvfSampleFilter); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, - int64_t, - raft::neighbors::filtering::NoneIvfSampleFilter); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, - int64_t, - raft::neighbors::filtering::NoneIvfSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_search( + float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index ccbca0c13d..7ce64928fa 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,7 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index -#include // NoneIvfSampleFilter +#include // none_ivf_sample_filter #include // utils::mapping #include // rmm::device_memory_resource @@ -207,7 +207,7 @@ void search_impl(raft::resources const& handle, /** See raft::neighbors::ivf_flat::search docs */ template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> inline void search(raft::resources const& handle, const search_params& params, const index& index, diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 4ed1f81c86..555c6febb7 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -21,7 +21,7 @@ #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit #include // raft::neighbors::ivf_pq::codebook_gen -#include // NoneIvfSampleFilter +#include // none_ivf_sample_filter #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -183,25 +183,25 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::filtering::NoneIvfSampleFilter); + half, half, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::filtering::NoneIvfSampleFilter); + float, half, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::filtering::NoneIvfSampleFilter); + float, float, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 38e2b62001..9e03e8066a 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -20,7 +20,7 @@ #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t #include // codebook_gen -#include // NoneIvfSampleFilter +#include // none_ivf_sample_filter #include // RAFT_CUDA_TRY #include // raft::atomicMin #include // raft::Pow2 @@ -229,7 +229,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * query_kths keep the current state of the filtering - atomically updated distances to the * k-th closest neighbors for each query [n_queries]. * @param sample_filter - * A filter that selects samples for a given query. Use an instance of NoneIvfSampleFilter to + * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. @@ -495,7 +495,7 @@ __global__ void compute_similarity_kernel(uint32_t n_rows, // The signature of the kernel defined by a minimal set of template parameters template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> using compute_similarity_kernel_t = decltype(&compute_similarity_kernel); @@ -504,7 +504,7 @@ template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> struct compute_similarity_kernel_config { public: static auto get(uint32_t pq_bits, uint32_t k_max) @@ -554,7 +554,7 @@ template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t { @@ -599,7 +599,7 @@ struct selected { template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, @@ -666,7 +666,7 @@ void compute_similarity_run(selected s, */ template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 1f74131bbc..345e58eb12 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -719,7 +719,7 @@ inline auto get_max_batch_size(uint32_t k, /** See raft::spatial::knn::ivf_pq::search docs */ template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> inline void search(raft::resources const& handle, const search_params& params, const index& index, diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 6198e04fe4..b55c05a76f 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -136,7 +136,7 @@ void refine_device(raft::resources const& handle, k, raft::distance::is_min_close(metric), // TODO: add the filtering support - raft::neighbors::filtering::NoneIvfSampleFilter(), + raft::neighbors::filtering::none_ivf_sample_filter(), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 6844879252..a18ee065bf 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -357,7 +357,7 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_flat::search_params search_params; - * filtering::NoneIvfSampleFilter filter; + * filtering::none_ivf_sample_filter filter; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations * ivf_flat::search_with_filtering( @@ -466,7 +466,7 @@ void search(raft::resources const& handle, neighbors, distances, mr, - raft::neighbors::filtering::NoneIvfSampleFilter()); + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @@ -487,7 +487,7 @@ void search(raft::resources const& handle, * ... * // use default search parameters * ivf_flat::search_params search_params; - * filtering::NoneIvfSampleFilter filter; + * filtering::none_ivf_sample_filter filter; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations * ivf_flat::search_with_filtering( @@ -588,7 +588,7 @@ void search(raft::resources const& handle, queries, neighbors, distances, - raft::neighbors::filtering::NoneIvfSampleFilter()); + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index 527d6792e8..ad9d95f790 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -229,7 +229,7 @@ void search(raft::resources const& handle, queries, neighbors, distances, - raft::neighbors::filtering::NoneIvfSampleFilter()); + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @} */ // end group ivf_pq @@ -358,7 +358,7 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_pq::search_params search_params; - * filtering::NoneIvfSampleFilter filter; + * filtering::none_ivf_sample_filter filter; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations * ivf_pq::search_with_filtering( diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh index 866fae2f8f..4c89a5ee51 100644 --- a/cpp/include/raft/neighbors/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -22,7 +22,7 @@ namespace raft::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ -struct NoneIvfSampleFilter { +struct none_ivf_sample_filter { inline __device__ __host__ bool operator()( // query index const uint32_t query_ix, @@ -40,18 +40,18 @@ struct NoneIvfSampleFilter { * filter template can be used: * * template - * struct IndexIvfSampleFilter { + * struct index_ivf_sample_filter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * - * IndexIvfSampleFilter() {} - * IndexIvfSampleFilter(const index_type* const* _inds_ptr) + * index_ivf_sample_filter() {} + * index_ivf_sample_filter(const index_type* const* _inds_ptr) * : inds_ptr{_inds_ptr} {} - * IndexIvfSampleFilter(const IndexIvfSampleFilter&) = default; - * IndexIvfSampleFilter(IndexIvfSampleFilter&&) = default; - * IndexIvfSampleFilter& operator=(const IndexIvfSampleFilter&) = default; - * IndexIvfSampleFilter& operator=(IndexIvfSampleFilter&&) = default; + * index_ivf_sample_filter(const index_ivf_sample_filter&) = default; + * index_ivf_sample_filter(index_ivf_sample_filter&&) = default; + * index_ivf_sample_filter& operator=(const index_ivf_sample_filter&) = default; + * index_ivf_sample_filter& operator=(index_ivf_sample_filter&&) = default; * * inline __device__ __host__ bool operator()( * const uint32_t query_ix, @@ -65,7 +65,7 @@ struct NoneIvfSampleFilter { * }; * * Initialize it as: - * using filter_type = IndexIvfSampleFilter; + * using filter_type = index_ivf_sample_filter; * filter_type filter(raft_ivfpq_index.inds_ptrs().data_handle()); * * Use it as: @@ -78,25 +78,25 @@ struct NoneIvfSampleFilter { * to a contiguous bit mask vector. * * template - * struct BitMaskIvfSampleFilter { + * struct bitmask_ivf_sample_filter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * const uint64_t* const bit_mask_ptr = nullptr; * const int64_t bit_mask_stride_64 = 0; * - * BitMaskIvfSampleFilter() {} - * BitMaskIvfSampleFilter( + * bitmask_ivf_sample_filter() {} + * bitmask_ivf_sample_filter( * const index_type* const* _inds_ptr, * const uint64_t* const _bit_mask_ptr, * const int64_t _bit_mask_stride_64) * : inds_ptr{_inds_ptr}, * bit_mask_ptr{_bit_mask_ptr}, * bit_mask_stride_64{_bit_mask_stride_64} {} - * BitMaskIvfSampleFilter(const BitMaskIvfSampleFilter&) = default; - * BitMaskIvfSampleFilter(BitMaskIvfSampleFilter&&) = default; - * BitMaskIvfSampleFilter& operator=(const BitMaskIvfSampleFilter&) = default; - * BitMaskIvfSampleFilter& operator=(BitMaskIvfSampleFilter&&) = default; + * bitmask_ivf_sample_filter(const bitmask_ivf_sample_filter&) = default; + * bitmask_ivf_sample_filter(bitmask_ivf_sample_filter&&) = default; + * bitmask_ivf_sample_filter& operator=(const bitmask_ivf_sample_filter&) = default; + * bitmask_ivf_sample_filter& operator=(bitmask_ivf_sample_filter&&) = default; * * inline __device__ __host__ bool operator()( * const uint32_t query_ix, diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 2c980ce032..84ba3dbcb0 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -37,6 +37,6 @@ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); + float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index ca87cc6824..81c019fbdc 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -37,6 +37,6 @@ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); + int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 7f9f824fbe..744e603e9e 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -37,6 +37,6 @@ rmm::cuda_stream_view stream) instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::NoneIvfSampleFilter); + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 3e8ee2cbc0..0efc07eb6b 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -30,14 +30,11 @@ rmm::mr::device_memory_resource* mr, \ IvfSampleFilterT sample_filter) -instantiate_raft_neighbors_ivf_flat_detail_search(float, - int64_t, - raft::neighbors::filtering::NoneIvfSampleFilter); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, - int64_t, - raft::neighbors::filtering::NoneIvfSampleFilter); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, - int64_t, - raft::neighbors::filtering::NoneIvfSampleFilter); +instantiate_raft_neighbors_ivf_flat_detail_search( + float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index 61d1a87b53..19c3070fd2 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -104,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::NoneIvfSampleFilter);\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::none_ivf_sample_filter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index b4916b2582..46642b5595 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -72,7 +72,7 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::filtering::NoneIvfSampleFilter); + float, float, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index c239d71a47..03d9fb9171 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 365d4a066e..221be5b4fd 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 14c3496000..b665a37040 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -72,7 +72,7 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::filtering::NoneIvfSampleFilter); + float, half, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 9e3c5656bd..1acdab4c2e 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 8b902ecbbd..a8ad62c51b 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -74,7 +74,7 @@ instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::NoneIvfSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index 81741976ae..91a69b0e54 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -72,7 +72,7 @@ #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::filtering::NoneIvfSampleFilter); + half, half, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA From 2e68806d72ab2d603413593e3e9a70537e17b2ad Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 25 May 2023 18:28:05 -0400 Subject: [PATCH 4/6] Code cleanup for sample_filter_types.hpp --- .../{sample_filter.cuh => sample_filter_types.hpp} | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) rename cpp/include/raft/neighbors/{sample_filter.cuh => sample_filter_types.hpp} (95%) diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter_types.hpp similarity index 95% rename from cpp/include/raft/neighbors/sample_filter.cuh rename to cpp/include/raft/neighbors/sample_filter_types.hpp index 4c89a5ee51..5a301e9d2f 100644 --- a/cpp/include/raft/neighbors/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -19,11 +19,13 @@ #include #include +#include + namespace raft::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ struct none_ivf_sample_filter { - inline __device__ __host__ bool operator()( + inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, // the current inverted list index @@ -53,7 +55,7 @@ struct none_ivf_sample_filter { * index_ivf_sample_filter& operator=(const index_ivf_sample_filter&) = default; * index_ivf_sample_filter& operator=(index_ivf_sample_filter&&) = default; * - * inline __device__ __host__ bool operator()( + * inline _RAFT_HOST_DEVICE bool operator()( * const uint32_t query_ix, * const uint32_t cluster_ix, * const uint32_t sample_ix) const { @@ -98,7 +100,7 @@ struct none_ivf_sample_filter { * bitmask_ivf_sample_filter& operator=(const bitmask_ivf_sample_filter&) = default; * bitmask_ivf_sample_filter& operator=(bitmask_ivf_sample_filter&&) = default; * - * inline __device__ __host__ bool operator()( + * inline _RAFT_HOST_DEVICE bool operator()( * const uint32_t query_ix, * const uint32_t cluster_ix, * const uint32_t sample_ix) const { From 57aff05d8edb56ce0de8c762a35af07220e5153e Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Fri, 26 May 2023 21:58:31 +0000 Subject: [PATCH 5/6] Fix bad includes of sample_filter.cuh --- .../neighbors/detail/ivf_flat_interleaved_scan-ext.cuh | 10 +++++----- .../raft/neighbors/detail/ivf_flat_search-ext.cuh | 8 ++++---- .../raft/neighbors/detail/ivf_flat_search-inl.cuh | 2 +- .../neighbors/detail/ivf_pq_compute_similarity-ext.cuh | 2 +- .../neighbors/detail/ivf_pq_compute_similarity-inl.cuh | 2 +- cpp/include/raft/neighbors/detail/ivf_pq_search.cuh | 2 +- cpp/include/raft/neighbors/detail/refine.cuh | 2 +- .../ivf_flat_interleaved_scan_float_float_int64_t.cu | 2 +- ...ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...f_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- cpp/src/neighbors/detail/ivf_flat_search.cu | 2 +- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index c3047155aa..47f3e8888c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -16,11 +16,11 @@ #pragma once -#include // uintX_t -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index 2d0eee55af..976d15a61c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -16,10 +16,10 @@ #pragma once -#include // uintX_t -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 7ce64928fa..366e9bfcd5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,7 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter +#include // none_ivf_sample_filter #include // utils::mapping #include // rmm::device_memory_resource diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 555c6febb7..0ae2e23b63 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -21,7 +21,7 @@ #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit #include // raft::neighbors::ivf_pq::codebook_gen -#include // none_ivf_sample_filter +#include // none_ivf_sample_filter #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 9e03e8066a..2fefa900c3 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -20,7 +20,7 @@ #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t #include // codebook_gen -#include // none_ivf_sample_filter +#include // none_ivf_sample_filter #include // RAFT_CUDA_TRY #include // raft::atomicMin #include // raft::Pow2 diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 345e58eb12..8257f5ed35 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b55c05a76f..48bffb6a20 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 84ba3dbcb0..a1d6cca7d5 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ T, AccT, IdxT, IvfSampleFilterT) \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 81c019fbdc..514301562d 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ T, AccT, IdxT, IvfSampleFilterT) \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 744e603e9e..32698a8e80 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ T, AccT, IdxT, IvfSampleFilterT) \ diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 0efc07eb6b..9d39607750 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ template void raft::neighbors::ivf_flat::detail::search( \ From 9ca675b031d5c6c1c35cf45380dd6b8af5f932bc Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Wed, 31 May 2023 15:00:02 -0400 Subject: [PATCH 6/6] Removed an unneeded comment --- cpp/include/raft/neighbors/detail/refine.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 48bffb6a20..251f725361 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -135,7 +135,6 @@ void refine_device(raft::resources const& handle, 1, k, raft::distance::is_min_close(metric), - // TODO: add the filtering support raft::neighbors::filtering::none_ivf_sample_filter(), indices.data_handle(), distances.data_handle(),