Skip to content

Commit

Permalink
add skip_storage flag to HNSW (facebookresearch#3487)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3487

Sometimes it is not useful to serialize the storage index along with a HNSW index. This diff adds a flag that supports skipping the storage of the index.

Searchign and adding to the index is not possible until a storage index is added back in.

Reviewed By: junjieqi

Differential Revision: D57911060

fbshipit-source-id: 5a4ceee4a8f53f6f746df59af3942b813a99c14f
  • Loading branch information
mdouze authored and facebook-github-bot committed May 31, 2024
1 parent 2230434 commit bf73e38
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 23 deletions.
5 changes: 2 additions & 3 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/IndexHNSW.h>

#include <omp.h>
Expand Down Expand Up @@ -251,7 +249,8 @@ void hnsw_search(
const SearchParameters* params_in) {
FAISS_THROW_IF_NOT_MSG(
index->storage,
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
"No storage index, please use IndexHNSWFlat (or variants) "
"instead of IndexHNSW directly");
const SearchParametersHNSW* params = nullptr;
const HNSW& hnsw = index->hnsw;

Expand Down
10 changes: 6 additions & 4 deletions faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/index_io.h>

#include <faiss/impl/io_macros.h>
Expand Down Expand Up @@ -531,7 +529,11 @@ Index* read_index(IOReader* f, int io_flags) {
Index* idx = nullptr;
uint32_t h;
READ1(h);
if (h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) {
if (h == fourcc("null")) {
// denotes a missing index, useful for some cases
return nullptr;
} else if (
h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) {
IndexFlat* idxf;
if (h == fourcc("IxFI")) {
idxf = new IndexFlatIP();
Expand Down Expand Up @@ -961,7 +963,7 @@ Index* read_index(IOReader* f, int io_flags) {
read_index_header(idxhnsw, f);
read_HNSW(&idxhnsw->hnsw, f);
idxhnsw->storage = read_index(f, io_flags);
idxhnsw->own_fields = true;
idxhnsw->own_fields = idxhnsw->storage != nullptr;
if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) {
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table();
}
Expand Down
25 changes: 16 additions & 9 deletions faiss/impl/index_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/index_io.h>

#include <faiss/impl/io.h>
Expand Down Expand Up @@ -390,8 +388,12 @@ static void write_ivf_header(const IndexIVF* ivf, IOWriter* f) {
write_direct_map(&ivf->direct_map, f);
}

void write_index(const Index* idx, IOWriter* f) {
if (const IndexFlat* idxf = dynamic_cast<const IndexFlat*>(idx)) {
void write_index(const Index* idx, IOWriter* f, int io_flags) {
if (idx == nullptr) {
// eg. for a storage component of HNSW that is set to nullptr
uint32_t h = fourcc("null");
WRITE1(h);
} else if (const IndexFlat* idxf = dynamic_cast<const IndexFlat*>(idx)) {
uint32_t h =
fourcc(idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI"
: idxf->metric_type == METRIC_L2 ? "IxF2"
Expand Down Expand Up @@ -765,7 +767,12 @@ void write_index(const Index* idx, IOWriter* f) {
WRITE1(h);
write_index_header(idxhnsw, f);
write_HNSW(&idxhnsw->hnsw, f);
write_index(idxhnsw->storage, f);
if (io_flags & IO_FLAG_SKIP_STORAGE) {
uint32_t n4 = fourcc("null");
WRITE1(n4);
} else {
write_index(idxhnsw->storage, f);
}
} else if (const IndexNSG* idxnsg = dynamic_cast<const IndexNSG*>(idx)) {
uint32_t h = dynamic_cast<const IndexNSGFlat*>(idx) ? fourcc("INSf")
: dynamic_cast<const IndexNSGPQ*>(idx) ? fourcc("INSp")
Expand Down Expand Up @@ -841,14 +848,14 @@ void write_index(const Index* idx, IOWriter* f) {
}
}

void write_index(const Index* idx, FILE* f) {
void write_index(const Index* idx, FILE* f, int io_flags) {
FileIOWriter writer(f);
write_index(idx, &writer);
write_index(idx, &writer, io_flags);
}

void write_index(const Index* idx, const char* fname) {
void write_index(const Index* idx, const char* fname, int io_flags) {
FileIOWriter writer(fname);
write_index(idx, &writer);
write_index(idx, &writer, io_flags);
}

void write_VectorTransform(const VectorTransform* vt, const char* fname) {
Expand Down
11 changes: 6 additions & 5 deletions faiss/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

// I/O code for indexes

#ifndef FAISS_INDEX_IO_H
Expand Down Expand Up @@ -35,9 +33,12 @@ struct IOReader;
struct IOWriter;
struct InvertedLists;

void write_index(const Index* idx, const char* fname);
void write_index(const Index* idx, FILE* f);
void write_index(const Index* idx, IOWriter* writer);
/// skip the storage for graph-based indexes
const int IO_FLAG_SKIP_STORAGE = 1;

void write_index(const Index* idx, const char* fname, int io_flags = 0);
void write_index(const Index* idx, FILE* f, int io_flags = 0);
void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);

void write_index_binary(const IndexBinary* idx, const char* fname);
void write_index_binary(const IndexBinary* idx, FILE* f);
Expand Down
4 changes: 2 additions & 2 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ def range_search_with_parameters(index, x, radius, params=None, output_stats=Fal
###########################################


def serialize_index(index):
def serialize_index(index, io_flags=0):
""" convert an index to a numpy uint8 array """
writer = VectorIOWriter()
write_index(index, writer)
write_index(index, writer, io_flags)
return vector_to_array(writer.data)


Expand Down
36 changes: 36 additions & 0 deletions tests/test_graph_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,42 @@ def test_ndis_stats(self):
Dhnsw, Ihnsw = index.search(self.xq, 1)
self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch)

def test_io_no_storage(self):
d = self.xq.shape[1]
index = faiss.IndexHNSWFlat(d, 16)
index.add(self.xb)

Dref, Iref = index.search(self.xq, 5)

# test writing without storage
index2 = faiss.deserialize_index(
faiss.serialize_index(index, faiss.IO_FLAG_SKIP_STORAGE)
)
self.assertEquals(index2.storage, None)
self.assertRaises(
RuntimeError,
index2.search, self.xb, 1)

# make sure we can store an index with empty storage
index4 = faiss.deserialize_index(
faiss.serialize_index(index2))

# add storage afterwards
index.storage = faiss.clone_index(index.storage)
index.own_fields = True

Dnew, Inew = index.search(self.xq, 5)
np.testing.assert_array_equal(Dnew, Dref)
np.testing.assert_array_equal(Inew, Iref)

if False:
# test reading without storage
# not implemented because it is hard to skip over an index
index3 = faiss.deserialize_index(
faiss.serialize_index(index), faiss.IO_FLAG_SKIP_STORAGE
)
self.assertEquals(index3.storage, None)


class TestNSG(unittest.TestCase):

Expand Down

0 comments on commit bf73e38

Please sign in to comment.