Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ivf-flat: limit the workspace size of the search via batching #1515

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,33 @@ inline void search(raft::device_resources const& handle,
"n_probes (number of clusters to probe in the search) must be positive.");
auto n_probes = std::min<uint32_t>(params.n_probes, index.n_lists());

auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16);
// a batch size heuristic: try to keep the workspace within the specified size
constexpr uint32_t kExpectedWsSize = 1024 * 1024 * 1024;
const uint32_t max_queries =
std::min<uint32_t>(n_queries,
raft::div_rounding_up_safe<uint64_t>(
kExpectedWsSize, 16ull * uint64_t{n_probes} * k + 4ull * index.dim()));
tfeher marked this conversation as resolved.
Show resolved Hide resolved

auto pool_guard = raft::get_pool_memory_resource(mr, max_queries * n_probes * k * 16);
if (pool_guard) {
RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes",
n_queries * n_probes * k * 16ull);
}

return search_impl<T, float, IdxT>(handle,
index,
queries,
n_queries,
k,
n_probes,
raft::distance::is_min_close(index.metric()),
neighbors,
distances,
mr);
for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);

search_impl<T, float, IdxT>(handle,
index,
queries + offset_q * index.dim(),
queries_batch,
k,
n_probes,
raft::distance::is_min_close(index.metric()),
neighbors + offset_q * k,
distances + offset_q * k,
mr);
}
}

} // namespace raft::neighbors::ivf_flat::detail
3 changes: 2 additions & 1 deletion cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ const std::vector<AnnIvfFlatInputs<int64_t>> inputs = {

// test splitting the big query batches (> max gridDim.y) into smaller batches
{100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, false},
{1000000, 1024, 32, 10, 256, 256, raft::distance::DistanceType::InnerProduct, false},
{98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, true},

// test radix_sort for getting the cluster selection
Expand All @@ -353,4 +354,4 @@ const std::vector<AnnIvfFlatInputs<int64_t>> inputs = {
raft::distance::DistanceType::InnerProduct,
false}};

} // namespace raft::neighbors::ivf_flat
} // namespace raft::neighbors::ivf_flat