Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate max cluster size correctly for IVF-PQ #938

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <thrust/binary_search.h>
#include <thrust/extrema.h>
#include <thrust/functional.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
Expand Down Expand Up @@ -430,26 +431,21 @@ auto calculate_offsets_and_indices(IdxT n_rows,
IdxT* data_indices,
rmm::cuda_stream_view stream) -> uint32_t
{
auto exec_policy = rmm::exec_policy(stream);
uint32_t max_cluster_size = 0;
rmm::device_scalar<uint32_t> max_cluster_size_dev_buf(stream);
auto max_cluster_size_dev = max_cluster_size_dev_buf.data();
update_device(max_cluster_size_dev, &max_cluster_size, 1, stream);
auto exec_policy = rmm::exec_policy(stream);
// Calculate the offsets
IdxT cumsum = 0;
update_device(cluster_offsets, &cumsum, 1, stream);
thrust::inclusive_scan(exec_policy,
cluster_sizes,
cluster_sizes + n_lists,
cluster_offsets + 1,
tfeher marked this conversation as resolved.
Show resolved Hide resolved
[max_cluster_size_dev] __device__(IdxT s, uint32_t l) {
atomicMax(max_cluster_size_dev, l);
return s + l;
});
[] __device__(IdxT s, uint32_t l) { return s + l; });
tfeher marked this conversation as resolved.
Show resolved Hide resolved
update_host(&cumsum, cluster_offsets + n_lists, 1, stream);
update_host(&max_cluster_size, max_cluster_size_dev, 1, stream);
uint32_t max_cluster_size =
*thrust::max_element(exec_policy, cluster_sizes, cluster_sizes + n_lists);
stream.synchronize();
RAFT_EXPECTS(cumsum == n_rows, "cluster sizes do not add up.");
RAFT_LOG_DEBUG("Max cluster size %d", max_cluster_size);
rmm::device_uvector<IdxT> data_offsets_buf(n_lists, stream);
auto data_offsets = data_offsets_buf.data();
copy(data_offsets, cluster_offsets, n_lists, stream);
Expand Down Expand Up @@ -554,7 +550,7 @@ void train_per_cluster(const handle_t& handle,
auto cluster_offsets = offsets_buf.data();
auto indices = indices_buf.data();
uint32_t max_cluster_size = calculate_offsets_and_indices(
n_rows, index.n_lists(), labels, cluster_sizes.data(), cluster_offsets, indices, stream);
IdxT(n_rows), index.n_lists(), labels, cluster_sizes.data(), cluster_offsets, indices, stream);

rmm::device_uvector<uint32_t> pq_labels(max_cluster_size * index.pq_dim(), stream, device_memory);
rmm::device_uvector<uint32_t> pq_cluster_sizes(index.pq_book_size(), stream, device_memory);
Expand Down