diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index c364118fdd..21288fb58a 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -213,22 +213,33 @@ inline void search(raft::device_resources const& handle, "n_probes (number of clusters to probe in the search) must be positive."); auto n_probes = std::min(params.n_probes, index.n_lists()); - auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); + // a batch size heuristic: try to keep the workspace within the specified size + constexpr uint32_t kExpectedWsSize = 1024 * 1024 * 1024; + const uint32_t max_queries = + std::min(n_queries, + raft::div_rounding_up_safe( + kExpectedWsSize, 16ull * uint64_t{n_probes} * k + 4ull * index.dim())); + + auto pool_guard = raft::get_pool_memory_resource(mr, max_queries * n_probes * k * 16); if (pool_guard) { RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", n_queries * n_probes * k * 16ull); } - return search_impl(handle, - index, - queries, - n_queries, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors, - distances, - mr); + for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { + uint32_t queries_batch = min(max_queries, n_queries - offset_q); + + search_impl(handle, + index, + queries + offset_q * index.dim(), + queries_batch, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors + offset_q * k, + distances + offset_q * k, + mr); + } } } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 4d90c3d7e4..1f5efbf7b9 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -333,6 +333,7 @@ const std::vector> inputs = { // test splitting the big query batches (> max gridDim.y) into smaller batches {100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, false}, + {1000000, 1024, 32, 10, 256, 256, raft::distance::DistanceType::InnerProduct, false}, {98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, true}, // test radix_sort for getting the cluster selection @@ -353,4 +354,4 @@ const std::vector> inputs = { raft::distance::DistanceType::InnerProduct, false}}; -} // namespace raft::neighbors::ivf_flat \ No newline at end of file +} // namespace raft::neighbors::ivf_flat