Skip to content

Commit

Permalink
Improve speed for SQLDocumentStore (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni authored Aug 21, 2020
1 parent a54d6a5 commit 7d2a8f1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 29 deletions.
6 changes: 3 additions & 3 deletions haystack/database/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def query_by_embedding(
_, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]

documents = [
self.get_all_documents(filters={"vector_id": [vector_id]})[0] for vector_id in vector_ids_for_query
]
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
# sort the documents as per query results
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore

return documents

Expand Down
34 changes: 8 additions & 26 deletions haystack/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class DocumentORM(ORMBase):
class MetaORM(ORMBase):
__tablename__ = "meta"

name = Column(String)
value = Column(String)
name = Column(String, index=True)
value = Column(String, index=True)

documents = relationship(DocumentORM, secondary="document_meta", backref="Meta")

Expand Down Expand Up @@ -80,36 +80,18 @@ def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> Li

return documents

def get_all_documents( # type: ignore
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
def get_all_documents(
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None
) -> List[Document]:
index = index or self.index
document_rows = self.session.query(DocumentORM).filter_by(index=index).all()
if offset:
document_rows = document_rows.offset(offset)
if limit:
document_rows = document_rows.limit(limit)

documents = []
for row in document_rows:
documents.append(self._convert_sql_row_to_document(row))
query = self.session.query(DocumentORM).filter_by(index=index)

if filters:
for key, values in filters.items():
results = (
self.session.query(DocumentORM)
.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))
.filter(DocumentORM.meta.any(MetaORM.value.in_(values)))
.all()
)
else:
results = self.session.query(DocumentORM).filter_by(index=index).all()
query = query.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))\
.filter(DocumentORM.meta.any(MetaORM.value.in_(values)))

documents = [self._convert_sql_row_to_document(row) for row in results]
documents = [self._convert_sql_row_to_document(row) for row in query.all()]
return documents

def get_all_labels(self, index=None, filters: Optional[dict] = None):
Expand Down

0 comments on commit 7d2a8f1

Please sign in to comment.