From ec16a1400ae676b46c6bfd3c55f38339ba86790b Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 10 Apr 2024 14:22:20 +0200 Subject: [PATCH] Make join nodes work when no docs are provided --- haystack/nodes/other/join_docs.py | 9 ++++++++- test/nodes/test_join_documents.py | 32 ++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/haystack/nodes/other/join_docs.py b/haystack/nodes/other/join_docs.py index 3d682c0f64..a74472010a 100644 --- a/haystack/nodes/other/join_docs.py +++ b/haystack/nodes/other/join_docs.py @@ -60,6 +60,11 @@ def __init__( def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: results = [inp["documents"] for inp in inputs] + + # Check if all results are non-empty + if all(res is None for res in results) or all(res == [] for res in results): + return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1" + document_map = {doc.id: doc for result in results for doc in result} if self.join_mode == "concatenate": @@ -100,7 +105,9 @@ def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # Join single document lists - if isinstance(inputs[0]["documents"][0], Document): + if inputs[0]["documents"] is None or inputs[0]["documents"] == []: + return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1" + elif isinstance(inputs[0]["documents"][0], Document): return self.run(inputs=inputs, top_k_join=top_k_join) # Join lists of document lists else: diff --git a/test/nodes/test_join_documents.py b/test/nodes/test_join_documents.py index ae809f4994..ac92616fb7 100644 --- a/test/nodes/test_join_documents.py +++ b/test/nodes/test_join_documents.py @@ -1,7 +1,7 @@ import pytest -from haystack import Document +from haystack import Document, Pipeline from haystack.nodes.other.join_docs import JoinDocuments from copy import deepcopy @@ -149,3 +149,33 @@ def test_joindocuments_rrf_weights(): assert result_none["documents"] == result_even["documents"] assert result_uneven["documents"] != result_none["documents"] assert result_uneven["documents"][0].score > result_none["documents"][0].score + + +@pytest.mark.unit +def test_join_node_empty_documents(): + pipe = Pipeline() + join_node = JoinDocuments(join_mode="concatenate") + pipe.add_node(component=join_node, name="Join", inputs=["Query"]) + + # Test single document lists + output = pipe.run(query="test", documents=[]) + assert len(output["documents"]) == 0 + + # Test lists of document lists + output = join_node.run_batch(queries=["test"], documents=[]) + assert len(output[0]["documents"]) == 0 + + +@pytest.mark.unit +def test_join_node_none_documents(): + pipe = Pipeline() + join_node = JoinDocuments(join_mode="concatenate") + pipe.add_node(component=join_node, name="Join", inputs=["Query"]) + + # Test single document lists + output = pipe.run(query="test", documents=None) + assert len(output["documents"]) == 0 + + # Test lists of document lists + output = join_node.run_batch(queries=["test"], documents=None) + assert len(output[0]["documents"]) == 0