Skip to content

Commit

Permalink
fix: Send batches of query-doc pairs to inference_from_objects (#5125)
Browse files Browse the repository at this point in the history
* Send batches of query-doc pairs to inference_from_objects

* Use absolute import path

* Add separate preprocessing_batch_size parameter
  • Loading branch information
bogdankostic authored Jun 26, 2023
1 parent f4e18e9 commit 82291b5
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 25 deletions.
12 changes: 0 additions & 12 deletions haystack/document_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import collections
from pathlib import Path
from itertools import islice
from abc import abstractmethod

import numpy as np
Expand Down Expand Up @@ -897,14 +896,3 @@ def query_batch(
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
"""
pass


def get_batches_from_generator(iterable, n):
"""
Batch elements of an iterable into fixed-length chunks or blocks.
"""
it = iter(iterable)
x = tuple(islice(it, n))
while x:
yield x
x = tuple(islice(it, n))
2 changes: 1 addition & 1 deletion haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm.auto import tqdm

from haystack.schema import Document, FilterType
from haystack.document_stores.base import get_batches_from_generator
from haystack.utils.batching import get_batches_from_generator
from haystack.nodes.retriever import DenseRetriever
from haystack.document_stores.sql import SQLDocumentStore
from haystack.lazy_imports import LazyImport
Expand Down
2 changes: 1 addition & 1 deletion haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from haystack.schema import Document, FilterType, Label
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.document_stores import KeywordDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.utils.batching import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.nodes.retriever.dense import DenseRetriever
from haystack.utils.scipy_utils import expit
Expand Down
2 changes: 1 addition & 1 deletion haystack/document_stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tenacity import retry, wait_exponential, retry_if_not_result

from haystack.schema import Document, FilterType
from haystack.document_stores.base import get_batches_from_generator
from haystack.utils.batching import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError
from haystack.nodes.retriever import DenseRetriever
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from haystack.document_stores import KeywordDocumentStore
from haystack.schema import Document, FilterType, Label
from haystack.document_stores.base import get_batches_from_generator
from haystack.utils.batching import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError, HaystackError
from haystack.nodes.retriever import DenseRetriever
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def query_batch(

body = []
all_documents = []
for query, cur_filters in zip(queries, filters):
for query, cur_filters in tqdm(zip(queries, filters)):
cur_query_body = self._construct_query_body(
query=query,
filters=cur_filters,
Expand Down
2 changes: 1 addition & 1 deletion haystack/document_stores/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from haystack.schema import Document, FilterType, Label
from haystack.document_stores import KeywordDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.utils.batching import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.document_stores.utils import convert_date_to_rfc3339
from haystack.errors import DocumentStoreError, HaystackError
Expand Down
24 changes: 18 additions & 6 deletions haystack/nodes/reader/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from haystack.schema import Document, Answer, Span
from haystack.document_stores.base import BaseDocumentStore
from haystack.nodes.reader.base import BaseReader
from haystack.utils import get_batches_from_generator
from haystack.utils.early_stopping import EarlyStopping
from haystack.telemetry import send_event
from haystack.lazy_imports import LazyImport
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
force_download=False,
use_auth_token: Optional[Union[str, bool]] = None,
max_query_length: int = 64,
preprocessing_batch_size: Optional[int] = None,
):
"""
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bert-base-cased',
Expand Down Expand Up @@ -135,6 +137,9 @@ def __init__(
Additional information can be found here
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:param max_query_length: Maximum length of the question in number of tokens.
:param preprocessing_batch_size: Number of query-document pairs to be preprocessed (= tokenized, put into
tensors, etc.) at once. If `None` (default), all query-document pairs are
preprocessed at once.
"""
torch_and_transformers_import.check()

Expand Down Expand Up @@ -175,6 +180,7 @@ def __init__(
self.use_confidence_scores = use_confidence_scores
self.confidence_threshold = confidence_threshold
self.model_name_or_path = model_name_or_path # Used in distillation, see DistillationDataSilo._get_checksum()
self.preprocessing_batch_size = preprocessing_batch_size

def _training_procedure(
self,
Expand Down Expand Up @@ -851,9 +857,12 @@ def predict_batch(
if batch_size is not None:
self.inferencer.batch_size = batch_size
# Make predictions on all document-query pairs
predictions = self.inferencer.inference_from_objects(
objects=inputs, return_json=False, multiprocessing_chunksize=10
)
predictions = []
for input_batch in get_batches_from_generator(inputs, self.preprocessing_batch_size):
cur_predictions = self.inferencer.inference_from_objects(
objects=input_batch, return_json=False, multiprocessing_chunksize=10
)
predictions.extend(cur_predictions)

# Group predictions together
grouped_predictions = []
Expand Down Expand Up @@ -917,9 +926,12 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =

# get answers from QA model
# TODO: Need fix in FARM's `to_dict` function of `QAInput` class
predictions = self.inferencer.inference_from_objects(
objects=inputs, return_json=False, multiprocessing_chunksize=1
)
predictions = []
for input_batch in get_batches_from_generator(inputs, self.preprocessing_batch_size):
cur_predictions = self.inferencer.inference_from_objects(
objects=input_batch, return_json=False, multiprocessing_chunksize=1
)
predictions.extend(cur_predictions)
# Deduplicate same answers resulting from Document split overlap
predictions = self._deduplicate_predictions(predictions, documents)
# assemble answers from all the different documents & format them.
Expand Down
1 change: 1 addition & 0 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
)
from haystack.utils.early_stopping import EarlyStopping
from haystack.utils.labels import aggregate_labels
from haystack.utils.batching import get_batches_from_generator
12 changes: 12 additions & 0 deletions haystack/utils/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from itertools import islice


def get_batches_from_generator(iterable, n):
"""
Batch elements of an iterable into fixed-length chunks or blocks.
"""
it = iter(iterable)
x = tuple(islice(it, n))
while x:
yield x
x = tuple(islice(it, n))
22 changes: 21 additions & 1 deletion test/nodes/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import os
from pathlib import Path
from shutil import rmtree
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -496,3 +496,23 @@ def test_reader_long_document(reader):
res = reader.predict(query="Where does Christelle live?", documents=docs)
assert res["answers"][0].offsets_in_document[0].start >= 0
assert res["answers"][0].offsets_in_document[0].end >= 0


@pytest.mark.unit
@patch("haystack.nodes.reader.farm.QAInferencer")
def test_farmreader_predict_preprocessor_batching(mocked_qa_inferencer, docs):
reader = FARMReader(model_name_or_path="mocked_model", preprocessing_batch_size=2)
reader.predict(query="sample query", documents=docs)

# We expect 3 calls to the QAInferencer (5 docs / 2 batch_size)
assert reader.inferencer.inference_from_objects.call_count == 3


@pytest.mark.unit
@patch("haystack.nodes.reader.farm.QAInferencer")
def test_farmreader_predict_batch_preprocessor_batching(mocked_qa_inferencer, docs):
reader = FARMReader(model_name_or_path="mocked_model", preprocessing_batch_size=2)
reader.predict_batch(queries=["sample query 1", "sample_query_2"], documents=docs)

# We expect 5 calls to the QAInferencer (2 queries * 5 docs / 2 batch_size)
assert reader.inferencer.inference_from_objects.call_count == 5

0 comments on commit 82291b5

Please sign in to comment.