Skip to content

Commit

Permalink
Fix indexing of metadata for FAISS/SQL Document Store (#310)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni authored Aug 13, 2020
1 parent 397dcf9 commit 089fecf
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
4 changes: 3 additions & 1 deletion haystack/database/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ def query_by_embedding(self,
def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document:
# We put all additional data of the doc into meta_data and return it in the API
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)}
meta_data["name"] = meta_data.pop(self.name_field, None)
name = meta_data.pop(self.name_field, None)
if name:
meta_data["name"] = name

document = Document(
id=hit["_id"],
Expand Down
15 changes: 8 additions & 7 deletions haystack/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def __init__(self, url: str = "sqlite://", index="document"):
self.index = index
self.label_index = "label"

def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
index = index or self.index
document_row = self.session.query(DocumentORM).filter_by(index=index, id=id).first()
document = document_row or self._convert_sql_row_to_document(document_row)
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document

def get_documents_by_id(self, ids: List[str]) -> List[Document]:
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids)).all()
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
index = index or self.index
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all()
documents = [self._convert_sql_row_to_document(row) for row in results]

return documents
Expand Down Expand Up @@ -138,7 +138,8 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
index = index or self.index
for doc in document_objects:
meta_orms = [MetaORM(name=key, value=value) for key, value in doc.meta.items()]
meta_fields = doc.meta or {}
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index)
self.session.add(doc_orm)
self.session.commit()
Expand Down
34 changes: 34 additions & 0 deletions test/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from haystack.database.base import Document, Label
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.database.faiss import FAISSDocumentStore


def test_get_all_documents_without_filters(document_store_with_docs):
Expand Down Expand Up @@ -42,6 +43,39 @@ def test_get_documents_by_id(document_store_with_docs):
assert doc.text == documents[0].text


def test_write_document_meta(document_store):
documents = [
{"text": "dict_without_meta", "id": "1"},
{"text": "dict_with_meta", "meta_field": "test2", "name": "filename2", "id": "2"},
Document(text="document_object_without_meta", id="3"),
Document(text="document_object_with_meta", meta={"meta_field": "test4", "name": "filename3"}, id="4"),
]
document_store.write_documents(documents)
documents_in_store = document_store.get_all_documents()
assert len(documents_in_store) == 4

assert not document_store.get_document_by_id("1").meta
assert document_store.get_document_by_id("2").meta["meta_field"] == "test2"
assert not document_store.get_document_by_id("3").meta
assert document_store.get_document_by_id("4").meta["meta_field"] == "test4"


def test_write_document_index(document_store):
documents = [
{"text": "text1", "id": "1"},
{"text": "text2", "id": "2"},
]
document_store.write_documents([documents[0]], index="haystack_test_1")
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1

if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
document_store.write_documents([documents[1]], index="haystack_test_2")
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1

assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
assert len(document_store.get_all_documents()) == 0


def test_labels(document_store):
label = Label(
question="question",
Expand Down

0 comments on commit 089fecf

Please sign in to comment.