diff --git a/BUCK b/BUCK index 7e12501725d..bdc7423b6b9 100644 --- a/BUCK +++ b/BUCK @@ -368,6 +368,11 @@ cpp_library_wrapper(name="rocksdb_lib", srcs=[ cpp_library_wrapper(name="rocksdb_whole_archive_lib", srcs=[], deps=[":rocksdb_lib"], headers=[], link_whole=True, extra_test_libs=False) +cpp_library_wrapper(name="rocksdb_with_faiss_lib", srcs=["utilities/secondary_index/faiss_ivf_index.cc"], deps=[ + "//faiss:faiss", + ":rocksdb_lib", + ], headers=[], link_whole=False, extra_test_libs=False) + cpp_library_wrapper(name="rocksdb_test_lib", srcs=[ "db/db_test_util.cc", "db/db_with_timestamp_test_util.cc", @@ -382,6 +387,20 @@ cpp_library_wrapper(name="rocksdb_test_lib", srcs=[ "utilities/cassandra/test_utils.cc", ], deps=[":rocksdb_lib"], headers=[], link_whole=False, extra_test_libs=True) +cpp_library_wrapper(name="rocksdb_with_faiss_test_lib", srcs=[ + "db/db_test_util.cc", + "db/db_with_timestamp_test_util.cc", + "table/mock_table.cc", + "test_util/mock_time_env.cc", + "test_util/secondary_cache_test_util.cc", + "test_util/testharness.cc", + "test_util/testutil.cc", + "tools/block_cache_analyzer/block_cache_trace_analyzer.cc", + "tools/trace_analyzer_tool.cc", + "utilities/agg_merge/test_agg_merge.cc", + "utilities/cassandra/test_utils.cc", + ], deps=[":rocksdb_with_faiss_lib"], headers=[], link_whole=False, extra_test_libs=True) + cpp_library_wrapper(name="rocksdb_tools_lib", srcs=[ "test_util/testutil.cc", "tools/block_cache_analyzer/block_cache_trace_analyzer.cc", @@ -5078,6 +5097,12 @@ cpp_unittest_wrapper(name="external_sst_file_test", extra_compiler_flags=[]) +cpp_unittest_wrapper(name="faiss_ivf_index_test", + srcs=["utilities/secondary_index/faiss_ivf_index_test.cc"], + deps=[":rocksdb_with_faiss_test_lib"], + extra_compiler_flags=[]) + + cpp_unittest_wrapper(name="fault_injection_test", srcs=["db/fault_injection_test.cc"], deps=[":rocksdb_test_lib"], diff --git a/Makefile b/Makefile index 0bedd667b37..e2dee455a4c 100644 --- a/Makefile +++ b/Makefile @@ -659,7 +659,7 @@ ifneq ($(filter check-headers, $(MAKECMDGOALS)),) # TODO: add/support JNI headers DEV_HEADER_DIRS := $(sort include/ $(dir $(ALL_SOURCES))) # Some headers like in port/ are platform-specific - DEV_HEADERS_TO_CHECK := $(shell $(FIND) $(DEV_HEADER_DIRS) -type f -name '*.h' | grep -E -v 'port/|plugin/|lua/|range_tree/') + DEV_HEADERS_TO_CHECK := $(shell $(FIND) $(DEV_HEADER_DIRS) -type f -name '*.h' | grep -E -v 'port/|plugin/|lua/|range_tree/|secondary_index/') PUBLIC_HEADERS_TO_CHECK := $(shell $(FIND) include/ -type f -name '*.h' | grep -E -v 'lua/') else DEV_HEADERS_TO_CHECK := diff --git a/buckifier/buckify_rocksdb.py b/buckifier/buckify_rocksdb.py index 92fcb8a7bb3..035254b5ad1 100755 --- a/buckifier/buckify_rocksdb.py +++ b/buckifier/buckify_rocksdb.py @@ -161,6 +161,15 @@ def generate_buck(repo_path, deps_map): extra_external_deps="", link_whole=True, ) + # rocksdb_with_faiss_lib + BUCK.add_library( + "rocksdb_with_faiss_lib", + src_mk.get("WITH_FAISS_LIB_SOURCES", []), + deps=[ + "//faiss:faiss", + ":rocksdb_lib", + ], + ) # rocksdb_test_lib BUCK.add_library( "rocksdb_test_lib", @@ -171,6 +180,18 @@ def generate_buck(repo_path, deps_map): [":rocksdb_lib"], extra_test_libs=True, ) + # rocksdb_with_faiss_test_lib + BUCK.add_library( + "rocksdb_with_faiss_test_lib", + src_mk.get("MOCK_LIB_SOURCES", []) + + src_mk.get("TEST_LIB_SOURCES", []) + + src_mk.get("EXP_LIB_SOURCES", []) + + src_mk.get("ANALYZER_LIB_SOURCES", []), + deps=[ + ":rocksdb_with_faiss_lib", + ], + extra_test_libs=True, + ) # rocksdb_tools_lib BUCK.add_library( "rocksdb_tools_lib", @@ -278,11 +299,16 @@ def generate_buck(repo_path, deps_map): for test_src in src_mk.get("TEST_MAIN_SOURCES", []): test = test_src.split(".c")[0].strip().split("/")[-1].strip() - test_source_map[test] = test_src + test_source_map[test] = (test_src, False) print("" + test + " " + test_src) + for test_src in src_mk.get("WITH_FAISS_TEST_MAIN_SOURCES", []): + test = test_src.split(".c")[0].strip().split("/")[-1].strip() + test_source_map[test] = (test_src, True) + print("" + test + " " + test_src + " [FAISS]") + for target_alias, deps in deps_map.items(): - for test, test_src in sorted(test_source_map.items()): + for test, (test_src, with_faiss) in sorted(test_source_map.items()): if len(test) == 0: print(ColorString.warning("Failed to get test name for %s" % test_src)) continue @@ -304,12 +330,20 @@ def generate_buck(repo_path, deps_map): extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), ) else: - BUCK.register_test( - test_target_name, - test_src, - deps=json.dumps(deps["extra_deps"] + [":rocksdb_test_lib"]), - extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), - ) + if with_faiss: + BUCK.register_test( + test_target_name, + test_src, + deps=json.dumps(deps["extra_deps"] + [":rocksdb_with_faiss_test_lib"]), + extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), + ) + else: + BUCK.register_test( + test_target_name, + test_src, + deps=json.dumps(deps["extra_deps"] + [":rocksdb_test_lib"]), + extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), + ) BUCK.export_file("tools/db_crashtest.py") print(ColorString.info("Generated BUCK Summary:")) diff --git a/src.mk b/src.mk index fbe9ba1ea8e..121a08e928a 100644 --- a/src.mk +++ b/src.mk @@ -341,6 +341,9 @@ LIB_SOURCES_ASM = LIB_SOURCES_C = endif +WITH_FAISS_LIB_SOURCES = \ + utilities/secondary_index/faiss_ivf_index.cc \ + RANGE_TREE_SOURCES =\ utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc \ utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc \ @@ -651,6 +654,9 @@ TEST_MAIN_SOURCES = \ TEST_MAIN_SOURCES_C = \ db/c_test.c \ +WITH_FAISS_TEST_MAIN_SOURCES = \ + utilities/secondary_index/faiss_ivf_index_test.cc \ + MICROBENCH_SOURCES = \ microbench/ribbon_bench.cc \ microbench/db_basic_bench.cc \ diff --git a/utilities/secondary_index/faiss_ivf_index.cc b/utilities/secondary_index/faiss_ivf_index.cc new file mode 100644 index 00000000000..c419b98a2d1 --- /dev/null +++ b/utilities/secondary_index/faiss_ivf_index.cc @@ -0,0 +1,214 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/secondary_index/faiss_ivf_index.h" + +#include + +#include "faiss/invlists/InvertedLists.h" +#include "util/coding.h" + +namespace ROCKSDB_NAMESPACE { + +class FaissIVFIndex::Adapter : public faiss::InvertedLists { + public: + Adapter(size_t num_lists, size_t code_size) + : faiss::InvertedLists(num_lists, code_size) { + use_iterator = true; + } + + // Non-iterator-based read interface; not implemented/used since use_iterator + // is true + size_t list_size(size_t /* list_no */) const override { + assert(false); + return 0; + } + + const uint8_t* get_codes(size_t /* list_no */) const override { + assert(false); + return nullptr; + } + + const faiss::idx_t* get_ids(size_t /* list_no */) const override { + assert(false); + return nullptr; + } + + // Iterator-based read interface; not yet implemented + faiss::InvertedListsIterator* get_iterator( + size_t /* list_no */, + void* /* inverted_list_context */ = nullptr) const override { + // TODO: implement this + + assert(false); + return nullptr; + } + + // Write interface; only add_entry is implemented/required for now + size_t add_entry(size_t /* list_no */, faiss::idx_t /* id */, + const uint8_t* code, + void* inverted_list_context = nullptr) override { + std::string* const code_str = + static_cast(inverted_list_context); + assert(code_str); + + code_str->assign(reinterpret_cast(code), code_size); + + return 0; + } + + size_t add_entries(size_t /* list_no */, size_t /* num_entries */, + const faiss::idx_t* /* ids */, + const uint8_t* /* code */) override { + assert(false); + return 0; + } + + void update_entry(size_t /* list_no */, size_t /* offset */, + faiss::idx_t /* id */, const uint8_t* /* code */) override { + assert(false); + } + + void update_entries(size_t /* list_no */, size_t /* offset */, + size_t /* num_entries */, const faiss::idx_t* /* ids */, + const uint8_t* /* code */) override { + assert(false); + } + + void resize(size_t /* list_no */, size_t /* new_size */) override { + assert(false); + } +}; + +std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) { + std::string label_str; + PutVarsignedint64(&label_str, label); + + return label_str; +} + +faiss::idx_t FaissIVFIndex::DeserializeLabel(Slice label_slice) { + faiss::idx_t label = -1; + [[maybe_unused]] const bool ok = GetVarsignedint64(&label_slice, &label); + assert(ok); + + return label; +} + +FaissIVFIndex::FaissIVFIndex(std::unique_ptr&& index, + std::string primary_column_name) + : adapter_(std::make_unique(index->nlist, index->code_size)), + index_(std::move(index)), + primary_column_name_(std::move(primary_column_name)) { + assert(index_); + assert(index_->quantizer); + + index_->replace_invlists(adapter_.get()); +} + +FaissIVFIndex::~FaissIVFIndex() = default; + +void FaissIVFIndex::SetPrimaryColumnFamily(ColumnFamilyHandle* column_family) { + assert(column_family); + primary_column_family_ = column_family; +} + +void FaissIVFIndex::SetSecondaryColumnFamily( + ColumnFamilyHandle* column_family) { + assert(column_family); + secondary_column_family_ = column_family; +} + +ColumnFamilyHandle* FaissIVFIndex::GetPrimaryColumnFamily() const { + return primary_column_family_; +} + +ColumnFamilyHandle* FaissIVFIndex::GetSecondaryColumnFamily() const { + return secondary_column_family_; +} + +Slice FaissIVFIndex::GetPrimaryColumnName() const { + return primary_column_name_; +} + +Status FaissIVFIndex::UpdatePrimaryColumnValue( + const Slice& /* primary_key */, const Slice& primary_column_value, + std::optional>* updated_column_value) + const { + assert(updated_column_value); + + if (primary_column_value.size() != index_->d * sizeof(float)) { + return Status::InvalidArgument( + "Incorrectly sized vector passed to FaissIVFIndex"); + } + + constexpr faiss::idx_t n = 1; + faiss::idx_t label = -1; + + try { + index_->quantizer->assign( + n, reinterpret_cast(primary_column_value.data()), &label); + } catch (const std::exception& e) { + return Status::InvalidArgument(e.what()); + } + + if (label < 0 || label >= index_->nlist) { + return Status::InvalidArgument( + "Unexpected label returned by coarse quantizer"); + } + + updated_column_value->emplace(SerializeLabel(label)); + + return Status::OK(); +} + +Status FaissIVFIndex::GetSecondaryKeyPrefix( + const Slice& /* primary_key */, const Slice& primary_column_value, + std::variant* secondary_key_prefix) const { + assert(secondary_key_prefix); + + [[maybe_unused]] const faiss::idx_t label = + DeserializeLabel(primary_column_value); + assert(label >= 0); + assert(label < index_->nlist); + + *secondary_key_prefix = primary_column_value; + + return Status::OK(); +} + +Status FaissIVFIndex::GetSecondaryValue( + const Slice& /* primary_key */, const Slice& primary_column_value, + const Slice& original_column_value, + std::optional>* secondary_value) const { + assert(secondary_value); + + const faiss::idx_t label = DeserializeLabel(primary_column_value); + assert(label >= 0); + assert(label < index_->nlist); + + constexpr faiss::idx_t n = 1; + constexpr faiss::idx_t* xids = nullptr; + std::string code_str; + + try { + index_->add_core( + n, reinterpret_cast(original_column_value.data()), xids, + &label, &code_str); + } catch (const std::exception& e) { + return Status::InvalidArgument(e.what()); + } + + if (code_str.size() != index_->code_size) { + return Status::InvalidArgument( + "Unexpected code returned by fine quantizer"); + } + + secondary_value->emplace(std::move(code_str)); + + return Status::OK(); +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index.h b/utilities/secondary_index/faiss_ivf_index.h new file mode 100644 index 00000000000..956dba7762e --- /dev/null +++ b/utilities/secondary_index/faiss_ivf_index.h @@ -0,0 +1,60 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include +#include + +#include "faiss/IndexIVF.h" +#include "rocksdb/utilities/secondary_index.h" + +namespace ROCKSDB_NAMESPACE { + +// A SecondaryIndex implementation that wraps a FAISS inverted file index. +class FaissIVFIndex : public SecondaryIndex { + public: + explicit FaissIVFIndex(std::unique_ptr&& index, + std::string primary_column_name); + ~FaissIVFIndex() override; + + void SetPrimaryColumnFamily(ColumnFamilyHandle* column_family) override; + void SetSecondaryColumnFamily(ColumnFamilyHandle* column_family) override; + + ColumnFamilyHandle* GetPrimaryColumnFamily() const override; + ColumnFamilyHandle* GetSecondaryColumnFamily() const override; + + Slice GetPrimaryColumnName() const override; + + Status UpdatePrimaryColumnValue( + const Slice& primary_key, const Slice& primary_column_value, + std::optional>* updated_column_value) + const override; + + Status GetSecondaryKeyPrefix( + const Slice& primary_key, const Slice& primary_column_value, + std::variant* secondary_key_prefix) const override; + + Status GetSecondaryValue(const Slice& primary_key, + const Slice& primary_column_value, + const Slice& original_column_value, + std::optional>* + secondary_value) const override; + + private: + class Adapter; + + static std::string SerializeLabel(faiss::idx_t label); + static faiss::idx_t DeserializeLabel(Slice label_slice); + + std::unique_ptr adapter_; + std::unique_ptr index_; + std::string primary_column_name_; + ColumnFamilyHandle* primary_column_family_{}; + ColumnFamilyHandle* secondary_column_family_{}; +}; + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc new file mode 100644 index 00000000000..5d2008a47a7 --- /dev/null +++ b/utilities/secondary_index/faiss_ivf_index_test.cc @@ -0,0 +1,124 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/secondary_index/faiss_ivf_index.h" + +#include +#include +#include +#include + +#include "faiss/IndexFlat.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/utils/random.h" +#include "rocksdb/utilities/transaction_db.h" +#include "test_util/testharness.h" +#include "util/coding.h" + +namespace ROCKSDB_NAMESPACE { + +TEST(FaissIVFIndexTest, Basic) { + constexpr size_t dim = 128; + auto quantizer = std::make_unique(dim); + + constexpr size_t num_lists = 16; + auto index = + std::make_unique(quantizer.get(), dim, num_lists); + + constexpr faiss::idx_t num_vectors = 1024; + std::vector embeddings(dim * num_vectors); + faiss::float_rand(embeddings.data(), dim * num_vectors, 42); + + index->train(num_vectors, embeddings.data()); + + index->nprobe = 2; + + const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test"); + EXPECT_OK(DestroyDB(db_name, Options())); + + Options options; + options.create_if_missing = true; + + TransactionDBOptions txn_db_options; + const std::string primary_column_name = "embedding"; + txn_db_options.secondary_indices.emplace_back( + std::make_shared(std::move(index), primary_column_name)); + + TransactionDB* db = nullptr; + ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db)); + + std::unique_ptr db_guard(db); + + ColumnFamilyOptions cf1_opts; + ColumnFamilyHandle* cfh1 = nullptr; + ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1)); + std::unique_ptr cfh1_guard(cfh1); + + ColumnFamilyOptions cf2_opts; + ColumnFamilyHandle* cfh2 = nullptr; + ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2)); + std::unique_ptr cfh2_guard(cfh2); + + const auto& secondary_index = txn_db_options.secondary_indices.back(); + secondary_index->SetPrimaryColumnFamily(cfh1); + secondary_index->SetSecondaryColumnFamily(cfh2); + + { + std::unique_ptr txn(db->BeginTransaction(WriteOptions())); + + for (faiss::idx_t i = 0; i < num_vectors; ++i) { + const std::string primary_key = std::to_string(i); + + ASSERT_OK(txn->PutEntity( + cfh1, primary_key, + WideColumns{ + {primary_column_name, + Slice(reinterpret_cast(embeddings.data() + i * dim), + dim * sizeof(float))}})); + } + + ASSERT_OK(txn->Commit()); + } + + { + size_t num_found = 0; + + std::unique_ptr it(db->NewIterator(ReadOptions(), cfh2)); + + for (it->SeekToFirst(); it->Valid(); it->Next()) { + Slice key = it->key(); + faiss::idx_t label = -1; + ASSERT_TRUE(GetVarsignedint64(&key, &label)); + ASSERT_GE(label, 0); + ASSERT_LT(label, num_lists); + + faiss::idx_t id = -1; + ASSERT_EQ(std::from_chars(key.data(), key.data() + key.size(), id).ec, + std::errc()); + ASSERT_GE(id, 0); + ASSERT_LT(id, num_vectors); + + // Since we use IndexIVFFlat, there is no fine quantization, so the code + // is actually just the original embedding + ASSERT_EQ( + it->value(), + Slice(reinterpret_cast(embeddings.data() + id * dim), + dim * sizeof(float))); + + ++num_found; + } + + ASSERT_OK(it->status()); + ASSERT_EQ(num_found, num_vectors); + } +} + +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}