Skip to content

Commit

Permalink
IVF-Flat: make adaptive-centers behavior optional (#1019)
Browse files Browse the repository at this point in the history
Add an indexing parameter `adaptive_centers` (false by default) to control whether `index.centers()` should be kept up-to-date with the cluster data or remain immutable.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1019
  • Loading branch information
achirkin authored Nov 16, 2022
1 parent 464a11b commit c7e74bd
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 56 deletions.
29 changes: 27 additions & 2 deletions cpp/include/raft/neighbors/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/util/integer_utils.hpp>

#include <optional>
#include <type_traits>

namespace raft::neighbors::ivf_flat {

Expand All @@ -37,6 +38,19 @@ struct index_params : ann::index_params {
uint32_t kmeans_n_iters = 20;
/** The fraction of data to use during iterative kmeans building. */
double kmeans_trainset_fraction = 0.5;
/**
* By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`,
* and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index
* from scratch after invoking (`ivf_flat::extend`) a few times with new data, the distribution of
* which is no longer representative of the original training set.
*
* The alternative behavior (adaptive_centers = true) is to update the cluster centers for new
* data when it is added. In this case, `index.centers()` are always exactly the centroids of the
* data in the corresponding clusters. The drawback of this behavior is that the centroids depend
* on the order of adding new data (through the classification of the added data); that is,
* `index.centers()` "drift" together with the changing distribution of the newly added data.
*/
bool adaptive_centers = false;
};

struct search_params : ann::search_params {
Expand Down Expand Up @@ -72,6 +86,11 @@ struct index : ann::index {
{
return metric_;
}
/** Whether `centers()` change upon extending the index (ivf_pq::extend). */
[[nodiscard]] constexpr inline auto adaptive_centers() const noexcept -> bool
{
return adaptive_centers_;
}
/**
* Inverted list data [size, dim].
*
Expand Down Expand Up @@ -200,10 +219,15 @@ struct index : ann::index {
~index() = default;

/** Construct an empty index. It needs to be trained and then populated. */
index(const handle_t& handle, raft::distance::DistanceType metric, uint32_t n_lists, uint32_t dim)
index(const handle_t& handle,
raft::distance::DistanceType metric,
uint32_t n_lists,
bool adaptive_centers,
uint32_t dim)
: ann::index(),
veclen_(calculate_veclen(dim)),
metric_(metric),
adaptive_centers_(adaptive_centers),
data_(make_device_mdarray<T>(handle, make_extents<IdxT>(0, dim))),
indices_(make_device_mdarray<IdxT>(handle, make_extents<IdxT>(0))),
list_sizes_(make_device_mdarray<uint32_t>(handle, make_extents<uint32_t>(n_lists))),
Expand All @@ -216,7 +240,7 @@ struct index : ann::index {

/** Construct an empty index. It needs to be trained and then populated. */
index(const handle_t& handle, const index_params& params, uint32_t dim)
: index(handle, params.metric, params.n_lists, dim)
: index(handle, params.metric, params.n_lists, params.adaptive_centers, dim)
{
}

Expand All @@ -242,6 +266,7 @@ struct index : ann::index {
*/
uint32_t veclen_;
raft::distance::DistanceType metric_;
bool adaptive_centers_;
device_mdarray<T, extent_2d<IdxT>, row_major> data_;
device_mdarray<IdxT, extent_1d<IdxT>, row_major> indices_;
device_mdarray<uint32_t, extent_1d<uint32_t>, row_major> list_sizes_;
Expand Down
62 changes: 42 additions & 20 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/linalg/add.cuh>
#include <raft/stats/histogram.cuh>
#include <raft/util/pow2_utils.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -133,27 +135,40 @@ inline auto extend(const handle_t& handle,
orig_index.metric(),
stream);

index<T, IdxT> ext_index(handle, orig_index.metric(), n_lists, dim);
index<T, IdxT> ext_index(
handle, orig_index.metric(), n_lists, orig_index.adaptive_centers(), dim);

auto list_sizes_ptr = ext_index.list_sizes().data_handle();
auto list_offsets_ptr = ext_index.list_offsets().data_handle();
auto centers_ptr = ext_index.centers().data_handle();

// Calculate the centers and sizes on the new data, starting from the original values
raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream);
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);

kmeans::calc_centers_and_sizes(handle,
centers_ptr,
list_sizes_ptr,
n_lists,
dim,
new_vectors,
n_rows,
new_labels.data(),
false,
stream);
if (ext_index.adaptive_centers()) {
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);
kmeans::calc_centers_and_sizes(handle,
centers_ptr,
list_sizes_ptr,
n_lists,
dim,
new_vectors,
n_rows,
new_labels.data(),
false,
stream);
} else {
raft::stats::histogram<uint32_t, IdxT>(raft::stats::HistTypeAuto,
reinterpret_cast<int32_t*>(list_sizes_ptr),
IdxT(n_lists),
new_labels.data(),
n_rows,
1,
stream);
raft::linalg::add(
list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream);
}

// Calculate new offsets
IdxT index_size = 0;
Expand Down Expand Up @@ -210,13 +225,20 @@ inline auto extend(const handle_t& handle,

// Precompute the centers vector norms for L2Expanded distance
if (ext_index.center_norms().has_value()) {
// todo(lsugy): use other prim and remove this one
utils::dots_along_rows(n_lists,
dim,
ext_index.centers().data_handle(),
ext_index.center_norms()->data_handle(),
stream);
RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
if (!ext_index.adaptive_centers() && orig_index.center_norms().has_value()) {
raft::copy(ext_index.center_norms()->data_handle(),
orig_index.center_norms()->data_handle(),
orig_index.center_norms()->size(),
stream);
} else {
// todo(lsugy): use other prim and remove this one
utils::dots_along_rows(n_lists,
dim,
ext_index.centers().data_handle(),
ext_index.center_norms()->data_handle(),
stream);
RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}
}

// assemble the index
Expand Down
111 changes: 77 additions & 34 deletions cpp/test/neighbors/ann_ivf_flat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/spatial/knn/ann.cuh>
#include <raft/spatial/knn/ivf_flat.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/stats/mean.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
Expand Down Expand Up @@ -53,6 +54,7 @@ struct AnnIvfFlatInputs {
IdxT nprobe;
IdxT nlist;
raft::distance::DistanceType metric;
bool adaptive_centers;
};

template <typename T, typename DataT, typename IdxT>
Expand Down Expand Up @@ -198,6 +200,45 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_);
update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_);
handle_.sync_stream(stream_);

// Test the centroid invariants
if (index_2.adaptive_centers()) {
// The centers must be up-to-date with the corresponding data
std::vector<uint32_t> list_sizes(index_2.n_lists());
std::vector<IdxT> list_offsets(index_2.n_lists());
rmm::device_uvector<float> centroid(ps.dim, stream_);
raft::copy(
list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_);
raft::copy(
list_offsets.data(), index_2.list_offsets().data_handle(), index_2.n_lists(), stream_);
handle_.sync_stream(stream_);
for (uint32_t l = 0; l < index_2.n_lists(); l++) {
rmm::device_uvector<float> cluster_data(list_sizes[l] * ps.dim, stream_);
raft::spatial::knn::detail::utils::copy_selected<float>(
(IdxT)list_sizes[l],
(IdxT)ps.dim,
database.data(),
index_2.indices().data_handle() + list_offsets[l],
(IdxT)ps.dim,
cluster_data.data(),
(IdxT)ps.dim,
stream_);
raft::stats::mean<float, uint32_t>(
centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_);
ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle() + ps.dim * l,
centroid.data(),
ps.dim,
raft::CompareApprox<float>(0.001),
stream_));
}
} else {
// The centers must be immutable
ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(),
index.centers().data_handle(),
index_2.centers().size(),
raft::Compare<float>(),
stream_));
}
}
ASSERT_TRUE(eval_neighbours(indices_naive,
indices_ivfflat,
Expand Down Expand Up @@ -243,44 +284,44 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {

const std::vector<AnnIvfFlatInputs<int64_t>> inputs = {
// test various dims (aligned and not aligned to vector sizes)
{1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 2, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 3, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true},
{1000, 10000, 2, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 3, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true},
{1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false},
{1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true},

// test dims that do not fit into kernel shared memory limits
{1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 2049, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 2050, 16, 40, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 2051, 16, 40, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 2052, 16, 40, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 2053, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 2056, 16, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 2049, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 2050, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false},
{1000, 10000, 2051, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true},
{1000, 10000, 2052, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false},
{1000, 10000, 2053, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true},
{1000, 10000, 2056, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true},

// various random combinations
{1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded},
{100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded},
{20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded},
{1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded},
{10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded},

{1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::InnerProduct},
{100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::InnerProduct},
{20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct},
{10000, 131072, 8, 10, 50, 1024, raft::distance::DistanceType::InnerProduct},

{1000, 10000, 4096, 20, 50, 1024, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded, false},
{100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded, false},
{20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true},
{1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true},
{10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false},

{1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::InnerProduct, true},
{1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true},
{1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::InnerProduct, false},
{100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::InnerProduct, true},
{20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct, true},
{1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct, false},
{10000, 131072, 8, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true},

{1000, 10000, 4096, 20, 50, 1024, raft::distance::DistanceType::InnerProduct, false},

// test splitting the big query batches (> max gridDim.y) into smaller batches
{100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct},
{98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct},
{100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, false},
{98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, true},

// test radix_sort for getting the cluster selection
{1000,
Expand All @@ -289,14 +330,16 @@ const std::vector<AnnIvfFlatInputs<int64_t>> inputs = {
10,
raft::spatial::knn::detail::topk::kMaxCapacity * 2,
raft::spatial::knn::detail::topk::kMaxCapacity * 4,
raft::distance::DistanceType::L2Expanded},
raft::distance::DistanceType::L2Expanded,
false},
{1000,
10000,
16,
10,
raft::spatial::knn::detail::topk::kMaxCapacity * 4,
raft::spatial::knn::detail::topk::kMaxCapacity * 4,
raft::distance::DistanceType::InnerProduct}};
raft::distance::DistanceType::InnerProduct,
false}};

typedef AnnIVFFlatTest<float, float, std::int64_t> AnnIVFFlatTestF;
TEST_P(AnnIVFFlatTestF, AnnIVFFlat) { this->testIVFFlat(); }
Expand Down

0 comments on commit c7e74bd

Please sign in to comment.