diff --git a/haystack/database/faiss.py b/haystack/database/faiss.py index 7627dc7c3b..7f23352991 100644 --- a/haystack/database/faiss.py +++ b/haystack/database/faiss.py @@ -150,9 +150,9 @@ def query_by_embedding( _, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k) vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1] - documents = [ - self.get_all_documents(filters={"vector_id": [vector_id]})[0] for vector_id in vector_ids_for_query - ] + documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index) + # sort the documents as per query results + documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore return documents diff --git a/haystack/database/sql.py b/haystack/database/sql.py index d7e5a234a9..1b4bed14df 100644 --- a/haystack/database/sql.py +++ b/haystack/database/sql.py @@ -31,8 +31,8 @@ class DocumentORM(ORMBase): class MetaORM(ORMBase): __tablename__ = "meta" - name = Column(String) - value = Column(String) + name = Column(String, index=True) + value = Column(String, index=True) documents = relationship(DocumentORM, secondary="document_meta", backref="Meta") @@ -80,36 +80,18 @@ def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> Li return documents - def get_all_documents( # type: ignore - self, - limit: Optional[int] = None, - offset: Optional[int] = None, - index: Optional[str] = None, - filters: Optional[Dict[str, List[str]]] = None, + def get_all_documents( + self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None ) -> List[Document]: index = index or self.index - document_rows = self.session.query(DocumentORM).filter_by(index=index).all() - if offset: - document_rows = document_rows.offset(offset) - if limit: - document_rows = document_rows.limit(limit) - - documents = [] - for row in document_rows: - documents.append(self._convert_sql_row_to_document(row)) + query = self.session.query(DocumentORM).filter_by(index=index) if filters: for key, values in filters.items(): - results = ( - self.session.query(DocumentORM) - .filter(DocumentORM.meta.any(MetaORM.name.in_([key]))) - .filter(DocumentORM.meta.any(MetaORM.value.in_(values))) - .all() - ) - else: - results = self.session.query(DocumentORM).filter_by(index=index).all() + query = query.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))\ + .filter(DocumentORM.meta.any(MetaORM.value.in_(values))) - documents = [self._convert_sql_row_to_document(row) for row in results] + documents = [self._convert_sql_row_to_document(row) for row in query.all()] return documents def get_all_labels(self, index=None, filters: Optional[dict] = None):