-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add SplitDocumentList and JoinAnswer nodes * Update Documentation & Code Style * Add tests + adapt tutorial * Update Documentation & Code Style * Remove branch from installation path in Tutorial * Update Documentation & Code Style * Fix typing * Update Documentation & Code Style * Change name of SplitDocumentList to RouteDocuments * Update Documentation & Code Style * Adapt tutorials to new name * Add test for JoinAnswers * Update Documentation & Code Style * Adapt name of test for JoinAnswers node Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
11eebf8
commit c5542bd
Showing
12 changed files
with
1,160 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from haystack.nodes.other.docs2answers import Docs2Answers | ||
from haystack.nodes.other.join_docs import JoinDocuments | ||
from haystack.nodes.other.route_documents import RouteDocuments | ||
from haystack.nodes.other.join_answers import JoinAnswers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Optional, List, Dict, Tuple | ||
|
||
from haystack.schema import Answer | ||
from haystack.nodes import BaseComponent | ||
|
||
|
||
class JoinAnswers(BaseComponent): | ||
""" | ||
A node to join `Answer`s produced by multiple `Reader` nodes. | ||
""" | ||
|
||
def __init__( | ||
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None | ||
): | ||
""" | ||
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores | ||
of individual `Answer`s. | ||
:param weights: A node-wise list (length of list must be equal to the number of input nodes) of weights for | ||
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each | ||
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode. | ||
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join. | ||
""" | ||
|
||
assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode." | ||
assert not ( | ||
weights is not None and join_mode == "concatenate" | ||
), "Weights are not compatible with 'concatenate' join_mode" | ||
|
||
# Save init parameters to enable export of component config as YAML | ||
self.set_config(join_mode=join_mode, weights=weights, top_k_join=top_k_join) | ||
|
||
self.join_mode = join_mode | ||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None | ||
self.top_k_join = top_k_join | ||
|
||
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore | ||
reader_results = [inp["answers"] for inp in inputs] | ||
|
||
if not top_k_join: | ||
top_k_join = self.top_k_join | ||
|
||
if self.join_mode == "concatenate": | ||
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result] | ||
concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join] | ||
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1" | ||
|
||
elif self.join_mode == "merge": | ||
merged_answers = self._merge_answers(reader_results) | ||
|
||
merged_answers = merged_answers[:top_k_join] | ||
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1" | ||
|
||
else: | ||
raise ValueError(f"Invalid join_mode: {self.join_mode}") | ||
|
||
def _merge_answers(self, reader_results: List[List[Answer]]) -> List[Answer]: | ||
weights = self.weights if self.weights else [1 / len(reader_results)] * len(reader_results) | ||
|
||
for result, weight in zip(reader_results, weights): | ||
for answer in result: | ||
if isinstance(answer.score, float): | ||
answer.score *= weight | ||
|
||
return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import List, Tuple, Dict, Optional | ||
|
||
from haystack.nodes.base import BaseComponent | ||
from haystack.schema import Document | ||
|
||
|
||
class RouteDocuments(BaseComponent): | ||
""" | ||
A node to split a list of `Document`s by `content_type` or by the values of a metadata field and route them to | ||
different nodes. | ||
""" | ||
|
||
# By default (split_by == "content_type"), the node has two outgoing edges. | ||
outgoing_edges = 2 | ||
|
||
def __init__(self, split_by: str = "content_type", metadata_values: Optional[List[str]] = None): | ||
""" | ||
:param split_by: Field to split the documents by, either `"content_type"` or a metadata field name. | ||
If this parameter is set to `"content_type"`, the list of `Document`s will be split into a list containing | ||
only `Document`s of type `"text"` (will be routed to `"output_1"`) and a list containing only `Document`s of | ||
type `"text"` (will be routed to `"output_2"`). | ||
If this parameter is set to a metadata field name, you need to specify the parameter `metadata_values` as | ||
well. | ||
:param metadata_values: If the parameter `split_by` is set to a metadata field name, you need to provide a list | ||
of values to group the `Document`s to. `Document`s whose metadata field is equal to the first value of the | ||
provided list will be routed to `"output_1"`, `Document`s whose metadata field is equal to the second | ||
value of the provided list will be routed to `"output_2"`, etc. | ||
""" | ||
|
||
assert split_by == "content_type" or metadata_values is not None, ( | ||
"If split_by is set to the name of a metadata field, you must provide metadata_values " | ||
"to group the documents to." | ||
) | ||
|
||
# Save init parameters to enable export of component config as YAML | ||
self.set_config(split_by=split_by, metadata_values=metadata_values) | ||
|
||
self.split_by = split_by | ||
self.metadata_values = metadata_values | ||
|
||
# If we split list of Documents by a metadata field, number of outgoing edges might change | ||
if split_by != "content_type" and metadata_values is not None: | ||
self.outgoing_edges = len(metadata_values) | ||
|
||
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore | ||
if self.split_by == "content_type": | ||
split_documents: Dict[str, List[Document]] = {"output_1": [], "output_2": []} | ||
|
||
for doc in documents: | ||
if doc.content_type == "text": | ||
split_documents["output_1"].append(doc) | ||
elif doc.content_type == "table": | ||
split_documents["output_2"].append(doc) | ||
|
||
else: | ||
assert isinstance(self.metadata_values, list), ( | ||
"You need to provide metadata_values if you want to split" " a list of Documents by a metadata field." | ||
) | ||
split_documents = {f"output_{i+1}": [] for i in range(len(self.metadata_values))} | ||
for doc in documents: | ||
current_metadata_value = doc.meta.get(self.split_by, None) | ||
# Disregard current document if it does not contain the provided metadata field | ||
if current_metadata_value is not None: | ||
try: | ||
index = self.metadata_values.index(current_metadata_value) | ||
except ValueError: | ||
# Disregard current document if current_metadata_value is not in the provided metadata_values | ||
continue | ||
|
||
split_documents[f"output_{index+1}"].append(doc) | ||
|
||
return split_documents, "split_documents" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.