Skip to content

Commit

Permalink
Write-side support for FAISS IVF indices (#13197)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #13197

The patch adds initial support for backing FAISS's inverted file based indices with data stored in RocksDB. It introduces a `SecondaryIndex` implementation called `FaissIVFIndex` which takes ownership of a `faiss::IndexIVF` object. During indexing, `FaissIVFIndex` treats the original value of the specified primary column as an embedding vector, and passes it to the provided FAISS index object to perform quantization. It replaces the original embedding vector with the result of the coarse quantizer (i.e. the inverted list id), and puts the result of the fine quantizer (if any) into the secondary index value. Note that this patch is only one half of the equation; it provides a way of storing FAISS inverted lists in RocksDB but there is currently no retrieval/search support (this will be a follow-up change). Also, the integration currently works only with our internal Buck build. I plan to add support for `cmake` / `make` based builds similarly to how we handle Folly.

Reviewed By: jowlyzhang

Differential Revision: D66907065

fbshipit-source-id: 63fdf29895d5feeffc230254a7ddfb0aac050967
  • Loading branch information
ltamasi authored and facebook-github-bot committed Dec 10, 2024
1 parent 5aead7a commit b339d08
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 9 deletions.
25 changes: 25 additions & 0 deletions BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ cpp_library_wrapper(name="rocksdb_lib", srcs=[

cpp_library_wrapper(name="rocksdb_whole_archive_lib", srcs=[], deps=[":rocksdb_lib"], headers=[], link_whole=True, extra_test_libs=False)

cpp_library_wrapper(name="rocksdb_with_faiss_lib", srcs=["utilities/secondary_index/faiss_ivf_index.cc"], deps=[
"//faiss:faiss",
":rocksdb_lib",
], headers=[], link_whole=False, extra_test_libs=False)

cpp_library_wrapper(name="rocksdb_test_lib", srcs=[
"db/db_test_util.cc",
"db/db_with_timestamp_test_util.cc",
Expand All @@ -382,6 +387,20 @@ cpp_library_wrapper(name="rocksdb_test_lib", srcs=[
"utilities/cassandra/test_utils.cc",
], deps=[":rocksdb_lib"], headers=[], link_whole=False, extra_test_libs=True)

cpp_library_wrapper(name="rocksdb_with_faiss_test_lib", srcs=[
"db/db_test_util.cc",
"db/db_with_timestamp_test_util.cc",
"table/mock_table.cc",
"test_util/mock_time_env.cc",
"test_util/secondary_cache_test_util.cc",
"test_util/testharness.cc",
"test_util/testutil.cc",
"tools/block_cache_analyzer/block_cache_trace_analyzer.cc",
"tools/trace_analyzer_tool.cc",
"utilities/agg_merge/test_agg_merge.cc",
"utilities/cassandra/test_utils.cc",
], deps=[":rocksdb_with_faiss_lib"], headers=[], link_whole=False, extra_test_libs=True)

cpp_library_wrapper(name="rocksdb_tools_lib", srcs=[
"test_util/testutil.cc",
"tools/block_cache_analyzer/block_cache_trace_analyzer.cc",
Expand Down Expand Up @@ -5078,6 +5097,12 @@ cpp_unittest_wrapper(name="external_sst_file_test",
extra_compiler_flags=[])


cpp_unittest_wrapper(name="faiss_ivf_index_test",
srcs=["utilities/secondary_index/faiss_ivf_index_test.cc"],
deps=[":rocksdb_with_faiss_test_lib"],
extra_compiler_flags=[])


cpp_unittest_wrapper(name="fault_injection_test",
srcs=["db/fault_injection_test.cc"],
deps=[":rocksdb_test_lib"],
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ ifneq ($(filter check-headers, $(MAKECMDGOALS)),)
# TODO: add/support JNI headers
DEV_HEADER_DIRS := $(sort include/ $(dir $(ALL_SOURCES)))
# Some headers like in port/ are platform-specific
DEV_HEADERS_TO_CHECK := $(shell $(FIND) $(DEV_HEADER_DIRS) -type f -name '*.h' | grep -E -v 'port/|plugin/|lua/|range_tree/')
DEV_HEADERS_TO_CHECK := $(shell $(FIND) $(DEV_HEADER_DIRS) -type f -name '*.h' | grep -E -v 'port/|plugin/|lua/|range_tree/|secondary_index/')
PUBLIC_HEADERS_TO_CHECK := $(shell $(FIND) include/ -type f -name '*.h' | grep -E -v 'lua/')
else
DEV_HEADERS_TO_CHECK :=
Expand Down
50 changes: 42 additions & 8 deletions buckifier/buckify_rocksdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def generate_buck(repo_path, deps_map):
extra_external_deps="",
link_whole=True,
)
# rocksdb_with_faiss_lib
BUCK.add_library(
"rocksdb_with_faiss_lib",
src_mk.get("WITH_FAISS_LIB_SOURCES", []),
deps=[
"//faiss:faiss",
":rocksdb_lib",
],
)
# rocksdb_test_lib
BUCK.add_library(
"rocksdb_test_lib",
Expand All @@ -171,6 +180,18 @@ def generate_buck(repo_path, deps_map):
[":rocksdb_lib"],
extra_test_libs=True,
)
# rocksdb_with_faiss_test_lib
BUCK.add_library(
"rocksdb_with_faiss_test_lib",
src_mk.get("MOCK_LIB_SOURCES", [])
+ src_mk.get("TEST_LIB_SOURCES", [])
+ src_mk.get("EXP_LIB_SOURCES", [])
+ src_mk.get("ANALYZER_LIB_SOURCES", []),
deps=[
":rocksdb_with_faiss_lib",
],
extra_test_libs=True,
)
# rocksdb_tools_lib
BUCK.add_library(
"rocksdb_tools_lib",
Expand Down Expand Up @@ -278,11 +299,16 @@ def generate_buck(repo_path, deps_map):

for test_src in src_mk.get("TEST_MAIN_SOURCES", []):
test = test_src.split(".c")[0].strip().split("/")[-1].strip()
test_source_map[test] = test_src
test_source_map[test] = (test_src, False)
print("" + test + " " + test_src)

for test_src in src_mk.get("WITH_FAISS_TEST_MAIN_SOURCES", []):
test = test_src.split(".c")[0].strip().split("/")[-1].strip()
test_source_map[test] = (test_src, True)
print("" + test + " " + test_src + " [FAISS]")

for target_alias, deps in deps_map.items():
for test, test_src in sorted(test_source_map.items()):
for test, (test_src, with_faiss) in sorted(test_source_map.items()):
if len(test) == 0:
print(ColorString.warning("Failed to get test name for %s" % test_src))
continue
Expand All @@ -304,12 +330,20 @@ def generate_buck(repo_path, deps_map):
extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]),
)
else:
BUCK.register_test(
test_target_name,
test_src,
deps=json.dumps(deps["extra_deps"] + [":rocksdb_test_lib"]),
extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]),
)
if with_faiss:
BUCK.register_test(
test_target_name,
test_src,
deps=json.dumps(deps["extra_deps"] + [":rocksdb_with_faiss_test_lib"]),
extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]),
)
else:
BUCK.register_test(
test_target_name,
test_src,
deps=json.dumps(deps["extra_deps"] + [":rocksdb_test_lib"]),
extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]),
)
BUCK.export_file("tools/db_crashtest.py")

