Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-Flat: make adaptive-centers behavior optional #1019

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions cpp/include/raft/neighbors/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/util/integer_utils.hpp>

#include <optional>
#include <type_traits>

namespace raft::neighbors::ivf_flat {

Expand All @@ -37,6 +38,19 @@ struct index_params : ann::index_params {
uint32_t kmeans_n_iters = 20;
/** The fraction of data to use during iterative kmeans building. */
double kmeans_trainset_fraction = 0.5;
/**
* By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`,
* and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index
* from scratch after adding (`ivf_flat::extend`) a few times new data, distribution of which
achirkin marked this conversation as resolved.
Show resolved Hide resolved
* is not well-represented by the original training set.
achirkin marked this conversation as resolved.
Show resolved Hide resolved
*
* The alternative behavior (adaptive_centers = true) is to update cluster centers every time
achirkin marked this conversation as resolved.
Show resolved Hide resolved
* the new data is added. In this case, `index.centers()` are always exactly the centroids of
achirkin marked this conversation as resolved.
Show resolved Hide resolved
* the data in the corresponding clusters. The drawback of this behavior is that the centroids
* depend on the order of adding new data (through the classification of the added data); that is,
* `index.centers()` "drift" together with the changing distribution of the newly added data.
*/
bool adaptive_centers = false;
};

struct search_params : ann::search_params {
Expand Down Expand Up @@ -72,6 +86,11 @@ struct index : ann::index {
{
return metric_;
}
/** Whethe `centers()` change upon extending the index (ivf_pq::extend). */
[[nodiscard]] constexpr inline auto adaptive_centers() const noexcept -> bool
{
return adaptive_centers_;
}
/**
* Inverted list data [size, dim].
*
Expand Down Expand Up @@ -200,10 +219,15 @@ struct index : ann::index {
~index() = default;

/** Construct an empty index. It needs to be trained and then populated. */
index(const handle_t& handle, raft::distance::DistanceType metric, uint32_t n_lists, uint32_t dim)
index(const handle_t& handle,
raft::distance::DistanceType metric,
uint32_t n_lists,
bool adaptive_centers,
Copy link
Member

@cjnolet cjnolet Nov 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than making this an option on the index params object itself, why not just make it an optional argument on the extend() function? That would enable the user to determine whether it should be done outside of the creation of the index. It's really not an index option, but a design detail based on a specific use-case / usage-pattern.

I was thinking of something like extend(....., update_centers=false);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I see it sounds more intuitive to keep the parameter in the extend, I'd disagree for two reasons:

  1. The update-centers routine does not go through the whole data (which would be costly), but only aggregates newly added data. For this to work, the centers() must always be up-to-date. Hence, we cannot allow doing an "updating" extend after a "non-updating" one. Also from the logical point of view, this would break both alternative invariants (centers are neither constant nor up-to-date).
  2. I hope to preserve the public API the same for all our (ANN) models: build gets all parameters through the struct, extend never accepts extra parameters (and gets them from the index). Besides, the build function would need to take update_centers as well, because it optionally calls extend.

uint32_t dim)
: ann::index(),
veclen_(calculate_veclen(dim)),
metric_(metric),
adaptive_centers_(adaptive_centers),
data_(make_device_mdarray<T>(handle, make_extents<IdxT>(0, dim))),
indices_(make_device_mdarray<IdxT>(handle, make_extents<IdxT>(0))),
list_sizes_(make_device_mdarray<uint32_t>(handle, make_extents<uint32_t>(n_lists))),
Expand All @@ -216,7 +240,7 @@ struct index : ann::index {

/** Construct an empty index. It needs to be trained and then populated. */
index(const handle_t& handle, const index_params& params, uint32_t dim)
: index(handle, params.metric, params.n_lists, dim)
: index(handle, params.metric, params.n_lists, params.adaptive_centers, dim)
{
}

Expand All @@ -242,6 +266,7 @@ struct index : ann::index {
*/
uint32_t veclen_;
raft::distance::DistanceType metric_;
bool adaptive_centers_;
device_mdarray<T, extent_2d<IdxT>, row_major> data_;
device_mdarray<IdxT, extent_1d<IdxT>, row_major> indices_;
device_mdarray<uint32_t, extent_1d<uint32_t>, row_major> list_sizes_;
Expand Down
62 changes: 42 additions & 20 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/linalg/add.cuh>
#include <raft/stats/histogram.cuh>
#include <raft/util/pow2_utils.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -133,27 +135,40 @@ inline auto extend(const handle_t& handle,
orig_index.metric(),
stream);

index<T, IdxT> ext_index(handle, orig_index.metric(), n_lists, dim);
index<T, IdxT> ext_index(
handle, orig_index.metric(), n_lists, orig_index.adaptive_centers(), dim);

auto list_sizes_ptr = ext_index.list_sizes().data_handle();
auto list_offsets_ptr = ext_index.list_offsets().data_handle();
auto centers_ptr = ext_index.centers().data_handle();

// Calculate the centers and sizes on the new data, starting from the original values
raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream);
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);

kmeans::calc_centers_and_sizes(handle,
centers_ptr,
list_sizes_ptr,
n_lists,
dim,
new_vectors,
n_rows,
new_labels.data(),
false,
stream);
if (ext_index.adaptive_centers()) {
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);
kmeans::calc_centers_and_sizes(handle,
centers_ptr,
list_sizes_ptr,
n_lists,
dim,
new_vectors,
n_rows,
new_labels.data(),
false,
stream);
} else {
raft::stats::histogram<uint32_t, IdxT>(raft::stats::HistTypeAuto,
reinterpret_cast<int32_t*>(list_sizes_ptr),
IdxT(n_lists),
new_labels.data(),
n_rows,
1,
stream);
raft::linalg::add(
list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream);
}

// Calculate new offsets
IdxT index_size = 0;
Expand Down Expand Up @@ -210,13 +225,20 @@ inline auto extend(const handle_t& handle,

// Precompute the centers vector norms for L2Expanded distance
if (ext_index.center_norms().has_value()) {
// todo(lsugy): use other prim and remove this one
utils::dots_along_rows(n_lists,
dim,
ext_index.centers().data_handle(),
ext_index.center_norms()->data_handle(),
stream);
RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
if (!ext_index.adaptive_centers() && orig_index.center_norms().has_value()) {
raft::copy(ext_index.center_norms()->data_handle(),
orig_index.center_norms()->data_handle(),
orig_index.center_norms()->size(),
stream);
} else {
// todo(lsugy): use other prim and remove this one
utils::dots_along_rows(n_lists,
dim,
ext_index.centers().data_handle(),
ext_index.center_norms()->data_handle(),
stream);
RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}
}

// assemble the index
Expand Down