Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: FAISSDocumentStore - make write_documents properly work in combination w update_embeddings #5221

Merged
merged 6 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 55 additions & 46 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,54 +257,63 @@ def write_documents(
document_objects = self._handle_duplicate_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
if len(document_objects) > 0:
add_vectors = all(doc.embedding is not None for doc in document_objects)

if self.duplicate_documents == "overwrite" and add_vectors:
logger.warning(
"You have to provide `duplicate_documents = 'overwrite'` arg and "
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
"Please call `update_embeddings` method to repopulate `faiss_index`"
)

vector_id = self.faiss_indexes[index].ntotal
with tqdm(
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
) as progress_bar:
for i in range(0, len(document_objects), batch_size):
if len(document_objects) == 0:
return

vector_id = self.faiss_indexes[index].ntotal
add_vectors = all(doc.embedding is not None for doc in document_objects)

if vector_id > 0 and self.duplicate_documents == "overwrite" and add_vectors:
logger.warning(
"`FAISSDocumentStore` is adding new vectors to an existing `faiss_index`.\n"
"Please call `update_embeddings` method to correctly repopulate `faiss_index`"
)

with tqdm(
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
) as progress_bar:
for i in range(0, len(document_objects), batch_size):
batch_documents = document_objects[i : i + batch_size]
if add_vectors:
if not self.faiss_indexes[index].is_trained:
raise ValueError(
f"FAISS index of type {self.faiss_index_factory_str} must be trained before adding vectors. Call `train_index()` "
"method before adding the vectors. For details, refer to the documentation: "
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
)

embeddings = [doc.embedding for doc in batch_documents]
embeddings_to_index = np.array(embeddings, dtype="float32")

if self.similarity == "cosine":
self.normalize_embedding(embeddings_to_index)

self.faiss_indexes[index].add(embeddings_to_index)

# write_documents method (duplicate_documents="overwrite") should properly work in combination with
# update_embeddings method (update_existing_embeddings=False).
# If no new embeddings are provided, we save the existing FAISS vector ids
elif self.duplicate_documents == "overwrite":
existing_docs = self.get_documents_by_id(ids=[doc.id for doc in batch_documents], index=index)
existing_docs_vector_ids = {
doc.id: doc.meta["vector_id"] for doc in existing_docs if doc.meta and "vector_id" in doc.meta
}

docs_to_write_in_sql = []
for doc in batch_documents:
meta = doc.meta
if add_vectors:
if not self.faiss_indexes[index].is_trained:
raise ValueError(
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
"method before adding the vectors. For details, refer to the documentation: "
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
"".format(self.faiss_index_factory_str)
)

embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
embeddings_to_index = np.array(embeddings, dtype="float32")

if self.similarity == "cosine":
self.normalize_embedding(embeddings_to_index)

self.faiss_indexes[index].add(embeddings_to_index)

docs_to_write_in_sql = []
for doc in document_objects[i : i + batch_size]:
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_id
vector_id += 1
docs_to_write_in_sql.append(doc)

super(FAISSDocumentStore, self).write_documents(
docs_to_write_in_sql,
index=index,
duplicate_documents=duplicate_documents,
batch_size=batch_size,
)
progress_bar.update(batch_size)
progress_bar.close()
meta["vector_id"] = vector_id
vector_id += 1
elif self.duplicate_documents == "overwrite" and doc.id in existing_docs_vector_ids:
meta["vector_id"] = existing_docs_vector_ids[doc.id]
docs_to_write_in_sql.append(doc)

super(FAISSDocumentStore, self).write_documents(
docs_to_write_in_sql, index=index, duplicate_documents=duplicate_documents, batch_size=batch_size
)
progress_bar.update(batch_size)

def _create_document_field_map(self) -> Dict:
return {self.index: self.embedding_field}
Expand Down
19 changes: 19 additions & 0 deletions test/document_stores/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ def test_update_docs_different_indexes(self, ds, documents_with_embeddings):
assert len(docs_from_index_b) == len(docs_b)
assert {int(doc.meta["vector_id"]) for doc in docs_from_index_b} == {0, 1, 2, 3}

@pytest.mark.integration
def test_dont_update_existing_embeddings(self, ds, docs):
retriever = MockDenseRetriever(document_store=ds)
first_doc_id = docs[0].id

for i in range(1, 4):
ds.write_documents(docs[:i])
ds.update_embeddings(retriever=retriever, update_existing_embeddings=False)

assert ds.get_document_count() == i
assert ds.get_embedding_count() == i
assert ds.get_document_by_id(id=first_doc_id).meta["vector_id"] == "0"
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

# Check if the embeddings of the first document remain unchanged after multiple updates
if i == 1:
first_doc_embedding = ds.get_document_by_id(id=first_doc_id).embedding
else:
assert np.array_equal(ds.get_document_by_id(id=first_doc_id).embedding, first_doc_embedding)

@pytest.mark.integration
def test_passing_index_from_outside(self, documents_with_embeddings, tmp_path):
d = 768
Expand Down