Skip to content

Commit

Permalink
Add batch_size and generators to document stores. (#733)
Browse files Browse the repository at this point in the history
* Add batch update of embeddings in document stores

* Resolve merge conflict

* Remove document ordering dependency in tests

* Adjust index buffer size for tests

* Adjust ES Scroll Slice

* Use generator for document store pagination

* Add pagination for InMemoryDocumentStore

* Fix missing index parameter in FAISS update_embeddings()

* Fix FAISS update_embeddings()

* Update FAISS tests

* Update eval tests

* Revert code formatting change

* Fix document count in FAISS update embeddings

* Fix vector_ids reset in SQLDocumentStore

* Update doctrings

* Update docstring
  • Loading branch information
tanaysoni authored Jan 21, 2021
1 parent 0b583b8 commit 337376c
Show file tree
Hide file tree
Showing 10 changed files with 473 additions and 317 deletions.
140 changes: 89 additions & 51 deletions haystack/document_store/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from copy import deepcopy
from string import Template
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Generator
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan
from elasticsearch.exceptions import RequestError
Expand All @@ -13,6 +13,7 @@
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.retriever.base import BaseRetriever
from haystack.utils import get_batches_from_generator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -232,8 +233,9 @@ def get_documents_by_id(self, ids: List[str], index=None) -> List[Document]:
documents = [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
return documents

def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
batch_size: Optional[int] = None):
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
):
"""
Indexes documents for later queries in Elasticsearch.
Expand All @@ -253,7 +255,6 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
should be changed to what you have set for self.text_field and self.name_field.
:param index: Elasticsearch index where the documents should be indexed. If not supplied, self.index will be used.
:param batch_size: Number of documents that are passed to Elasticsearch's bulk function at a time.
If `None`, all documents will be passed to bulk at once.
:return: None
"""

Expand Down Expand Up @@ -298,22 +299,21 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
_doc.pop("meta")
documents_to_index.append(_doc)

if batch_size is not None:
# Pass batch_size number of documents to bulk
if len(documents_to_index) % batch_size == 0:
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
documents_to_index = []
# Pass batch_size number of documents to bulk
if len(documents_to_index) % batch_size == 0:
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
documents_to_index = []

if documents_to_index:
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)

def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None,
batch_size: Optional[int] = None):
def write_labels(
self, labels: Union[List[Label], List[dict]], index: Optional[str] = None, batch_size: int = 10_000
):
"""Write annotation labels into document store.
:param labels: A list of Python dictionaries or a list of Haystack Label objects.
:param batch_size: Number of labels that are passed to Elasticsearch's bulk function at a time.
If `None`, all labels will be passed to bulk at once.
"""
index = index or self.label_index
if index and not self.client.indices.exists(index=index):
Expand All @@ -339,11 +339,10 @@ def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[s

labels_to_index.append(_label)

if batch_size is not None:
# Pass batch_size number of labels to bulk
if len(labels_to_index) % batch_size == 0:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
labels_to_index = []
# Pass batch_size number of labels to bulk
if len(labels_to_index) % batch_size == 0:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
labels_to_index = []

if labels_to_index:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
Expand Down Expand Up @@ -387,10 +386,11 @@ def get_label_count(self, index: Optional[str] = None) -> int:
return self.get_document_count(index=index)

def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
"""
Get documents from the document store.
Expand All @@ -400,28 +400,62 @@ def get_all_documents(
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents

def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""

if index is None:
index = self.index

result = self.get_all_documents_in_index(index=index, filters=filters)
if return_embedding is None:
return_embedding = self.return_embedding
documents = [self._convert_es_hit_to_document(hit, return_embedding=return_embedding) for hit in result]

return documents
result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
for hit in result:
document = self._convert_es_hit_to_document(hit, return_embedding=return_embedding)
yield document

def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
def get_all_labels(
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, batch_size: int = 10_000
) -> List[Label]:
"""
Return all labels in the document store
"""
index = index or self.label_index
result = self.get_all_documents_in_index(index=index, filters=filters)
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
labels = [Label.from_dict(hit["_source"]) for hit in result]
return labels

def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]:
def _get_all_documents_in_index(
self,
index: str,
filters: Optional[Dict[str, List[str]]] = None,
batch_size: int = 10_000,
) -> Generator[dict, None, None]:
"""
Return all documents in a specific index in the document store
"""
Expand All @@ -444,9 +478,9 @@ def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, Lis
}
)
body["query"]["bool"]["filter"] = filter_clause
result = list(scan(self.client, query=body, index=index))

return result
result = scan(self.client, query=body, index=index, size=batch_size, scroll="1d")
yield from result

def query(
self,
Expand Down Expand Up @@ -683,13 +717,14 @@ def describe_documents(self, index=None):
}
return stats

def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever
:param retriever: Retriever to use to update the embeddings.
:param index: Index name to update
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
if index is None:
Expand All @@ -698,26 +733,29 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non
if not self.embedding_field:
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")

# TODO Index embeddings every X batches to avoid OOM for huge document collections
docs = self.get_all_documents(index)
logger.info(f"Updating embeddings for {len(docs)} docs ...")
embeddings = retriever.embed_passages(docs) # type: ignore
assert len(docs) == len(embeddings)

if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(docs, embeddings):
update = {"_op_type": "update",
"_index": index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)

bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
logger.info(f"Updating embeddings for {self.get_document_count(index=index)} docs ...")

result = self.get_all_documents_generator(index, batch_size=batch_size)
for document_batch in get_batches_from_generator(result, batch_size):
if len(document_batch) == 0:
break
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)

if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(document_batch, embeddings):
update = {"_op_type": "update",
"_index": index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)

bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)

def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
"""
Expand Down
Loading

0 comments on commit 337376c

Please sign in to comment.