Skip to content

Commit

Permalink
Calculate max cluster size correctly for IVF-PQ (#938)
Browse files Browse the repository at this point in the history
This PR fixes the calculation of max cluster size. 

Previous calculation could return much larger than the actual values, which could lead to OOM while allocating temporary buffers while building index for large datasets.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #938
  • Loading branch information
tfeher authored Oct 24, 2022
1 parent 3cc4737 commit 73cd988
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <thrust/binary_search.h>
#include <thrust/extrema.h>
#include <thrust/functional.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
Expand Down Expand Up @@ -430,26 +431,18 @@ auto calculate_offsets_and_indices(IdxT n_rows,
IdxT* data_indices,
rmm::cuda_stream_view stream) -> uint32_t
{
auto exec_policy = rmm::exec_policy(stream);
uint32_t max_cluster_size = 0;
rmm::device_scalar<uint32_t> max_cluster_size_dev_buf(stream);
auto max_cluster_size_dev = max_cluster_size_dev_buf.data();
update_device(max_cluster_size_dev, &max_cluster_size, 1, stream);
auto exec_policy = rmm::exec_policy(stream);
// Calculate the offsets
IdxT cumsum = 0;
update_device(cluster_offsets, &cumsum, 1, stream);
thrust::inclusive_scan(exec_policy,
cluster_sizes,
cluster_sizes + n_lists,
cluster_offsets + 1,
[max_cluster_size_dev] __device__(IdxT s, uint32_t l) {
atomicMax(max_cluster_size_dev, l);
return s + l;
});
thrust::inclusive_scan(
exec_policy, cluster_sizes, cluster_sizes + n_lists, cluster_offsets + 1, thrust::plus<IdxT>{});
update_host(&cumsum, cluster_offsets + n_lists, 1, stream);
update_host(&max_cluster_size, max_cluster_size_dev, 1, stream);
uint32_t max_cluster_size =
*thrust::max_element(exec_policy, cluster_sizes, cluster_sizes + n_lists);
stream.synchronize();
RAFT_EXPECTS(cumsum == n_rows, "cluster sizes do not add up.");
RAFT_LOG_DEBUG("Max cluster size %d", max_cluster_size);
rmm::device_uvector<IdxT> data_offsets_buf(n_lists, stream);
auto data_offsets = data_offsets_buf.data();
copy(data_offsets, cluster_offsets, n_lists, stream);
Expand Down

0 comments on commit 73cd988

Please sign in to comment.