diff --git a/haystack/database/elasticsearch.py b/haystack/database/elasticsearch.py index 1ca6b0eaec..1ca0ab3cd2 100644 --- a/haystack/database/elasticsearch.py +++ b/haystack/database/elasticsearch.py @@ -412,7 +412,9 @@ def query_by_embedding(self, def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document: # We put all additional data of the doc into meta_data and return it in the API meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)} - meta_data["name"] = meta_data.pop(self.name_field, None) + name = meta_data.pop(self.name_field, None) + if name: + meta_data["name"] = name document = Document( id=hit["_id"], diff --git a/haystack/database/sql.py b/haystack/database/sql.py index ff4e1e9131..d7e5a234a9 100644 --- a/haystack/database/sql.py +++ b/haystack/database/sql.py @@ -68,14 +68,14 @@ def __init__(self, url: str = "sqlite://", index="document"): self.index = index self.label_index = "label" - def get_document_by_id(self, id: str, index=None) -> Optional[Document]: - index = index or self.index - document_row = self.session.query(DocumentORM).filter_by(index=index, id=id).first() - document = document_row or self._convert_sql_row_to_document(document_row) + def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]: + documents = self.get_documents_by_id([id], index) + document = documents[0] if documents else None return document - def get_documents_by_id(self, ids: List[str]) -> List[Document]: - results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids)).all() + def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]: + index = index or self.index + results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all() documents = [self._convert_sql_row_to_document(row) for row in results] return documents @@ -138,7 +138,8 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] index = index or self.index for doc in document_objects: - meta_orms = [MetaORM(name=key, value=value) for key, value in doc.meta.items()] + meta_fields = doc.meta or {} + meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()] doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index) self.session.add(doc_orm) self.session.commit() diff --git a/test/test_db.py b/test/test_db.py index 5b0a6a7ef2..230ea772df 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -4,6 +4,7 @@ from haystack.database.base import Document, Label from haystack.database.elasticsearch import ElasticsearchDocumentStore +from haystack.database.faiss import FAISSDocumentStore def test_get_all_documents_without_filters(document_store_with_docs): @@ -42,6 +43,39 @@ def test_get_documents_by_id(document_store_with_docs): assert doc.text == documents[0].text +def test_write_document_meta(document_store): + documents = [ + {"text": "dict_without_meta", "id": "1"}, + {"text": "dict_with_meta", "meta_field": "test2", "name": "filename2", "id": "2"}, + Document(text="document_object_without_meta", id="3"), + Document(text="document_object_with_meta", meta={"meta_field": "test4", "name": "filename3"}, id="4"), + ] + document_store.write_documents(documents) + documents_in_store = document_store.get_all_documents() + assert len(documents_in_store) == 4 + + assert not document_store.get_document_by_id("1").meta + assert document_store.get_document_by_id("2").meta["meta_field"] == "test2" + assert not document_store.get_document_by_id("3").meta + assert document_store.get_document_by_id("4").meta["meta_field"] == "test4" + + +def test_write_document_index(document_store): + documents = [ + {"text": "text1", "id": "1"}, + {"text": "text2", "id": "2"}, + ] + document_store.write_documents([documents[0]], index="haystack_test_1") + assert len(document_store.get_all_documents(index="haystack_test_1")) == 1 + + if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS + document_store.write_documents([documents[1]], index="haystack_test_2") + assert len(document_store.get_all_documents(index="haystack_test_2")) == 1 + + assert len(document_store.get_all_documents(index="haystack_test_1")) == 1 + assert len(document_store.get_all_documents()) == 0 + + def test_labels(document_store): label = Label( question="question",