Skip to content

Commit

Permalink
Add document update for SQL and FAISS Document Store (#584)
Browse files Browse the repository at this point in the history
  • Loading branch information
lalitpagaria authored Nov 16, 2020
1 parent 3e095dd commit 3f81c93
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 11 deletions.
22 changes: 21 additions & 1 deletion haystack/document_store/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
faiss_index_factory_str: str = "Flat",
faiss_index: Optional[faiss.swigfaiss.Index] = None,
return_embedding: Optional[bool] = True,
update_existing_documents: bool = False,
index: str = "document",
**kwargs,
):
"""
Expand All @@ -63,6 +65,11 @@ def __init__(
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
or one with docs that you used in Haystack before and want to load again.
:param return_embedding: To return document embedding
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists.
:param index: Name of index in document store to use.
"""
self.vector_dim = vector_dim

Expand All @@ -73,7 +80,11 @@ def __init__(

self.index_buffer_size = index_buffer_size
self.return_embedding = return_embedding
super().__init__(url=sql_url)
super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
index=index
)

def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric_type=faiss.METRIC_INNER_PRODUCT, **kwargs):
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
Expand All @@ -99,12 +110,18 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
# vector index
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")

# doc + metadata index
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]

add_vectors = False if document_objects[0].embedding is None else True

if self.update_existing_documents and add_vectors:
logger.warning("You have enabled `update_existing_documents` feature and "
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
"Please call `update_embeddings` method to repopulate `faiss_index`")

for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = self.faiss_index.ntotal
if add_vectors:
Expand Down Expand Up @@ -134,6 +151,9 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")

# Faiss does not support update in existing index data so clear all existing data in it
self.faiss_index.reset()

index = index or self.index
documents = self.get_all_documents(index=index)

Expand Down
45 changes: 39 additions & 6 deletions haystack/document_store/sql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Dict, Union, List, Optional
from uuid import uuid4

Expand All @@ -10,6 +11,10 @@
from haystack import Document, Label
from haystack.preprocessor.utils import eval_data_from_file


logger = logging.getLogger(__name__)


Base = declarative_base() # type: Any


Expand Down Expand Up @@ -37,15 +42,15 @@ class MetaORM(ORMBase):

name = Column(String(100), index=True)
value = Column(String(1000), index=True)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)

documents = relationship(DocumentORM, backref="Meta")


class LabelORM(ORMBase):
__tablename__ = "label"

document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
index = Column(String(100), nullable=False)
no_answer = Column(Boolean, nullable=False)
origin = Column(String(100), nullable=False)
Expand All @@ -58,13 +63,30 @@ class LabelORM(ORMBase):


class SQLDocumentStore(BaseDocumentStore):
def __init__(self, url: str = "sqlite://", index="document"):
def __init__(
self,
url: str = "sqlite://",
index: str = "document",
label_index: str = "label",
update_existing_documents: bool = False,
):
"""
:param url: URL for SQL database as expected by SQLAlchemy. More info here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
:param index: The documents are scoped to an index attribute that can be used when writing, querying, or deleting documents.
This parameter sets the default value for document index.
:param label_index: The default value of index attribute for the labels.
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists. Using this parameter coud cause performance degradation for document insertion.
"""
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
self.index = index
self.label_index = "label"
self.label_index = label_index
self.update_existing_documents = update_existing_documents

def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
documents = self.get_documents_by_id([id], index)
Expand Down Expand Up @@ -132,8 +154,19 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
vector_id = meta_fields.get("vector_id")
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
self.session.add(doc_orm)
self.session.commit()
if self.update_existing_documents:
# First old meta data cleaning is required
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
self.session.merge(doc_orm)
else:
self.session.add(doc_orm)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
self.session.rollback()
raise ex

def write_labels(self, labels, index=None):

Expand Down
36 changes: 32 additions & 4 deletions test/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from haystack import Document, Label
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore


@pytest.mark.elasticsearch
Expand Down Expand Up @@ -85,6 +84,36 @@ def test_get_document_count(document_store):
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3


@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
def test_update_existing_documents(document_store, update_existing_documents):
original_docs = [
{"text": "text1_orig", "id": "1", "meta_field_for_count": "a"},
]

updated_docs = [
{"text": "text1_new", "id": "1", "meta_field_for_count": "a"},
]

document_store.update_existing_documents = update_existing_documents
document_store.write_documents(original_docs)
assert document_store.get_document_count() == 1

if update_existing_documents:
document_store.write_documents(updated_docs)
else:
with pytest.raises(Exception):
document_store.write_documents(updated_docs)

stored_docs = document_store.get_all_documents()
assert len(stored_docs) == 1
if update_existing_documents:
assert stored_docs[0].text == updated_docs[0]["text"]
else:
assert stored_docs[0].text == original_docs[0]["text"]


@pytest.mark.elasticsearch
def test_write_document_meta(document_store):
documents = [
Expand Down Expand Up @@ -112,9 +141,8 @@ def test_write_document_index(document_store):
document_store.write_documents([documents[0]], index="haystack_test_1")
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1

if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
document_store.write_documents([documents[1]], index="haystack_test_2")
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
document_store.write_documents([documents[1]], index="haystack_test_2")
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1

assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
assert len(document_store.get_all_documents()) == 0
Expand Down

0 comments on commit 3f81c93

Please sign in to comment.