From 9bf7b4bdf120638dd6525d60f3f60f95a393318c Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Mon, 5 Jun 2023 16:18:52 +0000 Subject: [PATCH 1/3] Add sample filtering for ivf_flat. Filtering code refactoring and cleanup (#1541) The PR does the following: * Introduces `ivf_flat::search_with_filtering()` call in the same way the filtering was introduced to ivf_pq in #1513 * Moves `sample_filter.cuh` from `raft/neighbor/detail` to `raft/neighbor` * Moves `NoneSampleFilter` from `raft::neighbor::ivf_pq::detail` namespace to `raft::neighbor::filtering` namespace * Renames `NoneSampleFilter` to `NoneIvfSampleFilter` and template argument `SampleFilterT` to `IvfSampleFilterT` * Adds a missing `resource::get_workspace_resource(handle)` in `ivf_flat-inl.cuh` in a `search_with_filtering()` call (which was copied from `search()` call with the same problem) * Adds more comments in `ivf_pq-inl.h` * Some code cleanup in `ivf_pq-inl.h` Authors: - Alexander Guzhva (https://github.com/alexanderguzhva) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1541 --- .../detail/ivf_flat_interleaved_scan-ext.cuh | 52 +++--- .../detail/ivf_flat_interleaved_scan-inl.cuh | 86 ++++++--- .../neighbors/detail/ivf_flat_search-ext.cuh | 46 ++--- .../neighbors/detail/ivf_flat_search-inl.cuh | 90 ++++++---- .../detail/ivf_pq_compute_similarity-ext.cuh | 121 ++++++------- .../detail/ivf_pq_compute_similarity-inl.cuh | 66 +++---- .../raft/neighbors/detail/ivf_pq_search.cuh | 42 ++--- cpp/include/raft/neighbors/detail/refine.cuh | 3 + cpp/include/raft/neighbors/ivf_flat-ext.cuh | 21 +++ cpp/include/raft/neighbors/ivf_flat-inl.cuh | 169 +++++++++++++++--- cpp/include/raft/neighbors/ivf_pq-ext.cuh | 8 +- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 64 ++++++- ...ple_filter.cuh => sample_filter_types.hpp} | 44 ++--- ...at_interleaved_scan_float_float_int64_t.cu | 34 ++-- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 34 ++-- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 34 ++-- cpp/src/neighbors/detail/ivf_flat_search.cu | 33 ++-- .../ivf_pq_compute_similarity_00_generate.py | 14 +- .../ivf_pq_compute_similarity_float_float.cu | 84 ++++----- ...f_pq_compute_similarity_float_fp8_false.cu | 84 ++++----- ...vf_pq_compute_similarity_float_fp8_true.cu | 84 ++++----- .../ivf_pq_compute_similarity_float_half.cu | 84 ++++----- ...vf_pq_compute_similarity_half_fp8_false.cu | 84 ++++----- ...ivf_pq_compute_similarity_half_fp8_true.cu | 84 ++++----- .../ivf_pq_compute_similarity_half_half.cu | 84 ++++----- 25 files changed, 928 insertions(+), 621 deletions(-) rename cpp/include/raft/neighbors/{detail/sample_filter.cuh => sample_filter_types.hpp} (70%) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 46f72c4005..47f3e8888c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -16,24 +16,27 @@ #pragma once -#include // uintX_t -#include // raft::neighbors::ivf_flat::index -#include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::ivf_flat::detail { -template +template void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, + const uint32_t queries_offset, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, const bool select_min, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -43,23 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + extern template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 4eed2aa453..18f1906dc5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -646,6 +646,7 @@ struct loadAndComputeDist { * @param n_probes * @param k * @param dim + * @param sample_filter * @param[out] neighbors * @param[out] distances */ @@ -655,6 +656,7 @@ template __global__ void __launch_bounds__(kThreadsPerBlock) @@ -666,9 +668,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, + const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, const uint32_t dim, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances) { @@ -736,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const bool valid = vec_id < list_length; // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { + if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) { loadAndComputeDist lc(dist, compute_dist); for (int pos = 0; pos < shm_assisted_dim; @@ -803,6 +807,7 @@ template void launch_kernel(Lambda lambda, @@ -811,8 +816,10 @@ void launch_kernel(Lambda lambda, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, + const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, @@ -820,8 +827,15 @@ void launch_kernel(Lambda lambda, { RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = - interleaved_scan_kernel; + constexpr auto kKernel = interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); @@ -860,9 +874,11 @@ void launch_kernel(Lambda lambda, index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), + queries_offset + query_offset, n_probes, k, index.dim(), + sample_filter, neighbors, distances); queries += grid_dim_y * index.dim(); @@ -931,6 +947,7 @@ template void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) { @@ -943,6 +960,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + IvfSampleFilterT, euclidean_dist, raft::identity_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::L2SqrtExpanded: @@ -953,6 +971,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + IvfSampleFilterT, euclidean_dist, raft::sqrt_op>({}, {}, std::forward(args)...); case raft::distance::DistanceType::InnerProduct: @@ -962,6 +981,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg T, AccT, IdxT, + IvfSampleFilterT, inner_prod_dist, raft::identity_op>({}, {}, std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. @@ -976,6 +996,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg template (1, 16 / sizeof(T))> struct select_interleaved_scan_kernel { @@ -990,13 +1011,20 @@ struct select_interleaved_scan_kernel { { if constexpr (Capacity > 1) { if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); + return select_interleaved_scan_kernel::run(capacity, + veclen, + select_min, + std::forward(args)...); } } if constexpr (Veclen > 1) { if (veclen % Veclen != 0) { - return select_interleaved_scan_kernel::run( + return select_interleaved_scan_kernel::run( capacity, 1, select_min, std::forward(args)...); } } @@ -1010,9 +1038,11 @@ struct select_interleaved_scan_kernel { veclen == Veclen, "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); if (select_min) { - launch_with_fixed_consts(std::forward(args)...); + launch_with_fixed_consts( + std::forward(args)...); } else { - launch_with_fixed_consts(std::forward(args)...); + launch_with_fixed_consts( + std::forward(args)...); } } }; @@ -1028,6 +1058,9 @@ struct select_interleaved_scan_kernel { * @param[in] queries device pointer to the query vectors [batch_size, dim] * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] * @param n_queries batch size + * @param[in] queries_offset + * An offset of the current query batch. It is used for feeding sample_filter with the + * correct query index. * @param metric type of the measured distance * @param n_probes number of nearest clusters to query * @param k number of nearest neighbors. @@ -1041,36 +1074,43 @@ struct select_interleaved_scan_kernel { * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) * @param stream + * @param sample_filter + * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to + * provide a green light for every sample. */ -template +template void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, + const uint32_t queries_offset, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, const bool select_min, + IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - n_probes, - k, - neighbors, - distances, - grid_dim_x, - stream); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + queries_offset, + n_probes, + k, + sample_filter, + neighbors, + distances, + grid_dim_x, + stream); } } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index b97e64a259..976d15a61c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -16,15 +16,16 @@ #pragma once -#include // uintX_t -#include // raft::neighbors::ivf_flat::index -#include // RAFT_EXPLICIT +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::ivf_flat::detail { -template +template void search(raft::resources const& handle, const search_params& params, const raft::neighbors::ivf_flat::index& index, @@ -33,26 +34,31 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) - -instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + IvfSampleFilterT sample_filter) + +instantiate_raft_neighbors_ivf_flat_detail_search( + float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 66ad9682d7..366e9bfcd5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -26,6 +26,7 @@ #include // matrix::detail::select_k #include // interleaved_scan #include // raft::neighbors::ivf_flat::index +#include // none_ivf_sample_filter #include // utils::mapping #include // rmm::device_memory_resource @@ -33,17 +34,19 @@ namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -template +template void search_impl(raft::resources const& handle, const raft::neighbors::ivf_flat::index& index, const T* queries, uint32_t n_queries, + uint32_t queries_offset, uint32_t k, uint32_t n_probes, bool select_min, IdxT* neighbors, AccT* distances, - rmm::mr::device_memory_resource* search_mr) + rmm::mr::device_memory_resource* search_mr, + IvfSampleFilterT sample_filter) { auto stream = resource::get_cuda_stream(handle); // The norm of query @@ -143,18 +146,21 @@ void search_impl(raft::resources const& handle, uint32_t grid_dim_x = 0; if (n_probes > 1) { // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT>(index, - nullptr, - nullptr, - n_queries, - index.metric(), - n_probes, - k, - select_min, - nullptr, - nullptr, - grid_dim_x, - stream); + ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( + index, + nullptr, + nullptr, + n_queries, + queries_offset, + index.metric(), + n_probes, + k, + select_min, + sample_filter, + nullptr, + nullptr, + grid_dim_x, + stream); } else { grid_dim_x = 1; } @@ -164,18 +170,21 @@ void search_impl(raft::resources const& handle, indices_dev_ptr = neighbors; } - ivfflat_interleaved_scan::value_t, IdxT>(index, - queries, - coarse_indices_dev.data(), - n_queries, - index.metric(), - n_probes, - k, - select_min, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); + ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( + index, + queries, + coarse_indices_dev.data(), + n_queries, + queries_offset, + index.metric(), + n_probes, + k, + select_min, + sample_filter, + indices_dev_ptr, + distances_dev_ptr, + grid_dim_x, + stream); RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); @@ -196,7 +205,9 @@ void search_impl(raft::resources const& handle, } /** See raft::neighbors::ivf_flat::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -205,7 +216,8 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); @@ -230,16 +242,18 @@ inline void search(raft::resources const& handle, for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); - search_impl(handle, - index, - queries + offset_q * index.dim(), - queries_batch, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors + offset_q * k, - distances + offset_q * k, - mr); + search_impl(handle, + index, + queries + offset_q * index.dim(), + queries_batch, + offset_q, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors + offset_q * k, + distances + offset_q * k, + mr, + sample_filter); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 62e46e3ae1..0ae2e23b63 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -20,8 +20,8 @@ #include // RAFT_WEAK_FUNCTION #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit -#include // NoneSampleFilter #include // raft::neighbors::ivf_pq::codebook_gen +#include // none_ivf_sample_filter #include // RAFT_EXPLICIT #include // rmm::cuda_stream_view @@ -37,7 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, @@ -100,7 +100,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) RAFT_EXPLICIT; @@ -119,7 +119,7 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -129,78 +129,79 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected RAFT_EXPLICIT; + uint32_t topk) + -> selected RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_pq::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - extern template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + half, half, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, half, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, float, raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index 37174f54e1..2fefa900c3 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -19,8 +19,8 @@ #include // raft::distance::DistanceType #include // matrix::detail::select::warpsort::warp_sort_distributed #include // dummy_block_sort_t -#include // NoneSampleFilter #include // codebook_gen +#include // none_ivf_sample_filter #include // RAFT_CUDA_TRY #include // raft::atomicMin #include // raft::Pow2 @@ -229,7 +229,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * query_kths keep the current state of the filtering - atomically updated distances to the * k-th closest neighbors for each query [n_queries]. * @param sample_filter - * A filter that selects samples for a given query. Use an instance of NoneSampleFilter to + * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. @@ -246,7 +246,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, */ template +template using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); + decltype(&compute_similarity_kernel); // The config struct lifts the runtime parameters to the template parameters template + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> struct compute_similarity_kernel_config { public: static auto get(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { return kernel_choose_bits(pq_bits, k_max); } private: static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { switch (pq_bits) { case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); @@ -527,7 +529,7 @@ struct compute_similarity_kernel_config { template static auto kernel_try_capacity(uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { if constexpr (Capacity > 0) { if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } @@ -537,7 +539,7 @@ struct compute_similarity_kernel_config { } return compute_similarity_kernel + typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter> auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t + -> compute_similarity_kernel_t { return compute_similarity_kernel_config::get(pq_bits, k_max); + IvfSampleFilterT>::get(pq_bits, k_max); } /** Estimate the occupancy for the given kernel on the given device. */ -template +template struct occupancy_t { using shmem_unit = Pow2<128>; @@ -575,7 +577,7 @@ struct occupancy_t { inline occupancy_t() = default; inline occupancy_t(size_t smem, uint32_t n_threads, - compute_similarity_kernel_t kernel, + compute_similarity_kernel_t kernel, const cudaDeviceProp& dev_props) { RAFT_CUDA_TRY( @@ -586,17 +588,19 @@ struct occupancy_t { } }; -template +template struct selected { - compute_similarity_kernel_t kernel; + compute_similarity_kernel_t kernel; dim3 grid_dim; dim3 block_dim; size_t smem_size; size_t device_lut_size; }; -template -void compute_similarity_run(selected s, +template +void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t n_rows, uint32_t dim, @@ -616,7 +620,7 @@ void compute_similarity_run(selected s, const float* queries, const uint32_t* index_list, float* query_kths, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, LutT* lut_scores, OutT* _out_scores, uint32_t* _out_indices) @@ -660,7 +664,9 @@ void compute_similarity_run(selected s, * beyond this limit do not consider increasing the number of active blocks per SM * would improve locality anymore. */ -template +template auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, @@ -670,7 +676,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, uint32_t precomp_data_count, uint32_t n_queries, uint32_t n_probes, - uint32_t topk) -> selected + uint32_t topk) -> selected { // Shared memory for storing the lookup table size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); @@ -742,9 +748,9 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further optimize occupancy and data locality for the L1 cache. */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; auto topk_or_zero = manage_local_topk ? topk : 0u; std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), @@ -753,8 +759,8 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, // we may allow slightly lower than 100% occupancy; constexpr double kTargetOccupancy = 0.75; // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; + occupancy_t selected_perf{}; + selected selected_config; for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { if (smem_size_const > dev_props.sharedMemPerBlockOptin) { // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. @@ -790,7 +796,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, continue; } - occupancy_t cur(smem_size, n_threads, kernel, dev_props); + occupancy_t cur(smem_size, n_threads, kernel, dev_props); if (cur.blocks_per_sm <= 0) { // For some reason, we still cannot make this kernel run. Skip the candidate. continue; @@ -805,7 +811,7 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props, if (n_threads_tmp < n_threads) { while (n_threads_tmp >= n_threads_min) { auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp( + occupancy_t tmp( smem_size_tmp, n_threads_tmp, kernel, dev_props); bool select_it = false; if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d402a2436b..8257f5ed35 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include #include @@ -415,7 +415,7 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, * 3. split the query batch into smaller chunks, so that the device workspace * is guaranteed to fit into GPU memory. */ -template +template void ivfpq_search_worker(raft::resources const& handle, const index& index, uint32_t max_samples, @@ -429,7 +429,7 @@ void ivfpq_search_worker(raft::resources const& handle, float* distances, // [n_queries, topK] float scaling_factor, double preferred_shmem_carveout, - SampleFilterT sample_filter, + IvfSampleFilterT sample_filter, rmm::mr::device_memory_resource* mr) { auto stream = resource::get_cuda_stream(handle); @@ -531,17 +531,17 @@ void ivfpq_search_worker(raft::resources const& handle, } break; } - auto search_instance = - compute_similarity_select(resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + auto search_instance = compute_similarity_select( + resource::get_device_properties(handle), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -610,10 +610,10 @@ void ivfpq_search_worker(raft::resources const& handle, * This structure helps selecting a proper instance of the worker search function, * which contains a few template parameters. */ -template +template struct ivfpq_search { public: - using fun_t = decltype(&ivfpq_search_worker); + using fun_t = decltype(&ivfpq_search_worker); /** * Select an instance of the ivf-pq search function based on search tuning parameters, @@ -629,7 +629,7 @@ struct ivfpq_search { static auto filter_reasonable_instances(const search_params& params) -> fun_t { if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; + return ivfpq_search_worker; } else { RAFT_FAIL( "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " @@ -717,7 +717,9 @@ inline auto get_max_batch_size(uint32_t k, } /** See raft::spatial::knn::ivf_pq::search docs */ -template +template inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -727,7 +729,7 @@ inline void search(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported element type."); @@ -787,7 +789,7 @@ inline void search(raft::resources const& handle, rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - auto search_instance = ivfpq_search::fun(params, index.metric()); + auto search_instance = ivfpq_search::fun(params, index.metric()); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 64f9511ff9..251f725361 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -129,10 +130,12 @@ void refine_device(raft::resources const& handle, queries.data_handle(), fake_coarse_idx.data(), static_cast(n_queries), + 0, refinement_index.metric(), 1, k, raft::distance::is_min_close(metric), + raft::neighbors::filtering::none_ivf_sample_filter(), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index dff7b6b2ab..848703c9b5 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -74,6 +74,18 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* index) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, @@ -85,6 +97,15 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + template void search(raft::resources const& handle, const search_params& params, diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 739e012e08..a18ee065bf 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -357,6 +357,69 @@ void extend(raft::resources const& handle, * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_flat::search_params search_params; + * filtering::none_ivf_sample_filter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) +{ + raft::neighbors::ivf_flat::detail::search( + handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); +} + +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); @@ -394,8 +457,16 @@ void search(raft::resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); + raft::neighbors::ivf_flat::detail::search(handle, + params, + index, + queries, + n_queries, + k, + neighbors, + distances, + mr, + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @@ -403,6 +474,74 @@ void search(raft::resources const& handle, * @{ */ +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * filtering::none_ivf_sample_filter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries1, out_inds1, out_dists1, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries2, out_inds2, out_dists2, filter); + * ivf_flat::search_with_filtering( + * handle, search_params, index, queries3, out_inds3, out_dists3, filter); + * ... + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + search_with_filtering(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(neighbors.extent(1)), + neighbors.data_handle(), + distances.data_handle(), + resource::get_workspace_resource(handle), + sample_filter); +} + /** * @brief Search ANN using the constructed index. * @@ -443,25 +582,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - nullptr); + search_with_filtering(handle, + params, + index, + queries, + neighbors, + distances, + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index 5b7391569b..1595f55d8c 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -45,14 +45,14 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* idx) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter) RAFT_EXPLICIT; template void search(raft::resources const& handle, @@ -83,7 +83,7 @@ void extend(raft::resources const& handle, const IdxT* new_indices, IdxT n_rows) RAFT_EXPLICIT; -template +template void search_with_filtering(raft::resources const& handle, const raft::neighbors::ivf_pq::search_params& params, const index& idx, @@ -93,7 +93,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) RAFT_EXPLICIT; + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template void search(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index fbe2fcb30d..ad9d95f790 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -158,14 +158,14 @@ void extend(raft::resources const& handle, * k] * @param[in] sample_filter a filter the greenlights samples for a given query. */ -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -223,8 +223,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { - search_with_filtering( - handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter()); + search_with_filtering(handle, + params, + idx, + queries, + neighbors, + distances, + raft::neighbors::filtering::none_ivf_sample_filter()); } /** @} */ // end group ivf_pq @@ -337,7 +342,54 @@ void extend(raft::resources const& handle, detail::extend(handle, idx, new_vectors, new_indices, n_rows); } -template +/** + * @brief Search ANN using the constructed index using the given filter. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_pq::search_params search_params; + * filtering::none_ivf_sample_filter filter; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); + * ivf_pq::search_with_filtering( + * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + * @param[in] sample_filter a filter the greenlights samples for a given query + */ +template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, @@ -347,7 +399,7 @@ void search_with_filtering(raft::resources const& handle, IdxT* neighbors, float* distances, rmm::mr::device_memory_resource* mr = nullptr, - SampleFilterT sample_filter = SampleFilterT()) + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { detail::search( handle, params, idx, queries, n_queries, k, neighbors, distances, mr, sample_filter); diff --git a/cpp/include/raft/neighbors/detail/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter_types.hpp similarity index 70% rename from cpp/include/raft/neighbors/detail/sample_filter.cuh rename to cpp/include/raft/neighbors/sample_filter_types.hpp index f5c3d91afe..5a301e9d2f 100644 --- a/cpp/include/raft/neighbors/detail/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -19,11 +19,13 @@ #include #include -namespace raft::neighbors::ivf_pq::detail { +#include + +namespace raft::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ -struct NoneSampleFilter { - inline __device__ __host__ bool operator()( +struct none_ivf_sample_filter { + inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, // the current inverted list index @@ -40,20 +42,20 @@ struct NoneSampleFilter { * filter template can be used: * * template - * struct IndexSampleFilter { + * struct index_ivf_sample_filter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * - * IndexSampleFilter() {} - * IndexSampleFilter(const index_type* const* _inds_ptr) + * index_ivf_sample_filter() {} + * index_ivf_sample_filter(const index_type* const* _inds_ptr) * : inds_ptr{_inds_ptr} {} - * IndexSampleFilter(const IndexSampleFilter&) = default; - * IndexSampleFilter(IndexSampleFilter&&) = default; - * IndexSampleFilter& operator=(const IndexSampleFilter&) = default; - * IndexSampleFilter& operator=(IndexSampleFilter&&) = default; + * index_ivf_sample_filter(const index_ivf_sample_filter&) = default; + * index_ivf_sample_filter(index_ivf_sample_filter&&) = default; + * index_ivf_sample_filter& operator=(const index_ivf_sample_filter&) = default; + * index_ivf_sample_filter& operator=(index_ivf_sample_filter&&) = default; * - * inline __device__ __host__ bool operator()( + * inline _RAFT_HOST_DEVICE bool operator()( * const uint32_t query_ix, * const uint32_t cluster_ix, * const uint32_t sample_ix) const { @@ -65,7 +67,7 @@ struct NoneSampleFilter { * }; * * Initialize it as: - * using filter_type = IndexSampleFilter; + * using filter_type = index_ivf_sample_filter; * filter_type filter(raft_ivfpq_index.inds_ptrs().data_handle()); * * Use it as: @@ -78,27 +80,27 @@ struct NoneSampleFilter { * to a contiguous bit mask vector. * * template - * struct BitMaskSampleFilter { + * struct bitmask_ivf_sample_filter { * using index_type = IdxT; * * const index_type* const* inds_ptr = nullptr; * const uint64_t* const bit_mask_ptr = nullptr; * const int64_t bit_mask_stride_64 = 0; * - * BitMaskSampleFilter() {} - * BitMaskSampleFilter( + * bitmask_ivf_sample_filter() {} + * bitmask_ivf_sample_filter( * const index_type* const* _inds_ptr, * const uint64_t* const _bit_mask_ptr, * const int64_t _bit_mask_stride_64) * : inds_ptr{_inds_ptr}, * bit_mask_ptr{_bit_mask_ptr}, * bit_mask_stride_64{_bit_mask_stride_64} {} - * BitMaskSampleFilter(const BitMaskSampleFilter&) = default; - * BitMaskSampleFilter(BitMaskSampleFilter&&) = default; - * BitMaskSampleFilter& operator=(const BitMaskSampleFilter&) = default; - * BitMaskSampleFilter& operator=(BitMaskSampleFilter&&) = default; + * bitmask_ivf_sample_filter(const bitmask_ivf_sample_filter&) = default; + * bitmask_ivf_sample_filter(bitmask_ivf_sample_filter&&) = default; + * bitmask_ivf_sample_filter& operator=(const bitmask_ivf_sample_filter&) = default; + * bitmask_ivf_sample_filter& operator=(bitmask_ivf_sample_filter&&) = default; * - * inline __device__ __host__ bool operator()( + * inline _RAFT_HOST_DEVICE bool operator()( * const uint32_t query_ix, * const uint32_t cluster_ix, * const uint32_t sample_ix) const { @@ -113,4 +115,4 @@ struct NoneSampleFilter { * } * }; */ -} // namespace raft::neighbors::ivf_pq::detail +} // namespace raft::neighbors::filtering diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 4dfa2a707c..a1d6cca7d5 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 2d54248e4d..514301562d 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 75fe52f3c7..32698a8e80 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -15,22 +15,28 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ + T, AccT, IdxT, IvfSampleFilterT) \ + template void \ + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const uint32_t queries_offset, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IvfSampleFilterT sample_filter, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( + uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 001281c8fc..9d39607750 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -15,21 +15,26 @@ */ #include +#include -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr, \ + IvfSampleFilterT sample_filter) -instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search( + float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); +instantiate_raft_neighbors_ivf_flat_detail_search( + uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); #undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index ac547626bb..19c3070fd2 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -41,8 +41,8 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, SampleFilterT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, IvfSampleFilterT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ const cudaDeviceProp& dev_props, \\ bool manage_local_topk, \\ int locality_hint, \\ @@ -52,10 +52,10 @@ uint32_t precomp_data_count, \\ uint32_t n_queries, \\ uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ \\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ rmm::cuda_stream_view stream, \\ uint32_t n_rows, \\ uint32_t dim, \\ @@ -75,7 +75,7 @@ const float* queries, \\ const uint32_t* index_list, \\ float* query_kths, \\ - SampleFilterT sample_filter, \\ + IvfSampleFilterT sample_filter, \\ LutT* lut_scores, \\ OutT* _out_scores, \\ uint32_t* _out_indices); @@ -104,6 +104,6 @@ path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::ivf_pq::detail::NoneSampleFilter);\n") + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::none_ivf_sample_filter);\n") f.write(trailer) print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 67b67df19f..46642b5595 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, float, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index 1c97a1c9ba..03d9fb9171 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 14e2d19fe7..221be5b4fd 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 7fd3a8d0b2..b665a37040 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + float, half, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 01df4d87e3..1acdab4c2e 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 251515a552..a8ad62c51b 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -27,54 +27,54 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::ivf_pq::detail::NoneSampleFilter); + raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index b29f4bca96..91a69b0e54 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -27,52 +27,52 @@ #include #include -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, SampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t n_rows, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - SampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ uint32_t* _out_indices); #define COMMA , instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter); + half, half, raft::neighbors::filtering::none_ivf_sample_filter); #undef COMMA From fc979fe6e4ec8030c07eda96c037e34e7c958b7a Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Tue, 6 Jun 2023 22:19:42 +0530 Subject: [PATCH 2/3] [HOTFIX] Fix distance metrics L2/cosine/correlation when X & Y are same buffer but with different shape and add unit test for such case. (#1571) -- This is how tiled_brute_force_knn may use pairwise distance API hence assuming when X == Y the buffer has same shape is incorrect. Authors: - Mahesh Doijade (https://github.com/mdoijade) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) --- cpp/include/raft/distance/detail/distance.cuh | 79 ++++++++------ cpp/test/distance/dist_correlation.cu | 23 ++++ cpp/test/distance/dist_cos.cu | 39 +++++++ cpp/test/distance/dist_l2_exp.cu | 40 +++++++ cpp/test/distance/distance_base.cuh | 102 ++++++++++++++++++ 5 files changed, 250 insertions(+), 33 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 7493c4e558..b6885808ce 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -126,9 +126,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // unused { - ASSERT( - !(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || (worksize < 2 * m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -137,9 +135,27 @@ void distance_impl(raft::resources const& handle, AccT* y_norm = workspace; AccT* sq_x_norm = workspace; AccT* sq_y_norm = workspace; - if (x != y) { + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + sq_x_norm += std::max(m, n); + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm( + sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream); + } else { y_norm += m; - raft::linalg::reduce(x_norm, x, k, @@ -167,21 +183,6 @@ void distance_impl(raft::resources const& handle, sq_y_norm = sq_x_norm + m; raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); - } else { - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - sq_x_norm += m; - sq_y_norm = sq_x_norm; - raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } using OpT = ops::correlation_distance_op; @@ -210,23 +211,25 @@ void distance_impl(raft::resources const& handle, "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); cudaStream_t stream = raft::resource::get_cuda_stream(handle); DataT* x_norm = workspace; DataT* y_norm = workspace; - if (x != y) { + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::rowNorm( + x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + } else { y_norm += m; raft::linalg::rowNorm( x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); raft::linalg::rowNorm( y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - } else { - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } ops::cosine_distance_op distance_op{}; @@ -453,21 +456,29 @@ void distance_impl_l2_expanded( // NOTE: different name "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); DataT* x_norm = workspace; DataT* y_norm = workspace; - if (x != y) { + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if ((x == y) && is_row_major) { + raft::linalg::rowNorm(x_norm, + x, + k, + std::max(m, n), + raft::linalg::L2Norm, + is_row_major, + stream, + raft::identity_op{}); + } else { y_norm += m; raft::linalg::rowNorm( x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); raft::linalg::rowNorm( y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - } else { - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } ops::l2_exp_distance_op distance_op{perform_sqrt}; @@ -789,8 +800,10 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; if (is_allocated) { + // TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input + // accuracy issue is resolved until then we allocate as m + n. worksize += numOfBuffers * m * sizeof(AccType); - if (x != y) worksize += numOfBuffers * n * sizeof(AccType); + worksize += numOfBuffers * n * sizeof(AccType); } return worksize; diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu index fc729dec1c..aa2866483a 100644 --- a/cpp/test/distance/dist_correlation.cu +++ b/cpp/test/distance/dist_correlation.cu @@ -24,6 +24,10 @@ template class DistanceCorrelation : public DistanceTest {}; +template +class DistanceCorrelationXequalY + : public DistanceTestSameBuffer {}; + const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, @@ -44,6 +48,25 @@ TEST_P(DistanceCorrelationF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf)); +typedef DistanceCorrelationXequalY DistanceCorrelationXequalYF; +TEST_P(DistanceCorrelationXequalYF, Result) +{ + int m = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + raft::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index 9e1cf5af17..caf55529ed 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -24,6 +24,10 @@ template class DistanceExpCos : public DistanceTest { }; +template +class DistanceExpCosXequalY + : public DistanceTestSameBuffer {}; + const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, @@ -34,6 +38,18 @@ const std::vector> inputsf = { {0.001f, 32, 1024, 1024, false, 1234ULL}, {0.003f, 1024, 1024, 1024, false, 1234ULL}, }; + +const std::vector> inputsXeqYf = { + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, +}; + typedef DistanceExpCos DistanceExpCosF; TEST_P(DistanceExpCosF, Result) { @@ -44,6 +60,29 @@ TEST_P(DistanceExpCosF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf)); +typedef DistanceExpCosXequalY DistanceExpCosXequalYF; +TEST_P(DistanceExpCosXequalYF, Result) +{ + int m = params.m; + int n = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + n, + raft::CompareApprox(params.tolerance), + stream)); + n = params.isRowMajor ? m : m / 2; + m = params.isRowMajor ? m / 2 : m; + + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m, + n, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsXeqYf)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu index 6b6a290386..7bdbb44362 100644 --- a/cpp/test/distance/dist_l2_exp.cu +++ b/cpp/test/distance/dist_l2_exp.cu @@ -24,6 +24,10 @@ template class DistanceEucExpTest : public DistanceTest { }; +template +class DistanceEucExpTestXequalY + : public DistanceTestSameBuffer {}; + const std::vector> inputsf = { {0.001f, 2048, 4096, 128, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -37,6 +41,21 @@ const std::vector> inputsf = { {0.003f, 1024, 1024, 1024, false, 1234ULL}, {0.003f, 1021, 1021, 1021, false, 1234ULL}, }; + +const std::vector> inputsXeqYf = { + {0.01f, 2048, 4096, 128, true, 1234ULL}, + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.03f, 1021, 1021, 1021, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, + {0.03f, 1021, 1021, 1021, false, 1234ULL}, +}; + typedef DistanceEucExpTest DistanceEucExpTestF; TEST_P(DistanceEucExpTestF, Result) { @@ -47,6 +66,27 @@ TEST_P(DistanceEucExpTestF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf)); +typedef DistanceEucExpTestXequalY DistanceEucExpTestXequalYF; +TEST_P(DistanceEucExpTestXequalYF, Result) +{ + int m = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + raft::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, + DistanceEucExpTestXequalYF, + ::testing::ValuesIn(inputsXeqYf)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 6c7cab3f7b..20d78c7bb5 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -532,6 +532,108 @@ class DistanceTest : public ::testing::TestWithParam> { rmm::device_uvector x, y, dist_ref, dist, dist2; }; +/* + * This test suite verifies the path when X and Y are same buffer, + * distance metrics which requires norms like L2 expanded/cosine/correlation + * takes a more optimal path in such case to skip norm calculation for Y buffer. + * It may happen that though both X and Y are same buffer but user passes + * different dimensions for them like in case of tiled_brute_force_knn. + */ +template +class DistanceTestSameBuffer : public ::testing::TestWithParam> { + public: + using dev_vector = rmm::device_uvector; + DistanceTestSameBuffer() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + dist_ref({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), + dist({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), + dist2({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}) + { + } + + void SetUp() override + { + auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); + common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.m; + int k = params.k; + DataType metric_arg = params.metric_arg; + bool isRowMajor = params.isRowMajor; + if (distanceType == raft::distance::DistanceType::HellingerExpanded || + distanceType == raft::distance::DistanceType::JensenShannon || + distanceType == raft::distance::DistanceType::KLDivergence) { + // Hellinger works only on positive numbers + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + // Russel rao works on boolean values. + bernoulli(handle, r, x.data(), m * k, 0.5f); + } else { + uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); + } + + for (int i = 0; i < 2; i++) { + // both X and Y are same buffer but when i = 1 + // different dimensions for x & y is passed. + m = m / (i + 1); + naiveDistance(dist_ref[i].data(), + x.data(), + x.data(), + m, + n, + k, + distanceType, + isRowMajor, + metric_arg, + stream); + + DataType threshold = -10000.f; + + if (isRowMajor) { + distanceLauncher(handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); + + } else { + distanceLauncher(handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); + } + } + resource::sync_stream(handle, stream); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + DistanceInputs params; + dev_vector x; + static const int N = 2; + std::array dist_ref, dist, dist2; +}; + template class BigMatrixDistanceTest : public ::testing::Test { public: From 64f15b6de30b24116b596531a39b744f4afc27f0 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Tue, 6 Jun 2023 14:54:48 -0500 Subject: [PATCH 3/3] Unpin `dask` and `distributed` for development and fix `merge_labels` test (#1574) This PR unpins `dask` and `distributed` to `>=2023.5.1` for `23.08` development. xref: https://github.com/rapidsai/cudf/pull/13508 The offending test was using an rmm::device_scalar for some memory that should have been a vector. Not sure how this didn't fail in the past but these changes fix it. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ray Douglass (https://github.com/raydouglass) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1574 --- .github/workflows/pr.yaml | 4 ++-- .github/workflows/test.yaml | 4 ++-- conda/environments/all_cuda-118_arch-x86_64.yaml | 6 +++--- conda/recipes/raft-dask/meta.yaml | 6 +++--- cpp/test/label/merge_labels.cu | 4 +++- dependencies.yaml | 6 +++--- python/raft-dask/pyproject.toml | 4 ++-- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index f1153e7a41..fa18b63137 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -100,7 +100,7 @@ jobs: build_type: pull-request package-name: raft_dask # Always want to test against latest dask/distributed. - test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" - test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" + test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" + test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test" test-smoketest: "python ./ci/wheel_smoke_test_raft_dask.py" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 22be2ed01c..533f540304 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -49,6 +49,6 @@ jobs: date: ${{ inputs.date }} sha: ${{ inputs.sha }} package-name: raft_dask - test-before-amd64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" - test-before-arm64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" + test-before-amd64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" + test-before-arm64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.08" test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test" diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 9a7cee2821..2125bf6683 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -18,10 +18,10 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=0.29,<0.30 -- dask-core==2023.3.2 +- dask-core>=2023.5.1 - dask-cuda==23.8.* -- dask==2023.3.2 -- distributed==2023.3.2.1 +- dask>=2023.5.1 +- distributed>=2023.5.1 - doxygen>=1.8.20 - gcc_linux-64=11.* - gmock>=1.13.0 diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index cd08deabfa..26c0eed4f9 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -46,10 +46,10 @@ requirements: run: - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} - cuda-python >=11.7.1,<12.0 - - dask ==2023.3.2 - - dask-core ==2023.3.2 + - dask >=2023.5.1 + - dask-core >=2023.5.1 - dask-cuda ={{ minor_version }} - - distributed ==2023.3.2.1 + - distributed >=2023.5.1 - joblib >=0.11 - nccl >=2.9.9 - pylibraft {{ version }} diff --git a/cpp/test/label/merge_labels.cu b/cpp/test/label/merge_labels.cu index 022581c655..3e12f9171e 100644 --- a/cpp/test/label/merge_labels.cu +++ b/cpp/test/label/merge_labels.cu @@ -75,7 +75,9 @@ class MergeLabelsTest : public ::testing::TestWithParam params; rmm::device_uvector labels_a, labels_b, expected, R; - rmm::device_scalar mask, m; + rmm::device_uvector mask; + + rmm::device_scalar m; }; using MergeLabelsTestI = MergeLabelsTest; diff --git a/dependencies.yaml b/dependencies.yaml index b81710b5f8..24b70c588a 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -270,16 +270,16 @@ dependencies: common: - output_types: [conda, pyproject] packages: - - dask==2023.3.2 + - dask>=2023.5.1 - dask-cuda==23.8.* - - distributed==2023.3.2.1 + - distributed>=2023.5.1 - joblib>=0.11 - numba>=0.57 - *numpy - ucx-py==0.33.* - output_types: conda packages: - - dask-core==2023.3.2 + - dask-core>=2023.5.1 - ucx>=1.13.0 - ucx-proc=*=gpu - output_types: pyproject diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index d2fe9b3667..285b8a86d8 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -35,8 +35,8 @@ license = { text = "Apache 2.0" } requires-python = ">=3.9" dependencies = [ "dask-cuda==23.8.*", - "dask==2023.3.2", - "distributed==2023.3.2.1", + "dask>=2023.5.1", + "distributed>=2023.5.1", "joblib>=0.11", "numba>=0.57", "numpy>=1.21",