From edb40b6c1b453f186c41f9de43f7054bc723a386 Mon Sep 17 00:00:00 2001 From: pandasar13 <146428662+pandasar13@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:27:24 +0100 Subject: [PATCH] refactor: add batch_size to FAISS __init__ (#6401) * refactor: add batch_size to FAISS __init__ * refactor: add batch_size to FAISS __init__ * add release note to refactor: add batch_size to FAISS __init__ * fix release note * add batch_size to docstrings --------- Co-authored-by: anakin87 --- haystack/document_stores/faiss.py | 21 ++++++++++++++----- ...atch_size_faiss_init-5e97c1fb9409f873.yaml | 5 +++++ 2 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/add-batch_size_faiss_init-5e97c1fb9409f873.yaml diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index fe8464be82..258b874553 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -57,6 +57,7 @@ def __init__( ef_search: int = 20, ef_construction: int = 80, validate_index_sync: bool = True, + batch_size: int = 10_000, ): """ :param sql_url: SQL connection URL for the database. The default value is "sqlite:///faiss_document_store.db"`. It defaults to a local, file-based SQLite DB. For large scale deployment, we recommend Postgres. @@ -103,6 +104,8 @@ def __init__( :param ef_search: Used only if `index_factory == "HNSW"`. :param ef_construction: Used only if `index_factory == "HNSW"`. :param validate_index_sync: Checks if the document count equals the embedding count at initialization time. + :param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face + memory issues, decrease the batch_size. """ faiss_import.check() # special case if we want to load an existing index from disk @@ -152,6 +155,7 @@ def __init__( self.return_embedding = return_embedding self.embedding_field = embedding_field + self.batch_size = batch_size self.progress_bar = progress_bar @@ -216,7 +220,7 @@ def write_documents( self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, duplicate_documents: Optional[str] = None, headers: Optional[Dict[str, str]] = None, ) -> None: @@ -240,6 +244,8 @@ def write_documents( raise NotImplementedError("FAISSDocumentStore does not support headers.") index = index or self.index + batch_size = batch_size or self.batch_size + duplicate_documents = duplicate_documents or self.duplicate_documents assert ( duplicate_documents in self.duplicate_documents_options @@ -324,7 +330,7 @@ def update_embeddings( index: Optional[str] = None, update_existing_embeddings: bool = True, filters: Optional[FilterType] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, ): """ Updates the embeddings in the the document store using the encoding model specified in the retriever. @@ -342,6 +348,7 @@ def update_embeddings( :return: None """ index = index or self.index + batch_size = batch_size or self.batch_size if update_existing_embeddings is True: if filters is None: @@ -404,9 +411,10 @@ def get_all_documents( index: Optional[str] = None, filters: Optional[FilterType] = None, return_embedding: Optional[bool] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ) -> List[Document]: + batch_size = batch_size or self.batch_size if headers: raise NotImplementedError("FAISSDocumentStore does not support headers.") @@ -421,7 +429,7 @@ def get_all_documents_generator( index: Optional[str] = None, filters: Optional[FilterType] = None, return_embedding: Optional[bool] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ) -> Generator[Document, None, None]: """ @@ -440,6 +448,7 @@ def get_all_documents_generator( raise NotImplementedError("FAISSDocumentStore does not support headers.") index = index or self.index + batch_size = batch_size or self.batch_size documents = super(FAISSDocumentStore, self).get_all_documents_generator( index=index, filters=filters, batch_size=batch_size, return_embedding=False ) @@ -455,13 +464,15 @@ def get_documents_by_id( self, ids: List[str], index: Optional[str] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ) -> List[Document]: if headers: raise NotImplementedError("FAISSDocumentStore does not support headers.") index = index or self.index + batch_size = batch_size or self.batch_size + documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size) if self.return_embedding: for doc in documents: diff --git a/releasenotes/notes/add-batch_size_faiss_init-5e97c1fb9409f873.yaml b/releasenotes/notes/add-batch_size_faiss_init-5e97c1fb9409f873.yaml new file mode 100644 index 0000000000..7d0956d682 --- /dev/null +++ b/releasenotes/notes/add-batch_size_faiss_init-5e97c1fb9409f873.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Add batch_size to the __init__ method of FAISS Document Store. This works as the default value for all methods of + FAISS Document Store that support batch_size. \ No newline at end of file