print(ColorString.info("Generated BUCK Summary:"))
Expand Down
6 changes: 6 additions & 0 deletions src.mk
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ LIB_SOURCES_ASM =
LIB_SOURCES_C =
endif

WITH_FAISS_LIB_SOURCES = \
utilities/secondary_index/faiss_ivf_index.cc \

RANGE_TREE_SOURCES =\
utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc \
utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc \
Expand Down Expand Up @@ -651,6 +654,9 @@ TEST_MAIN_SOURCES = \
TEST_MAIN_SOURCES_C = \
db/c_test.c \

WITH_FAISS_TEST_MAIN_SOURCES = \
utilities/secondary_index/faiss_ivf_index_test.cc \

MICROBENCH_SOURCES = \
microbench/ribbon_bench.cc \
microbench/db_basic_bench.cc \
Expand Down
214 changes: 214 additions & 0 deletions utilities/secondary_index/faiss_ivf_index.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).

#include "utilities/secondary_index/faiss_ivf_index.h"

#include <cassert>

#include "faiss/invlists/InvertedLists.h"
#include "util/coding.h"

namespace ROCKSDB_NAMESPACE {

class FaissIVFIndex::Adapter : public faiss::InvertedLists {
public:
Adapter(size_t num_lists, size_t code_size)
: faiss::InvertedLists(num_lists, code_size) {
use_iterator = true;
}

// Non-iterator-based read interface; not implemented/used since use_iterator
// is true
size_t list_size(size_t /* list_no */) const override {
assert(false);
return 0;
}

const uint8_t* get_codes(size_t /* list_no */) const override {
assert(false);
return nullptr;
}

const faiss::idx_t* get_ids(size_t /* list_no */) const override {
assert(false);
return nullptr;
}

// Iterator-based read interface; not yet implemented
faiss::InvertedListsIterator* get_iterator(
size_t /* list_no */,
void* /* inverted_list_context */ = nullptr) const override {
// TODO: implement this

assert(false);
return nullptr;
}

// Write interface; only add_entry is implemented/required for now
size_t add_entry(size_t /* list_no */, faiss::idx_t /* id */,
const uint8_t* code,
void* inverted_list_context = nullptr) override {
std::string* const code_str =
static_cast<std::string*>(inverted_list_context);
assert(code_str);

code_str->assign(reinterpret_cast<const char*>(code), code_size);

return 0;
}

size_t add_entries(size_t /* list_no */, size_t /* num_entries */,
const faiss::idx_t* /* ids */,
const uint8_t* /* code */) override {
assert(false);
return 0;
}

void update_entry(size_t /* list_no */, size_t /* offset */,
faiss::idx_t /* id */, const uint8_t* /* code */) override {
assert(false);
}

void update_entries(size_t /* list_no */, size_t /* offset */,
size_t /* num_entries */, const faiss::idx_t* /* ids */,
const uint8_t* /* code */) override {
assert(false);
}

void resize(size_t /* list_no */, size_t /* new_size */) override {
assert(false);
}
};

std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) {
std::string label_str;
PutVarsignedint64(&label_str, label);

return label_str;
}

