Skip to content

Commit

Permalink
Make join nodes work when no docs are provided
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Apr 10, 2024
1 parent d8eca03 commit ec16a14
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
9 changes: 8 additions & 1 deletion haystack/nodes/other/join_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __init__(

def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
results = [inp["documents"] for inp in inputs]

# Check if all results are non-empty
if all(res is None for res in results) or all(res == [] for res in results):
return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1"

document_map = {doc.id: doc for result in results for doc in result}

if self.join_mode == "concatenate":
Expand Down Expand Up @@ -100,7 +105,9 @@ def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None)

def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
# Join single document lists
if isinstance(inputs[0]["documents"][0], Document):
if inputs[0]["documents"] is None or inputs[0]["documents"] == []:
return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1"
elif isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join)
# Join lists of document lists
else:
Expand Down
32 changes: 31 additions & 1 deletion test/nodes/test_join_documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


from haystack import Document
from haystack import Document, Pipeline
from haystack.nodes.other.join_docs import JoinDocuments
from copy import deepcopy

Expand Down Expand Up @@ -149,3 +149,33 @@ def test_joindocuments_rrf_weights():
assert result_none["documents"] == result_even["documents"]
assert result_uneven["documents"] != result_none["documents"]
assert result_uneven["documents"][0].score > result_none["documents"][0].score


@pytest.mark.unit
def test_join_node_empty_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=[])
assert len(output["documents"]) == 0

# Test lists of document lists
output = join_node.run_batch(queries=["test"], documents=[])
assert len(output[0]["documents"]) == 0


@pytest.mark.unit
def test_join_node_none_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=None)
assert len(output["documents"]) == 0

# Test lists of document lists
output = join_node.run_batch(queries=["test"], documents=None)
assert len(output[0]["documents"]) == 0

0 comments on commit ec16a14

Please sign in to comment.