diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index c7e3798f5d..44b88a0b23 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -24,6 +24,7 @@ #include #include +#include namespace raft::neighbors::ivf_flat { @@ -37,6 +38,19 @@ struct index_params : ann::index_params { uint32_t kmeans_n_iters = 20; /** The fraction of data to use during iterative kmeans building. */ double kmeans_trainset_fraction = 0.5; + /** + * By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`, + * and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index + * from scratch after invoking (`ivf_flat::extend`) a few times with new data, the distribution of + * which is no longer representative of the original training set. + * + * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new + * data when it is added. In this case, `index.centers()` are always exactly the centroids of the + * data in the corresponding clusters. The drawback of this behavior is that the centroids depend + * on the order of adding new data (through the classification of the added data); that is, + * `index.centers()` "drift" together with the changing distribution of the newly added data. + */ + bool adaptive_centers = false; }; struct search_params : ann::search_params { @@ -72,6 +86,11 @@ struct index : ann::index { { return metric_; } + /** Whether `centers()` change upon extending the index (ivf_pq::extend). */ + [[nodiscard]] constexpr inline auto adaptive_centers() const noexcept -> bool + { + return adaptive_centers_; + } /** * Inverted list data [size, dim]. * @@ -200,10 +219,15 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. It needs to be trained and then populated. */ - index(const handle_t& handle, raft::distance::DistanceType metric, uint32_t n_lists, uint32_t dim) + index(const handle_t& handle, + raft::distance::DistanceType metric, + uint32_t n_lists, + bool adaptive_centers, + uint32_t dim) : ann::index(), veclen_(calculate_veclen(dim)), metric_(metric), + adaptive_centers_(adaptive_centers), data_(make_device_mdarray(handle, make_extents(0, dim))), indices_(make_device_mdarray(handle, make_extents(0))), list_sizes_(make_device_mdarray(handle, make_extents(n_lists))), @@ -216,7 +240,7 @@ struct index : ann::index { /** Construct an empty index. It needs to be trained and then populated. */ index(const handle_t& handle, const index_params& params, uint32_t dim) - : index(handle, params.metric, params.n_lists, dim) + : index(handle, params.metric, params.n_lists, params.adaptive_centers, dim) { } @@ -242,6 +266,7 @@ struct index : ann::index { */ uint32_t veclen_; raft::distance::DistanceType metric_; + bool adaptive_centers_; device_mdarray, row_major> data_; device_mdarray, row_major> indices_; device_mdarray, row_major> list_sizes_; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 44972825e0..82d498a789 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include @@ -133,7 +135,8 @@ inline auto extend(const handle_t& handle, orig_index.metric(), stream); - index ext_index(handle, orig_index.metric(), n_lists, dim); + index ext_index( + handle, orig_index.metric(), n_lists, orig_index.adaptive_centers(), dim); auto list_sizes_ptr = ext_index.list_sizes().data_handle(); auto list_offsets_ptr = ext_index.list_offsets().data_handle(); @@ -141,19 +144,31 @@ inline auto extend(const handle_t& handle, // Calculate the centers and sizes on the new data, starting from the original values raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream); - raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); - kmeans::calc_centers_and_sizes(handle, - centers_ptr, - list_sizes_ptr, - n_lists, - dim, - new_vectors, - n_rows, - new_labels.data(), - false, - stream); + if (ext_index.adaptive_centers()) { + raft::copy( + list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); + kmeans::calc_centers_and_sizes(handle, + centers_ptr, + list_sizes_ptr, + n_lists, + dim, + new_vectors, + n_rows, + new_labels.data(), + false, + stream); + } else { + raft::stats::histogram(raft::stats::HistTypeAuto, + reinterpret_cast(list_sizes_ptr), + IdxT(n_lists), + new_labels.data(), + n_rows, + 1, + stream); + raft::linalg::add( + list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream); + } // Calculate new offsets IdxT index_size = 0; @@ -210,13 +225,20 @@ inline auto extend(const handle_t& handle, // Precompute the centers vector norms for L2Expanded distance if (ext_index.center_norms().has_value()) { - // todo(lsugy): use other prim and remove this one - utils::dots_along_rows(n_lists, - dim, - ext_index.centers().data_handle(), - ext_index.center_norms()->data_handle(), - stream); - RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); + if (!ext_index.adaptive_centers() && orig_index.center_norms().has_value()) { + raft::copy(ext_index.center_norms()->data_handle(), + orig_index.center_norms()->data_handle(), + orig_index.center_norms()->size(), + stream); + } else { + // todo(lsugy): use other prim and remove this one + utils::dots_along_rows(n_lists, + dim, + ext_index.centers().data_handle(), + ext_index.center_norms()->data_handle(), + stream); + RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); + } } // assemble the index diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 9a430e14f2..c57e8c7548 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -53,6 +54,7 @@ struct AnnIvfFlatInputs { IdxT nprobe; IdxT nlist; raft::distance::DistanceType metric; + bool adaptive_centers; }; template @@ -198,6 +200,45 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); handle_.sync_stream(stream_); + + // Test the centroid invariants + if (index_2.adaptive_centers()) { + // The centers must be up-to-date with the corresponding data + std::vector list_sizes(index_2.n_lists()); + std::vector list_offsets(index_2.n_lists()); + rmm::device_uvector centroid(ps.dim, stream_); + raft::copy( + list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); + raft::copy( + list_offsets.data(), index_2.list_offsets().data_handle(), index_2.n_lists(), stream_); + handle_.sync_stream(stream_); + for (uint32_t l = 0; l < index_2.n_lists(); l++) { + rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); + raft::spatial::knn::detail::utils::copy_selected( + (IdxT)list_sizes[l], + (IdxT)ps.dim, + database.data(), + index_2.indices().data_handle() + list_offsets[l], + (IdxT)ps.dim, + cluster_data.data(), + (IdxT)ps.dim, + stream_); + raft::stats::mean( + centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); + ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle() + ps.dim * l, + centroid.data(), + ps.dim, + raft::CompareApprox(0.001), + stream_)); + } + } else { + // The centers must be immutable + ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(), + index.centers().data_handle(), + index_2.centers().size(), + raft::Compare(), + stream_)); + } } ASSERT_TRUE(eval_neighbours(indices_naive, indices_ivfflat, @@ -243,44 +284,44 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) - {1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 3, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 3, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, // test dims that do not fit into kernel shared memory limits - {1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2049, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2050, 16, 40, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 2051, 16, 40, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 2052, 16, 40, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 2053, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 2056, 16, 40, 1024, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2049, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2050, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2051, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2052, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2053, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2056, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, // various random combinations - {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded}, - {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded}, - {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded}, - {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded}, - {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded}, - - {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::InnerProduct}, - {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::InnerProduct}, - {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct}, - {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct}, - {10000, 131072, 8, 10, 50, 1024, raft::distance::DistanceType::InnerProduct}, - - {1000, 10000, 4096, 20, 50, 1024, raft::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded, false}, + {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, + {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, + {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, + + {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::InnerProduct, false}, + {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::InnerProduct, true}, + {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct, true}, + {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct, false}, + {10000, 131072, 8, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true}, + + {1000, 10000, 4096, 20, 50, 1024, raft::distance::DistanceType::InnerProduct, false}, // test splitting the big query batches (> max gridDim.y) into smaller batches - {100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct}, - {98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct}, + {100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, false}, + {98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, true}, // test radix_sort for getting the cluster selection {1000, @@ -289,14 +330,16 @@ const std::vector> inputs = { 10, raft::spatial::knn::detail::topk::kMaxCapacity * 2, raft::spatial::knn::detail::topk::kMaxCapacity * 4, - raft::distance::DistanceType::L2Expanded}, + raft::distance::DistanceType::L2Expanded, + false}, {1000, 10000, 16, 10, raft::spatial::knn::detail::topk::kMaxCapacity * 4, raft::spatial::knn::detail::topk::kMaxCapacity * 4, - raft::distance::DistanceType::InnerProduct}}; + raft::distance::DistanceType::InnerProduct, + false}}; typedef AnnIVFFlatTest AnnIVFFlatTestF; TEST_P(AnnIVFFlatTestF, AnnIVFFlat) { this->testIVFFlat(); }