faiss::idx_t FaissIVFIndex::DeserializeLabel(Slice label_slice) {
faiss::idx_t label = -1;
[[maybe_unused]] const bool ok = GetVarsignedint64(&label_slice, &label);
assert(ok);

return label;
}

FaissIVFIndex::FaissIVFIndex(std::unique_ptr<faiss::IndexIVF>&& index,
std::string primary_column_name)
: adapter_(std::make_unique<Adapter>(index->nlist, index->code_size)),
index_(std::move(index)),
primary_column_name_(std::move(primary_column_name)) {
assert(index_);
assert(index_->quantizer);

index_->replace_invlists(adapter_.get());
}

FaissIVFIndex::~FaissIVFIndex() = default;

void FaissIVFIndex::SetPrimaryColumnFamily(ColumnFamilyHandle* column_family) {
assert(column_family);
primary_column_family_ = column_family;
}

void FaissIVFIndex::SetSecondaryColumnFamily(
ColumnFamilyHandle* column_family) {
assert(column_family);
secondary_column_family_ = column_family;
}

ColumnFamilyHandle* FaissIVFIndex::GetPrimaryColumnFamily() const {
return primary_column_family_;
}

ColumnFamilyHandle* FaissIVFIndex::GetSecondaryColumnFamily() const {
return secondary_column_family_;
}

Slice FaissIVFIndex::GetPrimaryColumnName() const {
return primary_column_name_;
}

Status FaissIVFIndex::UpdatePrimaryColumnValue(
const Slice& /* primary_key */, const Slice& primary_column_value,
std::optional<std::variant<Slice, std::string>>* updated_column_value)
const {
assert(updated_column_value);

if (primary_column_value.size() != index_->d * sizeof(float)) {
return Status::InvalidArgument(
"Incorrectly sized vector passed to FaissIVFIndex");
}

constexpr faiss::idx_t n = 1;
faiss::idx_t label = -1;

try {
index_->quantizer->assign(
n, reinterpret_cast<const float*>(primary_column_value.data()), &label);
} catch (const std::exception& e) {
return Status::InvalidArgument(e.what());
}

if (label < 0 || label >= index_->nlist) {
return Status::InvalidArgument(
"Unexpected label returned by coarse quantizer");
}

updated_column_value->emplace(SerializeLabel(label));

return Status::OK();
}

Status FaissIVFIndex::GetSecondaryKeyPrefix(
const Slice& /* primary_key */, const Slice& primary_column_value,
std::variant<Slice, std::string>* secondary_key_prefix) const {
assert(secondary_key_prefix);

[[maybe_unused]] const faiss::idx_t label =
DeserializeLabel(primary_column_value);
assert(label >= 0);
assert(label < index_->nlist);

*secondary_key_prefix = primary_column_value;

return Status::OK();
}

Status FaissIVFIndex::GetSecondaryValue(
const Slice& /* primary_key */, const Slice& primary_column_value,
const Slice& original_column_value,
std::optional<std::variant<Slice, std::string>>* secondary_value) const {
assert(secondary_value);

const faiss::idx_t label = DeserializeLabel(primary_column_value);
assert(label >= 0);
assert(label < index_->nlist);

constexpr faiss::idx_t n = 1;
constexpr faiss::idx_t* xids = nullptr;
std::string code_str;

try {
index_->add_core(
n, reinterpret_cast<const float*>(original_column_value.data()), xids,
&label, &code_str);
} catch (const std::exception& e) {
return Status::InvalidArgument(e.what());
}

if (code_str.size() != index_->code_size) {
return Status::InvalidArgument(
"Unexpected code returned by fine quantizer");
}

secondary_value->emplace(std::move(code_str));

return Status::OK();
}

} // namespace ROCKSDB_NAMESPACE
Loading

0 comments on commit b339d08

Please sign in to comment.