From cf4e03d0b952c1baac73f695f94d6482d8c391d8 Mon Sep 17 00:00:00 2001 From: Rui Lan Date: Fri, 8 Dec 2023 11:47:07 -0800 Subject: [PATCH] Add subsample support for PQ codebook generation. More benchmark needed. --- .../raft/neighbors/detail/ivf_pq_build.cuh | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index e57133fc23..c2d1b94e0f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -404,8 +404,11 @@ void train_per_subset(raft::resources const& handle, auto device_memory = resource::get_workspace_resource(handle); rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); - rmm::device_uvector sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory); - rmm::device_uvector sub_labels(n_rows, stream, device_memory); + // Subsampling the train set for codebook generation. Using similar subsampling strategy as train_per_cluster + size_t big_enough = 256ul * std::max(index.pq_book_size(), index.pq_dim()); + auto pq_n_rows = uint32_t(std::min(big_enough, n_rows)); + rmm::device_uvector sub_trainset(pq_n_rows * size_t(index.pq_len()), stream, device_memory); + rmm::device_uvector sub_labels(pq_n_rows, stream, device_memory); rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); @@ -416,7 +419,7 @@ void train_per_subset(raft::resources const& handle, // Get the rotated cluster centers for each training vector. // This will be subtracted from the input vectors afterwards. utils::copy_selected( - n_rows, + pq_n_rows, index.pq_len(), index.centers_rot().data_handle() + index.pq_len() * j, labels, @@ -432,7 +435,7 @@ void train_per_subset(raft::resources const& handle, true, false, index.pq_len(), - n_rows, + pq_n_rows, index.dim(), &alpha, index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j, @@ -446,12 +449,12 @@ void train_per_subset(raft::resources const& handle, // train PQ codebook for this subspace auto sub_trainset_view = - raft::make_device_matrix_view(sub_trainset.data(), n_rows, index.pq_len()); + raft::make_device_matrix_view(sub_trainset.data(), pq_n_rows, index.pq_len()); auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, index.pq_book_size(), index.pq_len()); - auto sub_labels_view = raft::make_device_vector_view(sub_labels.data(), n_rows); + auto sub_labels_view = raft::make_device_vector_view(sub_labels.data(), pq_n_rows); auto cluster_sizes_view = raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params;