Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RouteDocuments and JoinAnswers nodes #2256

Merged
merged 19 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
594101d
Add SplitDocumentList and JoinAnswer nodes
bogdankostic Feb 25, 2022
840fede
Update Documentation & Code Style
github-actions[bot] Feb 27, 2022
598af88
Add tests + adapt tutorial
bogdankostic Feb 28, 2022
511f16e
Merge remote-tracking branch 'origin/split_tables_and_texts' into spl…
bogdankostic Feb 28, 2022
e199546
Update Documentation & Code Style
github-actions[bot] Feb 28, 2022
d24fb22
Remove branch from installation path in Tutorial
bogdankostic Feb 28, 2022
bf55469
Merge remote-tracking branch 'origin/split_tables_and_texts' into spl…
bogdankostic Feb 28, 2022
a56532c
Merge branch 'master' into split_tables_and_texts
bogdankostic Feb 28, 2022
5674eff
Update Documentation & Code Style
github-actions[bot] Feb 28, 2022
48198b7
Fix typing
bogdankostic Feb 28, 2022
e25834e
Merge remote-tracking branch 'origin/split_tables_and_texts' into spl…
bogdankostic Feb 28, 2022
665133e
Update Documentation & Code Style
github-actions[bot] Feb 28, 2022
867d5ef
Change name of SplitDocumentList to RouteDocuments
bogdankostic Mar 1, 2022
4b4c6b0
Update Documentation & Code Style
github-actions[bot] Mar 1, 2022
1842da3
Adapt tutorials to new name
bogdankostic Mar 1, 2022
13b0297
Add test for JoinAnswers
bogdankostic Mar 1, 2022
2dec1db
Merge remote-tracking branch 'origin/split_tables_and_texts' into spl…
bogdankostic Mar 1, 2022
a6042b6
Update Documentation & Code Style
github-actions[bot] Mar 1, 2022
2ad75f5
Adapt name of test for JoinAnswers node
bogdankostic Mar 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions haystack/nodes/other/join_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores
of individual `Answer`s.
:param weights: A node-wise list (length of list must be equal to the number of input nodes) of weights for
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assignef to each
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join.
"""
Expand All @@ -36,16 +36,17 @@ def __init__(
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
reader_results = [inp["answers"] for inp in inputs]

if not top_k_join:
top_k_join = self.top_k_join

if self.join_mode == "concatenate":
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
concatenated_answers = sorted(concatenated_answers, reverse=True)
concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join]
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1"

elif self.join_mode == "merge":
merged_answers = self._merge_answers(reader_results)

if not top_k_join:
top_k_join = self.top_k_join if self.top_k_join is not None else len(merged_answers)
merged_answers = merged_answers[:top_k_join]
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1"

Expand Down
34 changes: 25 additions & 9 deletions test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import responses

from haystack import __version__, Document
from haystack import __version__, Document, Answer, JoinAnswers
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
Expand All @@ -19,7 +19,7 @@
from haystack.nodes.retriever.sparse import ElasticsearchRetriever
from haystack.pipelines import Pipeline, DocumentSearchPipeline, RootNode, ExtractiveQAPipeline
from haystack.pipelines.base import _PipelineCodeGen
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever, SplitDocumentList
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever, RouteDocuments

from conftest import MOCK_DC, DC_API_ENDPOINT, DC_API_KEY, DC_TEST_INDEX, SAMPLES_PATH, deepset_cloud_fixture

Expand Down Expand Up @@ -1043,8 +1043,8 @@ def test_documentsearch_document_store_authentication(retriever_with_docs, docum
assert kwargs["headers"] == auth_headers


def test_split_document_list_content_type(test_docs_xs):
# Test splitting by content_type
def test_route_documents_by_content_type():
# Test routing by content_type
docs = [
Document(content="text document", content_type="text"),
Document(
Expand All @@ -1053,17 +1053,19 @@ def test_split_document_list_content_type(test_docs_xs):
),
]

split_documents = SplitDocumentList()
result, _ = split_documents.run(documents=docs)
route_documents = RouteDocuments()
result, _ = route_documents.run(documents=docs)
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert result["output_1"][0].content_type == "text"
assert result["output_2"][0].content_type == "table"

# Test splitting by metadata field

def test_route_documents_by_metafield(test_docs_xs):
# Test routing by metadata field
docs = [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in test_docs_xs]
split_documents = SplitDocumentList(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
result, _ = split_documents.run(docs)
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1
assert len(result["output_3"]) == 1
Expand All @@ -1072,6 +1074,20 @@ def test_split_document_list_content_type(test_docs_xs):
assert result["output_3"][0].meta["meta_field"] == "test5"


@pytest.mark.parametrize("join_mode", ["concatenate", "merge"])
def test_join_answers_concatenate(join_mode):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_join_answers_concatenate is a little bit misleading as you test for "concatenate" and "merge".

inputs =[{"answers": [Answer(answer="answer 1", score=0.7)]}, {"answers": [Answer(answer="answer 2", score=0.8)]}]

join_answers = JoinAnswers(join_mode=join_mode)
result, _ = join_answers.run(inputs)
assert len(result["answers"]) == 2
assert result["answers"] == sorted(result["answers"], reverse=True)

result, _ = join_answers.run(inputs, top_k_join=1)
assert len(result["answers"]) == 1
assert result["answers"][0].answer == "answer 2"


def clean_faiss_document_store():
if Path("existing_faiss_document_store").exists():
os.remove("existing_faiss_document_store")
Expand Down