Skip to content

Commit

Permalink
Add pagination for InMemoryDocumentStore
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni committed Jan 20, 2021
1 parent de13496 commit 0597085
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions haystack/document_store/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Iterator
from uuid import uuid4
from collections import defaultdict

Expand Down Expand Up @@ -187,9 +187,19 @@ def get_all_documents(
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
page_number: Optional[int] = None,
page_size: Optional[int] = None,
batch_size: int = 10_000,
) -> List[Document]:
result = self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding)
documents = list(result)
return documents

def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Iterator[Document]:
"""
Get documents from the document store.
Expand All @@ -198,11 +208,6 @@ def get_all_documents(
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param page_number: For getting a large number of documents, the results can be paginated. This
parameter defines the page number to be retrieved starting from the value 0. When using
page_number, the page_size argument must be set.
:param page_size: Number of documents to return in a single page. The page_number argument must be set when
using page_size.
"""
index = index or self.index
documents = deepcopy(list(self.indexes[index].values()))
Expand All @@ -229,12 +234,7 @@ def get_all_documents(
else:
filtered_documents = documents

if page_number is not None and page_size is not None:
start_pos = page_number * page_size
end_pos = start_pos + page_size
filtered_documents = filtered_documents[start_pos:end_pos]

return filtered_documents
yield from filtered_documents

def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
"""
Expand Down

0 comments on commit 0597085

Please sign in to comment.