Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify Summarizer, add Document Merger #3452

Merged
merged 30 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
556bfc7
remove generate_single_summary
anakin87 Oct 21, 2022
8197f48
update schemas
anakin87 Oct 21, 2022
2ff7c52
Merge remote-tracking branch 'upstream/main' into summarizer_refactoring
anakin87 Oct 21, 2022
9e850c2
remove unused import
anakin87 Oct 21, 2022
b69966c
fix mypy
anakin87 Oct 21, 2022
2fcf55d
fix mypy
anakin87 Oct 22, 2022
39c7a84
test: summarizer doesnt change content
anakin87 Oct 22, 2022
5614161
other test correction
anakin87 Oct 22, 2022
5301120
move test_summarizer_translation to test_extractor_translation
anakin87 Oct 23, 2022
4e0f142
fix test
anakin87 Oct 23, 2022
5a8bfc5
Merge branch 'main' into summarizer_refactoring
anakin87 Oct 25, 2022
a23385a
first try for doc merger
anakin87 Oct 27, 2022
4accd52
Merge branch 'main' into summarizer_refactoring
anakin87 Oct 28, 2022
0401461
reintroduce and deprecate generate_single_summary
anakin87 Oct 28, 2022
5a6a537
progress in document merger
anakin87 Oct 30, 2022
42ecf96
document merger!
anakin87 Oct 31, 2022
5fd70e6
Merge branch 'main' into summarizer_refactoring
anakin87 Oct 31, 2022
037cfb7
mypy, pylint fixes
anakin87 Oct 31, 2022
f45d483
use generator
anakin87 Oct 31, 2022
e25fb40
added test that will fail in 1.12
anakin87 Oct 31, 2022
c772803
Merge branch 'main' into summarizer_refactoring
anakin87 Nov 1, 2022
9558c9e
adapt to review
anakin87 Nov 1, 2022
7a66619
merge main
anakin87 Nov 1, 2022
0cc9c5a
extended deprecation docstring
anakin87 Nov 1, 2022
ca79476
Merge branch 'main' into summarizer_refactoring
anakin87 Nov 2, 2022
c9fa988
Update test/nodes/test_extractor_translation.py
ZanSara Nov 3, 2022
435c81e
Update test/nodes/test_summarizer.py
ZanSara Nov 3, 2022
78f2137
Update test/nodes/test_summarizer.py
ZanSara Nov 3, 2022
4ccf07d
black
ZanSara Nov 3, 2022
ab8e6ba
documents fixture
ZanSara Nov 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions haystack/nodes/other/document_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
from copy import deepcopy
from typing import Optional, List, Dict, Union

from haystack.schema import Document
from haystack.nodes.base import BaseComponent

logger = logging.getLogger(__name__)


class DocumentMerger(BaseComponent):
"""
A node to merge the texts of the documents.
"""

outgoing_edges = 1

def __init__(self, separator: str = " "):
"""
:param separator: The separator that appears between subsequent merged documents.
"""
super().__init__()
self.separator = separator

def merge(self, documents: List[Document], separator: Optional[str] = None) -> List[Document]:
"""
Produce a list made up of a single document, which contains all the texts of the documents provided.

:param separator: The separator that appears between subsequent merged documents.
:return: List of Documents
"""
if len(documents) == 0:
raise AttributeError("Document Merger needs at least one document to merge.")
if not all(doc.content_type == "text" for doc in documents):
raise AttributeError(
"Some of the documents provided are non-textual. Document Merger only works on textual documents."
)
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

if separator is None:
separator = self.separator
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

merged_content = separator.join([doc.content for doc in documents])
common_meta = self._extract_common_meta_dict(documents)

merged_document = Document(content=merged_content, meta=common_meta)
return [merged_document]

def run(self, documents: List[Document], separator: Optional[str] = None): # type: ignore
results: Dict = {"documents": []}
if documents:
results["documents"] = self.merge(documents=documents, separator=separator)
return results, "output_1"

