From 10ac7a5ecf913796d5ec716c9fd0fab163e587bd Mon Sep 17 00:00:00 2001 From: Levi Tamasi Date: Mon, 30 Dec 2024 11:24:42 -0800 Subject: [PATCH] Support KNN search for FAISS IVF indices (#13258) Summary: Pull Request resolved: https://github.com/facebook/rocksdb/pull/13258 Differential Revision: D67684898 --- include/rocksdb/options.h | 18 ++ utilities/secondary_index/faiss_ivf_index.cc | 246 +++++++++++++++++- utilities/secondary_index/faiss_ivf_index.h | 1 + .../secondary_index/faiss_ivf_index_test.cc | 166 +++++++++++- 4 files changed, 418 insertions(+), 13 deletions(-) diff --git a/include/rocksdb/options.h b/include/rocksdb/options.h index f34e34e59ccc..00b1afdc6f8c 100644 --- a/include/rocksdb/options.h +++ b/include/rocksdb/options.h @@ -1960,6 +1960,24 @@ struct ReadOptions { // Default: false bool allow_unprepared_value = false; + // The maximum number of neighbors K to return when performing a + // K-nearest-neighbors vector similarity search. The number of neighbors + // returned can be smaller if there are not enough vectors in the inverted + // lists probed. Only applicable to FAISS IVF secondary indices, where it must + // be specified and positive. See also `SecondaryIndex::NewIterator` and + // `similarity_search_probes` below. + // + // Default: none + std::optional similarity_search_neighbors; + + // The number of inverted lists to probe when performing a K-nearest-neighbors + // vector similarity search. Only applicable to FAISS IVF secondary indices, + // where it must be specified and positive. See also + // `SecondaryIndex::NewIterator` and `similarity_search_neighbors` above. + // + // Default: none + std::optional similarity_search_probes; + // *** END options only relevant to iterators or scans *** // *** BEGIN options for RocksDB internal use only *** diff --git a/utilities/secondary_index/faiss_ivf_index.cc b/utilities/secondary_index/faiss_ivf_index.cc index 1f897c516978..27c390f64344 100644 --- a/utilities/secondary_index/faiss_ivf_index.cc +++ b/utilities/secondary_index/faiss_ivf_index.cc @@ -6,12 +6,162 @@ #include "utilities/secondary_index/faiss_ivf_index.h" #include +#include +#include #include "faiss/invlists/InvertedLists.h" +#include "util/autovector.h" #include "util/coding.h" +#include "utilities/secondary_index/secondary_index_iterator.h" namespace ROCKSDB_NAMESPACE { +class FaissIVFIndex::KNNIterator : public Iterator { + public: + KNNIterator(faiss::IndexIVF* index, + std::unique_ptr&& secondary_index_it, size_t k, + size_t probes) + : index_(index), + secondary_index_it_(std::move(secondary_index_it)), + k_(k), + probes_(probes), + distances_(k, 0.0f), + labels_(k, -1), + pos_(0) { + assert(index_); + assert(secondary_index_it_); + assert(k_ > 0); + assert(probes_ > 0); + } + + Iterator* GetSecondaryIndexIterator() const { + return secondary_index_it_.get(); + } + + faiss::idx_t AddKey(std::string&& key) { + keys_.emplace_back(std::move(key)); + + return static_cast(keys_.size()) - 1; + } + + bool Valid() const override { + assert(k_ > 0); + assert(labels_.size() == k_); + assert(distances_.size() == k_); + + return status_.ok() && pos_ >= 0 && pos_ < k_ && labels_[pos_] >= 0; + } + + void SeekToFirst() override { + status_ = + Status::NotSupported("SeekToFirst not supported for FaissIVFIndex"); + } + + void SeekToLast() override { + status_ = + Status::NotSupported("SeekToLast not supported for FaissIVFIndex"); + } + + void Seek(const Slice& target) override { + distances_.assign(k_, 0.0f); + labels_.assign(k_, -1); + status_ = Status::OK(); + pos_ = 0; + keys_.clear(); + + faiss::SearchParametersIVF params; + params.nprobe = probes_; + params.inverted_list_context = this; + + constexpr faiss::idx_t n = 1; + + try { + index_->search(n, reinterpret_cast(target.data()), k_, + distances_.data(), labels_.data(), ¶ms); + } catch (const std::exception& e) { + status_ = Status::InvalidArgument(e.what()); + } + } + + void SeekForPrev(const Slice& /* target */) override { + status_ = + Status::NotSupported("SeekForPrev not supported for FaissIVFIndex"); + } + + void Next() override { + assert(Valid()); + + ++pos_; + } + + void Prev() override { + assert(Valid()); + + --pos_; + } + + Status status() const override { return status_; } + + Slice key() const override { + assert(Valid()); + assert(labels_[pos_] >= 0); + assert(labels_[pos_] < keys_.size()); + + return keys_[labels_[pos_]]; + } + + Slice value() const override { + assert(Valid()); + + return Slice(); + } + + const WideColumns& columns() const override { + assert(Valid()); + + return kNoWideColumns; + } + + Slice timestamp() const override { + assert(Valid()); + + return Slice(); + } + + Status GetProperty(std::string prop_name, std::string* prop) override { + if (!prop) { + return Status::InvalidArgument("No property pointer provided"); + } + + if (!Valid()) { + return Status::InvalidArgument("Iterator is not valid"); + } + + if (prop_name == kPropertyName_) { + *prop = std::to_string(distances_[pos_]); + return Status::OK(); + } + + return Iterator::GetProperty(std::move(prop_name), prop); + } + + private: + faiss::IndexIVF* index_; + std::unique_ptr secondary_index_it_; + size_t k_; + size_t probes_; + std::vector distances_; + std::vector labels_; + Status status_; + faiss::idx_t pos_; + autovector keys_; + + static const std::string kPropertyName_; +}; + +const std::string FaissIVFIndex::KNNIterator::kPropertyName_ = + "rocksdb.faiss.ivf.index.distance"; + class FaissIVFIndex::Adapter : public faiss::InvertedLists { public: Adapter(size_t num_lists, size_t code_size) @@ -36,14 +186,13 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists { return nullptr; } - // Iterator-based read interface; not yet implemented + // Iterator-based read interface faiss::InvertedListsIterator* get_iterator( - size_t /* list_no */, - void* /* inverted_list_context */ = nullptr) const override { - // TODO: implement this + size_t list_no, void* inverted_list_context = nullptr) const override { + KNNIterator* const it = static_cast(inverted_list_context); + assert(it); - assert(false); - return nullptr; + return new IteratorAdapter(it, list_no, code_size); } // Write interface; only add_entry is implemented/required for now @@ -80,6 +229,67 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists { void resize(size_t /* list_no */, size_t /* new_size */) override { assert(false); } + + private: + class IteratorAdapter : public faiss::InvertedListsIterator { + public: + IteratorAdapter(KNNIterator* it, size_t list_no, size_t code_size) + : it_(it), + secondary_index_it_(it->GetSecondaryIndexIterator()), + label_(FaissIVFIndex::SerializeLabel(list_no)), + code_size_(code_size) { + assert(it_); + assert(secondary_index_it_); + + secondary_index_it_->Seek(label_); + Update(); + } + + bool is_available() const override { return id_and_codes_.has_value(); } + + void next() override { + secondary_index_it_->Next(); + Update(); + } + + std::pair get_id_and_codes() override { + assert(is_available()); + + return *id_and_codes_; + } + + private: + void Update() { + id_and_codes_.reset(); + + if (!secondary_index_it_->Valid()) { + return; + } + + if (!secondary_index_it_->PrepareValue()) { + throw std::runtime_error( + "Failed to prepare value during iteration in FaissIVFIndex"); + } + + const Slice& value = secondary_index_it_->value(); + if (value.size() != code_size_) { + throw std::runtime_error( + "Code with unexpected size encountered during iteration in " + "FaissIVFIndex"); + } + + const Slice key = secondary_index_it_->key(); + const faiss::idx_t id = it_->AddKey(key.ToString()); + + id_and_codes_.emplace(id, reinterpret_cast(value.data())); + } + + KNNIterator* it_; + Iterator* secondary_index_it_; + std::string label_; + size_t code_size_; + std::optional> id_and_codes_; + }; }; std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) { @@ -105,6 +315,7 @@ FaissIVFIndex::FaissIVFIndex(std::unique_ptr&& index, assert(index_); assert(index_->quantizer); + index_->parallel_mode = 0; index_->replace_invlists(adapter_.get()); } @@ -202,7 +413,7 @@ Status FaissIVFIndex::GetSecondaryValue( if (code_str.size() != index_->code_size) { return Status::InvalidArgument( - "Unexpected code returned by fine quantizer"); + "Code with unexpected size returned by fine quantizer"); } secondary_value->emplace(std::move(code_str)); @@ -211,10 +422,23 @@ Status FaissIVFIndex::GetSecondaryValue( } std::unique_ptr FaissIVFIndex::NewIterator( - const ReadOptions& /* read_options */, - Iterator* /* underlying_it */) const { - // TODO: implement this - return std::unique_ptr(NewErrorIterator(Status::NotSupported())); + const ReadOptions& read_options, Iterator* it) const { + if (!read_options.similarity_search_neighbors.has_value() || + *read_options.similarity_search_neighbors == 0) { + return std::unique_ptr(NewErrorIterator( + Status::InvalidArgument("Invalid number of neighbors"))); + } + + if (!read_options.similarity_search_probes.has_value() || + *read_options.similarity_search_probes == 0) { + return std::unique_ptr( + NewErrorIterator(Status::InvalidArgument("Invalid number of probes"))); + } + + return std::make_unique( + index_.get(), std::make_unique(this, it), + *read_options.similarity_search_neighbors, + *read_options.similarity_search_probes); } } // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index.h b/utilities/secondary_index/faiss_ivf_index.h index ab601552057f..ac8d2e953ac6 100644 --- a/utilities/secondary_index/faiss_ivf_index.h +++ b/utilities/secondary_index/faiss_ivf_index.h @@ -47,6 +47,7 @@ class FaissIVFIndex : public SecondaryIndex { Iterator* underlying_it) const override; private: + class KNNIterator; class Adapter; static std::string SerializeLabel(faiss::idx_t label); diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc index 5d2008a47a7c..c301f533d697 100644 --- a/utilities/secondary_index/faiss_ivf_index_test.cc +++ b/utilities/secondary_index/faiss_ivf_index_test.cc @@ -33,8 +33,6 @@ TEST(FaissIVFIndexTest, Basic) { index->train(num_vectors, embeddings.data()); - index->nprobe = 2; - const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test"); EXPECT_OK(DestroyDB(db_name, Options())); @@ -65,6 +63,8 @@ TEST(FaissIVFIndexTest, Basic) { secondary_index->SetPrimaryColumnFamily(cfh1); secondary_index->SetSecondaryColumnFamily(cfh2); + // Write the embeddings to the primary column family, indexing them in the + // process { std::unique_ptr txn(db->BeginTransaction(WriteOptions())); @@ -82,6 +82,7 @@ TEST(FaissIVFIndexTest, Basic) { ASSERT_OK(txn->Commit()); } + // Verify the raw index data in the secondary column family { size_t num_found = 0; @@ -113,6 +114,167 @@ TEST(FaissIVFIndexTest, Basic) { ASSERT_OK(it->status()); ASSERT_EQ(num_found, num_vectors); } + + // Query the index with some of the original embeddings + ReadOptions read_options; + read_options.similarity_search_neighbors = 8; + read_options.similarity_search_probes = num_lists; + + std::unique_ptr underlying_it(db->NewIterator(read_options, cfh2)); + + std::unique_ptr it = + txn_db_options.secondary_indices.back()->NewIterator(read_options, + underlying_it.get()); + + auto get_id = [&]() -> faiss::idx_t { + Slice key = it->key(); + faiss::idx_t id = -1; + + if (std::from_chars(key.data(), key.data() + key.size(), id).ec != + std::errc()) { + return -1; + } + + return id; + }; + + auto get_distance = [&]() -> float { + std::string distance_str; + float distance = 0.0f; + + if (!it->GetProperty("rocksdb.faiss.ivf.index.distance", &distance_str) + .ok()) { + return -1.0f; + } + + if (std::from_chars(distance_str.data(), + distance_str.data() + distance_str.size(), distance) + .ec != std::errc()) { + return -1.0f; + } + + return distance; + }; + + auto verify = [&](faiss::idx_t id) { + // Search for a vector from the original set; we expect to find the vector + // itself as the closest match, since we're performing an exhaustive search + { + it->Seek( + Slice(reinterpret_cast(embeddings.data() + id * dim), + dim * sizeof(float))); + ASSERT_TRUE(it->Valid()); + ASSERT_OK(it->status()); + ASSERT_EQ(get_id(), id); + ASSERT_TRUE(it->value().empty()); + ASSERT_TRUE(it->columns().empty()); + ASSERT_EQ(get_distance(), 0.0f); + } + + // Take a step forward then a step back to get back to our original position + { + it->Next(); + ASSERT_TRUE(it->Valid()); + ASSERT_OK(it->status()); + + it->Prev(); + ASSERT_TRUE(it->Valid()); + ASSERT_OK(it->status()); + ASSERT_EQ(get_id(), id); + ASSERT_TRUE(it->value().empty()); + ASSERT_TRUE(it->columns().empty()); + ASSERT_EQ(get_distance(), 0.0f); + } + + // Iterate over the rest of the results + float prev_distance = 0.0f; + size_t num_found = 1; + + for (it->Next(); it->Valid(); it->Next()) { + ASSERT_OK(it->status()); + + const faiss::idx_t id = get_id(); + ASSERT_GE(id, 0); + ASSERT_LT(id, num_vectors); + + ASSERT_TRUE(it->value().empty()); + ASSERT_TRUE(it->columns().empty()); + + const float distance = get_distance(); + ASSERT_GE(distance, prev_distance); + + prev_distance = distance; + ++num_found; + } + + ASSERT_OK(it->status()); + ASSERT_EQ(num_found, *read_options.similarity_search_neighbors); + }; + + verify(0); + verify(16); + verify(32); + verify(64); + + // Sanity check unsupported APIs + it->SeekToFirst(); + ASSERT_FALSE(it->Valid()); + ASSERT_NOK(it->status()); + + it->SeekToLast(); + ASSERT_FALSE(it->Valid()); + ASSERT_NOK(it->status()); + + it->SeekForPrev(Slice(reinterpret_cast(embeddings.data()), + dim * sizeof(float))); + ASSERT_FALSE(it->Valid()); + ASSERT_NOK(it->status()); + + { + ReadOptions bad_options; + bad_options.similarity_search_probes = 1; + + // similarity_search_neighbors not set + { + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, underlying_it.get()); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + + // similarity_search_neighbors set to zero + bad_options.similarity_search_neighbors = 0; + + { + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, underlying_it.get()); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + } + + { + ReadOptions bad_options; + bad_options.similarity_search_neighbors = 1; + + // similarity_search_probes not set + { + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, underlying_it.get()); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + + // similarity_search_probes set to zero + bad_options.similarity_search_probes = 0; + + { + std::unique_ptr bad_it = + txn_db_options.secondary_indices.back()->NewIterator( + bad_options, underlying_it.get()); + ASSERT_TRUE(bad_it->status().IsInvalidArgument()); + } + } } } // namespace ROCKSDB_NAMESPACE