Skip to content

Commit

Permalink
add a context parameter to InvertedLists and InvertedListsIterator (#…
Browse files Browse the repository at this point in the history
…3247)

Summary:

add a context parameter to be passed to InvertedLists and InvertedListsIterator. 
- add a context field in `SearchParametersIVF`, the context will be passed to `InvertedLists::get_iterator`. The user can create `InvertedListsIterator` with the context object
- add a context parameter in `IndexIVF::add_core` method. the context will be passed to `InvertedLists::add_entry`. 

The user can use the context object to pass storage handlers, store error codes from storage layer, logging information, etc.

Reviewed By: mdouze

Differential Revision: D53113911
  • Loading branch information
bladepan authored and facebook-github-bot committed Feb 8, 2024
1 parent ed3f6e5 commit 8dc293d
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 34 deletions.
25 changes: 18 additions & 7 deletions faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ void IndexIVF::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
const idx_t* coarse_idx,
void* inverted_list_context) {
// do some blocking to avoid excessive allocs
idx_t bs = 65536;
if (n > bs) {
Expand All @@ -218,7 +219,8 @@ void IndexIVF::add_core(
i1 - i0,
x + i0 * d,
xids ? xids + i0 : nullptr,
coarse_idx + i0);
coarse_idx + i0,
inverted_list_context);
}
return;
}
Expand Down Expand Up @@ -249,7 +251,10 @@ void IndexIVF::add_core(
if (list_no >= 0 && list_no % nt == rank) {
idx_t id = xids ? xids[i] : ntotal + i;
size_t ofs = invlists->add_entry(
list_no, id, flat_codes.get() + i * code_size);
list_no,
id,
flat_codes.get() + i * code_size,
inverted_list_context);

dm_adder.add(i, list_no, ofs);

Expand Down Expand Up @@ -445,6 +450,9 @@ void IndexIVF::search_preassigned(
: pmode == 1 ? nprobe > 1
: nprobe * n > 1);

void* inverted_list_context =
params ? params->inverted_list_context : nullptr;

#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
{
std::unique_ptr<InvertedListScanner> scanner(
Expand Down Expand Up @@ -507,7 +515,7 @@ void IndexIVF::search_preassigned(
nlist);

// don't waste time on empty lists
if (invlists->is_empty(key)) {
if (invlists->is_empty(key, inverted_list_context)) {
return (size_t)0;
}

Expand All @@ -520,7 +528,7 @@ void IndexIVF::search_preassigned(
size_t list_size = 0;

std::unique_ptr<InvertedListsIterator> it(
invlists->get_iterator(key));
invlists->get_iterator(key, inverted_list_context));

nheap += scanner->iterate_codes(
it.get(), simi, idxi, k, list_size);
Expand Down Expand Up @@ -783,6 +791,9 @@ void IndexIVF::range_search_preassigned(
: pmode == 1 ? nprobe > 1
: nprobe * nx > 1);

void* inverted_list_context =
params ? params->inverted_list_context : nullptr;

#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
{
RangeSearchPartialResult pres(result);
Expand All @@ -804,7 +815,7 @@ void IndexIVF::range_search_preassigned(
ik,
nlist);

if (invlists->is_empty(key)) {
if (invlists->is_empty(key, inverted_list_context)) {
return;
}

Expand All @@ -813,7 +824,7 @@ void IndexIVF::range_search_preassigned(
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
if (invlists->use_iterator) {
std::unique_ptr<InvertedListsIterator> it(
invlists->get_iterator(key));
invlists->get_iterator(key, inverted_list_context));

scanner->iterate_codes_range(
it.get(), radius, qres, list_size);
Expand Down
5 changes: 4 additions & 1 deletion faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ struct SearchParametersIVF : SearchParameters {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
SearchParameters* quantizer_params = nullptr;
/// context object to pass to InvertedLists
void* inverted_list_context = nullptr;

virtual ~SearchParametersIVF() {}
};
Expand Down Expand Up @@ -232,7 +234,8 @@ struct IndexIVF : Index, IndexIVFInterface {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx);
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr);

/** Encodes a set of vectors as they would appear in the inverted lists
*
Expand Down
7 changes: 4 additions & 3 deletions faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ void IndexIVFFlat::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
const idx_t* coarse_idx,
void* inverted_list_context) {
FAISS_THROW_IF_NOT(is_trained);
FAISS_THROW_IF_NOT(coarse_idx);
FAISS_THROW_IF_NOT(!by_residual);
Expand All @@ -70,8 +71,8 @@ void IndexIVFFlat::add_core(
if (list_no >= 0 && list_no % nt == rank) {
idx_t id = xids ? xids[i] : ntotal + i;
const float* xi = x + i * d;
size_t offset =
invlists->add_entry(list_no, id, (const uint8_t*)xi);
size_t offset = invlists->add_entry(
list_no, id, (const uint8_t*)xi, inverted_list_context);
dm_adder.add(i, list_no, offset);
n_add++;
} else if (rank == 0 && list_no == -1) {
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVFFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ struct IndexIVFFlat : IndexIVF {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

void encode_vectors(
idx_t n,
Expand Down
14 changes: 9 additions & 5 deletions faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ void IndexIVFPQ::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
add_core_o(n, x, xids, nullptr, coarse_idx);
const idx_t* coarse_idx,
void* inverted_list_context) {
add_core_o(n, x, xids, nullptr, coarse_idx, inverted_list_context);
}

static std::unique_ptr<float[]> compute_residuals(
Expand Down Expand Up @@ -212,7 +213,8 @@ void IndexIVFPQ::add_core_o(
const float* x,
const idx_t* xids,
float* residuals_2,
const idx_t* precomputed_idx) {
const idx_t* precomputed_idx,
void* inverted_list_context) {
idx_t bs = index_ivfpq_add_core_o_bs;
if (n > bs) {
for (idx_t i0 = 0; i0 < n; i0 += bs) {
Expand All @@ -229,7 +231,8 @@ void IndexIVFPQ::add_core_o(
x + i0 * d,
xids ? xids + i0 : nullptr,
residuals_2 ? residuals_2 + i0 * d : nullptr,
precomputed_idx ? precomputed_idx + i0 : nullptr);
precomputed_idx ? precomputed_idx + i0 : nullptr,
inverted_list_context);
}
return;
}
Expand Down Expand Up @@ -281,7 +284,8 @@ void IndexIVFPQ::add_core_o(
}

uint8_t* code = xcodes.get() + i * code_size;
size_t offset = invlists->add_entry(key, id, code);
size_t offset =
invlists->add_entry(key, id, code, inverted_list_context);

if (residuals_2) {
float* res2 = residuals_2 + i * d;
Expand Down
6 changes: 4 additions & 2 deletions faiss/IndexIVFPQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ struct IndexIVFPQ : IndexIVF {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

/// same as add_core, also:
/// - output 2nd level residuals if residuals_2 != NULL
Expand All @@ -81,7 +82,8 @@ struct IndexIVFPQ : IndexIVF {
const float* x,
const idx_t* xids,
float* residuals_2,
const idx_t* precomputed_idx = nullptr);
const idx_t* precomputed_idx = nullptr,
void* inverted_list_context = nullptr);

/// trains the product quantizer
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVFPQR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void IndexIVFPQR::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) {
const idx_t* precomputed_idx,
void* /*inverted_list_context*/) {
std::unique_ptr<float[]> residual_2(new float[n * d]);

idx_t n0 = ntotal;
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVFPQR.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct IndexIVFPQR : IndexIVFPQ {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
const override;
Expand Down
6 changes: 4 additions & 2 deletions faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ void IndexIVFScalarQuantizer::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
const idx_t* coarse_idx,
void* inverted_list_context) {
FAISS_THROW_IF_NOT(is_trained);

std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
Expand Down Expand Up @@ -236,7 +237,8 @@ void IndexIVFScalarQuantizer::add_core(
memset(one_code.data(), 0, code_size);
squant->encode_vector(xi, one_code.data());

size_t ofs = invlists->add_entry(list_no, id, one_code.data());
size_t ofs = invlists->add_entry(
list_no, id, one_code.data(), inverted_list_context);

dm_add.add(i, list_no, ofs);

Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

InvertedListScanner* get_InvertedListScanner(
bool store_pairs,
Expand Down
18 changes: 11 additions & 7 deletions faiss/invlists/InvertedLists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ InvertedLists::InvertedLists(size_t nlist, size_t code_size)

InvertedLists::~InvertedLists() {}

bool InvertedLists::is_empty(size_t list_no) const {
return use_iterator
? !std::unique_ptr<InvertedListsIterator>(get_iterator(list_no))
->is_available()
: list_size(list_no) == 0;
bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context)
const {
return use_iterator ? !std::unique_ptr<InvertedListsIterator>(
get_iterator(list_no, inverted_list_context))
->is_available()
: list_size(list_no) == 0;
}

idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
Expand All @@ -58,7 +59,8 @@ const uint8_t* InvertedLists::get_single_code(size_t list_no, size_t offset)
size_t InvertedLists::add_entry(
size_t list_no,
idx_t theid,
const uint8_t* code) {
const uint8_t* code,
void* /*inverted_list_context*/) {
return add_entries(list_no, 1, &theid, code);
}

Expand All @@ -76,7 +78,9 @@ void InvertedLists::reset() {
}
}

InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const {
InvertedListsIterator* InvertedLists::get_iterator(
size_t /*list_no*/,
void* /*inverted_list_context*/) const {
FAISS_THROW_MSG("get_iterator is not supported");
}

Expand Down
12 changes: 9 additions & 3 deletions faiss/invlists/InvertedLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ struct InvertedLists {
* Read only functions */

// check if the list is empty
bool is_empty(size_t list_no) const;
bool is_empty(size_t list_no, void* inverted_list_context) const;

/// get the size of a list
virtual size_t list_size(size_t list_no) const = 0;

/// get iterable for lists that use_iterator
virtual InvertedListsIterator* get_iterator(size_t list_no) const;
virtual InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context) const;

/** get the codes for an inverted list
* must be released by release_codes
Expand Down Expand Up @@ -94,7 +96,11 @@ struct InvertedLists {
* writing functions */

/// add one entry to an inverted list
virtual size_t add_entry(size_t list_no, idx_t theid, const uint8_t* code);
virtual size_t add_entry(
size_t list_no,
idx_t theid,
const uint8_t* code,
void* inverted_list_context = nullptr);

virtual size_t add_entries(
size_t list_no,
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(FAISS_TEST_SRC
test_ivfpq_codec.cpp
test_ivfpq_indexing.cpp
test_lowlevel_ivf.cpp
test_ivf_index.cpp
test_merge.cpp
test_omp_threads.cpp
test_ondisk_ivf.cpp
Expand Down
Loading

0 comments on commit 8dc293d

Please sign in to comment.