def run_batch( # type: ignore
self, documents: Union[List[Document], List[List[Document]]], separator: Optional[str] = None
):
is_doclist_flat = isinstance(documents[0], Document)
if is_doclist_flat:
flat_result: List[Document] = self.merge(
documents=[doc for doc in documents if isinstance(doc, Document)], separator=separator
)
return {"documents": flat_result}, "output_1"
else:
nested_result: List[List[Document]] = [
self.merge(documents=docs_lst, separator=separator)
for docs_lst in documents
if isinstance(docs_lst, list)
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
]
return {"documents": nested_result}, "output_1"

def _extract_common_meta_dict(self, documents: List[Document]) -> dict:
"""
Given a list of documents, extract a dictionary containing the meta fields
that are common to all the documents
"""
flattened_meta = [self._flatten_dict(d.meta) for d in documents]
common_meta_flat_dict = deepcopy(flattened_meta[0])
for doc in flattened_meta[1:]:
if len(common_meta_flat_dict) == 0:
break
for k, v in doc.items():
if k in common_meta_flat_dict:
if common_meta_flat_dict[k] != v:
del common_meta_flat_dict[k]
common_meta_nested_dict = self._nest_dict(common_meta_flat_dict)
return common_meta_nested_dict

def _flatten_dict(self, d: dict, parent_key="") -> dict:
items: List = []
for k, v in d.items():
new_key = (parent_key, k) if parent_key else k
if isinstance(v, dict):
items.extend(self._flatten_dict(v, new_key).items())
else:
items.append((new_key, v))
return dict(items)

