From 059708590344afad06898a64117bc4ac6b2ec907 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Wed, 20 Jan 2021 17:29:25 +0100 Subject: [PATCH] Add pagination for InMemoryDocumentStore --- haystack/document_store/memory.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index 6f6c69aff5..e726b96ee7 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Iterator from uuid import uuid4 from collections import defaultdict @@ -187,9 +187,19 @@ def get_all_documents( index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, return_embedding: Optional[bool] = None, - page_number: Optional[int] = None, - page_size: Optional[int] = None, + batch_size: int = 10_000, ) -> List[Document]: + result = self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding) + documents = list(result) + return documents + + def get_all_documents_generator( + self, + index: Optional[str] = None, + filters: Optional[Dict[str, List[str]]] = None, + return_embedding: Optional[bool] = None, + batch_size: int = 10_000, + ) -> Iterator[Document]: """ Get documents from the document store. @@ -198,11 +208,6 @@ def get_all_documents( :param filters: Optional filters to narrow down the documents to return. Example: {"name": ["some", "more"], "category": ["only_one"]} :param return_embedding: Whether to return the document embeddings. - :param page_number: For getting a large number of documents, the results can be paginated. This - parameter defines the page number to be retrieved starting from the value 0. When using - page_number, the page_size argument must be set. - :param page_size: Number of documents to return in a single page. The page_number argument must be set when - using page_size. """ index = index or self.index documents = deepcopy(list(self.indexes[index].values())) @@ -229,12 +234,7 @@ def get_all_documents( else: filtered_documents = documents - if page_number is not None and page_size is not None: - start_pos = page_number * page_size - end_pos = start_pos + page_size - filtered_documents = filtered_documents[start_pos:end_pos] - - return filtered_documents + yield from filtered_documents def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]: """