def _nest_dict(self, d: dict) -> dict:
nested_dict: dict = {}
for key, value in d.items():
target = nested_dict
if isinstance(key, tuple):
for k in key[:-1]: # traverse all keys but the last
target = target.setdefault(k, {})
target[key[-1]] = value
else:
target[key] = value
while any(isinstance(k, tuple) for k in nested_dict.keys()):
nested_dict = self._nest_dict(nested_dict)
return nested_dict
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 3 additions & 7 deletions haystack/nodes/summarizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ def predict(self, documents: List[Document], generate_single_summary: Optional[b
Abstract method for creating a summary.

:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
:param generate_single_summary: Whether to generate a single summary for all documents or one summary per document.
If set to "True", all docs will be joined to a single string that will then
be summarized.
Important: The summary will depend on the order of the supplied documents!
:return: List of Documents, where Document.content contains the summarization and Document.meta["context"]
the original, not summarized text
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
:return: List of Documents, where Document.meta["summary"] contains the summarization
"""
pass

Expand Down Expand Up @@ -54,7 +50,7 @@ def run_batch( # type: ignore
):

results = self.predict_batch(
documents=documents, generate_single_summary=generate_single_summary, batch_size=batch_size
documents=documents, batch_size=batch_size, generate_single_summary=generate_single_summary
)

return {"documents": results}, "output_1"
159 changes: 67 additions & 92 deletions haystack/nodes/summarizer/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from tqdm.auto import tqdm
from transformers import pipeline
from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM

from haystack.schema import Document
from haystack.nodes.summarizer.base import BaseSummarizer
Expand Down Expand Up @@ -34,19 +33,18 @@ class TransformersSummarizer(BaseSummarizer):
|
| # Summarize
| summary = summarizer.predict(
| documents=docs,
| generate_single_summary=True
| )
| documents=docs)
|
| # Show results (List of Documents, containing summary and original text)
| # Show results (List of Documents, containing summary and original content)
| print(summary)
|
| [
| {
| "text": "California's largest electricity provider has turned off power to hundreds of thousands of customers.",
| "content": "PGE stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. ...",
| ...
| "meta": {
| "context": "PGE stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. ..."
| "summary": "California's largest electricity provider has turned off power to hundreds of thousands of customers.",
| ...
| },
| ...
| },
Expand Down Expand Up @@ -83,12 +81,8 @@ def __init__(
:param min_length: Minimum length of summarized text
:param use_gpu: Whether to use GPU (if available).
:param clean_up_tokenization_spaces: Whether or not to clean up the potential extra spaces in the text output
:param separator_for_single_summary: If `generate_single_summary=True` in `predict()`, we need to join all docs
into a single text. This separator appears between those subsequent docs.
:param generate_single_summary: Whether to generate a single summary for all documents or one summary per document.
If set to "True", all docs will be joined to a single string that will then
be summarized.
Important: The summary will depend on the order of the supplied documents!
:param separator_for_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
:param batch_size: Number of documents to process at a time.
:param progress_bar: Whether to show a progress bar.
:param use_auth_token: The API token used to download private models from Huggingface.
Expand All @@ -103,27 +97,34 @@ def __init__(
"""
super().__init__()

if generate_single_summary is True:
raise ValueError(
"'generate_single_summary' has been removed. Instead, you can use the Document Merger to merge documents before applying the Summarizer."
)
self.separator_for_single_summary = separator_for_single_summary
self.generate_single_summary = generate_single_summary
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
if len(self.devices) > 1:
logger.warning(
f"Multiple devices are not supported in {self.__class__.__name__} inference, "
f"using the first device {self.devices[0]}."
)

# TODO AutoModelForSeq2SeqLM is only necessary with transformers==4.1.1, with newer versions use the pipeline directly
if tokenizer is None:
tokenizer = model_name_or_path
model = AutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version, use_auth_token=use_auth_token
)
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

self.summarizer = pipeline(
"summarization", model=model, tokenizer=tokenizer, device=self.devices[0], use_auth_token=use_auth_token
task="summarization",
model=model_name_or_path,
tokenizer=tokenizer,
revision=model_version,
device=self.devices[0],
use_auth_token=use_auth_token,
)
self.max_length = max_length
self.min_length = min_length
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.separator_for_single_summary = separator_for_single_summary
self.generate_single_summary = generate_single_summary
self.print_log: Set[str] = set()
self.batch_size = batch_size
self.progress_bar = progress_bar
Expand All @@ -134,29 +135,22 @@ def predict(self, documents: List[Document], generate_single_summary: Optional[b
These document can for example be retrieved via the Retriever.

:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
:param generate_single_summary: Whether to generate a single summary for all documents or one summary per document.
If set to "True", all docs will be joined to a single string that will then
be summarized.
Important: The summary will depend on the order of the supplied documents!
:return: List of Documents, where Document.text contains the summarization and Document.meta["context"]
the original, not summarized text
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
:return: List of Documents, where Document.meta["summary"] contains the summarization
"""
if generate_single_summary is True:
raise ValueError(
"'generate_single_summary' has been removed. Instead, you can use the Document Merger to merge documents before applying the Summarizer."
)

if self.min_length > self.max_length:
raise AttributeError("min_length cannot be greater than max_length")

if len(documents) == 0:
raise AttributeError("Summarizer needs at least one document to produce a summary.")

if generate_single_summary is None:
generate_single_summary = self.generate_single_summary

contexts: List[str] = [doc.content for doc in documents]

if generate_single_summary:
# Documents order is very important to produce summary.
# Different order of same documents produce different summary.
contexts = [self.separator_for_single_summary.join(contexts)]

encoded_input = self.summarizer.tokenizer(contexts, verbose=False)
for input_id in encoded_input["input_ids"]:
tokens_count: int = len(input_id)
Expand All @@ -182,15 +176,9 @@ def predict(self, documents: List[Document], generate_single_summary: Optional[b

result: List[Document] = []

if generate_single_summary:
for context, summarized_answer in zip(contexts, summaries):
cur_doc = Document(content=summarized_answer["summary_text"], meta={"context": context})
result.append(cur_doc)
else:
for context, summarized_answer, document in zip(contexts, summaries, documents):
cur_doc = Document(content=summarized_answer["summary_text"], meta=document.meta)
cur_doc.meta.update({"context": context})
result.append(cur_doc)
for summary, document in zip(summaries, documents):
document.meta.update({"summary": summary["summary_text"]})
result.append(document)

return result

Expand All @@ -206,13 +194,13 @@ def predict_batch(

:param documents: Single list of related documents or list of lists of related documents
(e.g. coming from a retriever) that the answer shall be conditioned on.
:param generate_single_summary: Whether to generate a single summary for each provided document list or
one summary per document.
If set to "True", all docs of a document list will be joined to a single string
that will then be summarized.
Important: The summary will depend on the order of the supplied documents!
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
:param batch_size: Number of Documents to process at a time.
"""
if generate_single_summary is True:
raise ValueError(
"'generate_single_summary' has been removed. Instead, you can use the Document Merger to merge documents before applying the Summarizer."
)

if self.min_length > self.max_length:
raise AttributeError("min_length cannot be greater than max_length")
Expand All @@ -225,34 +213,17 @@ def predict_batch(
if batch_size is None:
batch_size = self.batch_size

if generate_single_summary is None:
generate_single_summary = self.generate_single_summary

single_doc_list = False
if isinstance(documents[0], Document):
single_doc_list = True

if single_doc_list:
is_doclist_flat = isinstance(documents[0], Document)
if is_doclist_flat:
contexts = [doc.content for doc in documents if isinstance(doc, Document)]
else:
contexts = [
[doc.content for doc in docs if isinstance(doc, Document)]
for docs in documents
if isinstance(docs, list)
]

if generate_single_summary:
if single_doc_list:
contexts = [self.separator_for_single_summary.join(contexts)]
else:
contexts = [self.separator_for_single_summary.join(context_group) for context_group in contexts]
number_of_docs = [1 for _ in contexts]
else:
if single_doc_list:
number_of_docs = [1 for _ in contexts]
else:
number_of_docs = [len(context_group) for context_group in contexts]
contexts = list(itertools.chain.from_iterable(contexts))
number_of_docs = [len(context_group) for context_group in contexts]
contexts = list(itertools.chain.from_iterable(contexts))

encoded_input = self.summarizer.tokenizer(contexts, verbose=False)
for input_id in encoded_input["input_ids"]:
Expand Down Expand Up @@ -286,26 +257,30 @@ def predict_batch(
):
summaries.extend(summary_batch)

# Group summaries together
grouped_summaries = []
grouped_contexts = []
left_idx = 0
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_summaries.append(summaries[left_idx:right_idx])
grouped_contexts.append(contexts[left_idx:right_idx])
left_idx = right_idx

result = []
for summary_group, context_group in zip(grouped_summaries, grouped_contexts):
cur_summaries = [
Document(content=summary["summary_text"], meta={"context": context})
for summary, context in zip(summary_group, context_group)
]
if single_doc_list:
result.append(cur_summaries[0])
else:
result.append(cur_summaries) # type: ignore

return result
if is_doclist_flat:
flat_result: List[Document] = []
flat_doc_list: List[Document] = [doc for doc in documents if isinstance(doc, Document)]
for summary, document in zip(summaries, flat_doc_list):
document.meta.update({"summary": summary["summary_text"]})
flat_result.append(document)
return flat_result
else:
nested_result: List[List[Document]] = []
nested_doc_list: List[List[Document]] = [lst for lst in documents if isinstance(lst, list)]

# Group summaries together
grouped_summaries = []
left_idx = 0
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_summaries.append(summaries[left_idx:right_idx])
left_idx = right_idx

for summary_group, docs_group in zip(grouped_summaries, nested_doc_list):
cur_summaries = []
for summary, document in zip(summary_group, docs_group):
document.meta.update({"summary": summary["summary_text"]})
cur_summaries.append(document)
nested_result.append(cur_summaries)
return nested_result
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
Loading