Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi Reference ROUGEScore #680

Merged
merged 30 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5d35382
Multiple references - initial commit
ashutoshml Dec 10, 2021
cf8d420
Added support for multi-reference rouge scorer based on maximum pairw…
ashutoshml Dec 12, 2021
2b1673f
Update format for compliance
ashutoshml Dec 12, 2021
2228f35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2021
254b58e
Merge branch 'multirouge' of github.com:ashutoshml/metrics into multi…
ashutoshml Dec 12, 2021
8b4b3cf
Merge collapsible 'if' statements
ashutoshml Dec 12, 2021
32c1f12
Merge branch 'master' into multirouge
ashutoshml Dec 12, 2021
557e063
Add average strategy to multi-reference rouge
ashutoshml Dec 12, 2021
1a85c9d
Merge branch 'multirouge' of github.com:ashutoshml/metrics into multi…
ashutoshml Dec 12, 2021
edddd8d
Fix issue with docstring
ashutoshml Dec 12, 2021
61b4090
Fix type errors in methods
ashutoshml Dec 13, 2021
9c46488
Merge branch 'master' into multirouge
ashutoshml Dec 14, 2021
e604ac9
changelog
SkafteNicki Dec 15, 2021
ae1a20b
Review comments incorporated; Average accumulate pending
ashutoshml Dec 15, 2021
0b591b5
Merge branch 'master' into multirouge
ashutoshml Dec 15, 2021
ac3019c
Add support for avg. accumulate
ashutoshml Dec 15, 2021
3ef233a
Merge branch 'multirouge' of github.com:ashutoshml/metrics into multi…
ashutoshml Dec 15, 2021
596d431
Fix naming conventions
ashutoshml Dec 15, 2021
74202eb
fix badge
Borda Dec 15, 2021
5a59e49
Merge branch 'master' into multirouge
ashutoshml Dec 16, 2021
c19b7c9
Merge branch 'master' into multirouge
ashutoshml Dec 16, 2021
f867c82
Merge branch 'master' into multirouge
Borda Dec 16, 2021
2a87d51
Update importing of Literal from typing to typing_extensions
ashutoshml Dec 16, 2021
a45f16a
Merge branch 'multirouge' of github.com:ashutoshml/metrics into multi…
ashutoshml Dec 16, 2021
2ab4bc2
Move test cases to common file
ashutoshml Dec 16, 2021
93cf0e6
Apply suggestions from code review
Borda Dec 16, 2021
67d0335
Fix typing issues with variable for mypy
ashutoshml Dec 17, 2021
da9dfff
Remove redundant clause in if statement
ashutoshml Dec 17, 2021
d0290ca
Merge branch 'master' into multirouge
mergify[bot] Dec 17, 2021
f6b6ff9
Apply suggestions from code review - Remove redundant information + F…
ashutoshml Dec 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))


- Added support for multi references in `ROUGEScore` ([#680](https://github.com/PyTorchLightning/metrics/pull/680))


### Changed

- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))
Expand Down
2 changes: 2 additions & 0 deletions tests/text/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

_inputs_multiple_references = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_REFERENCES)

_inputs_single_sentence_single_reference = Input(preds=HYPOTHESIS_B, targets=REFERENCE_1B)

ERROR_RATES_BATCHES_1 = {
"preds": [["hello world"], ["what a day"]],
"targets": [["hello world"], ["what a wonderful day"]],
Expand Down
89 changes: 55 additions & 34 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

from functools import partial
from typing import List
from typing import Sequence

import pytest
import torch

from tests.text.helpers import INPUT_ORDER, TextTester
from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_single_reference
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE
Expand All @@ -30,34 +32,45 @@

ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")

SINGLE_SENTENCE_EXAMPLE_PREDS = "The quick brown fox jumps over the lazy dog"
SINGLE_SENTENCE_EXAMPLE_TARGET = "The quick brown dog jumps on the log."

PREDS = "My name is John"
TARGETS = "Is your name John"
def _compute_rouge_score(
preds: Sequence[str],
targets: Sequence[Sequence[str]],
use_stemmer: bool,
rouge_level: str,
metric: str,
accumulate: str,
):
"""Evaluates rouge scores from rouge-score package for baseline evaluation."""
if isinstance(targets, list) and all(isinstance(target, str) for target in targets):
targets = [targets] if isinstance(preds, str) else [[target] for target in targets]


BATCHES_1 = {
"preds": [["the cat was under the bed"], ["the cat was found under the bed"]],
"targets": [["the cat was found under the bed"], ["the tiny little cat was found under the big funny bed "]],
}


BATCHES_2 = {
"preds": [["The quick brown fox jumps over the lazy dog"], ["My name is John"]],
"targets": [["The quick brown dog jumps on the log."], ["Is your name John"]],
}


def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool, rouge_level: str, metric: str):
if isinstance(preds, str):
preds = [preds]

if isinstance(targets, str):
targets = [targets]
targets = [[targets]]

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = BootstrapAggregator()
for pred, target in zip(preds, targets):
aggregator.add_scores(scorer.score(target, pred))

for target_raw, pred_raw in zip(targets, preds):
list_results = [scorer.score(target, pred_raw) for target in target_raw]
aggregator_avg = BootstrapAggregator()

if accumulate == "best":
key_curr = list(list_results[0].keys())[0]
all_fmeasure = torch.tensor([v[key_curr].fmeasure for v in list_results])
highest_idx = torch.argmax(all_fmeasure).item()
aggregator.add_scores(list_results[highest_idx])
elif accumulate == "avg":
for _score in list_results:
aggregator_avg.add_scores(_score)
_score = {rouge_key: scores.mid for rouge_key, scores in aggregator_avg.aggregate().items()}
aggregator.add_scores(_score)
else:
raise ValueError(f"Got unknown accumulate value {accumulate}. Expected to be one of ['best', 'avg']")

rs_scores = aggregator.aggregate()
rs_result = getattr(rs_scores[rouge_level].mid, metric)
return rs_result
Expand All @@ -84,19 +97,21 @@ def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool
@pytest.mark.parametrize(
["preds", "targets"],
[
(BATCHES_1["preds"], BATCHES_1["targets"]),
(BATCHES_2["preds"], BATCHES_2["targets"]),
(_inputs_multiple_references.preds, _inputs_multiple_references.targets),
],
)
@pytest.mark.parametrize("accumulate", ["avg", "best"])
class TestROUGEScore(TextTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_rouge_score_class(self, ddp, dist_sync_on_step, preds, targets, pl_rouge_metric_key, use_stemmer):
metric_args = {"use_stemmer": use_stemmer}

def test_rouge_score_class(
self, ddp, dist_sync_on_step, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate
):
metric_args = {"use_stemmer": use_stemmer, "accumulate": accumulate}
rouge_level, metric = pl_rouge_metric_key.split("_")
rouge_metric = partial(_compute_rouge_score, use_stemmer=use_stemmer, rouge_level=rouge_level, metric=metric)

rouge_metric = partial(
_compute_rouge_score, use_stemmer=use_stemmer, rouge_level=rouge_level, metric=metric, accumulate=accumulate
)
self.run_class_metric_test(
ddp=ddp,
preds=preds,
Expand All @@ -109,12 +124,13 @@ def test_rouge_score_class(self, ddp, dist_sync_on_step, preds, targets, pl_roug
key=pl_rouge_metric_key,
)

def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_stemmer):
metric_args = {"use_stemmer": use_stemmer}
def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate):
metric_args = {"use_stemmer": use_stemmer, "accumulate": accumulate}

rouge_level, metric = pl_rouge_metric_key.split("_")
rouge_metric = partial(_compute_rouge_score, use_stemmer=use_stemmer, rouge_level=rouge_level, metric=metric)

rouge_metric = partial(
_compute_rouge_score, use_stemmer=use_stemmer, rouge_level=rouge_level, metric=metric, accumulate=accumulate
)
self.run_functional_metric_test(
preds,
targets,
Expand Down Expand Up @@ -144,4 +160,9 @@ def test_rouge_metric_wrong_key_value_error():
ROUGEScore(rouge_keys=key)

with pytest.raises(ValueError):
rouge_score(PREDS, TARGETS, rouge_keys=key)
rouge_score(
_inputs_single_sentence_single_reference.preds,
_inputs_single_sentence_single_reference.targets,
rouge_keys=key,
accumulate="best",
)
110 changes: 82 additions & 28 deletions torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
import re
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.utilities.imports import _NLTK_AVAILABLE

Expand All @@ -33,6 +34,7 @@
"rougeL": "L",
"rougeLsum": "Lsum",
}
ALLOWED_ACCUMULATE_VALUES = ("avg", "best")


def _add_newline_to_end_of_each_sentence(x: str) -> str:
Expand Down Expand Up @@ -68,7 +70,7 @@ def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[s
return dict(precision=tensor(precision), recall=tensor(recall), fmeasure=tensor(fmeasure))


def _lcs(pred_tokens: List[str], target_tokens: List[str]) -> int:
def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
"""Common DP algorithm to compute the length of the longest common subsequence.

Args:
Expand All @@ -87,7 +89,7 @@ def _lcs(pred_tokens: List[str], target_tokens: List[str]) -> int:
return LCS[-1][-1]


def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> List[str]:
def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> Sequence[str]:
"""Rouge score should be calculated only over lowercased words and digits. Optionally, Porter stemmer can be
used to strip word suffixes to improve matching. The text normalization follows the implemantion from `Rouge
score_Text Normalizition`_
Expand All @@ -112,7 +114,7 @@ def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> Li
return tokens


def _rouge_n_score(pred: List[str], target: List[str], n_gram: int) -> Dict[str, Tensor]:
def _rouge_n_score(pred: Sequence[str], target: Sequence[str], n_gram: int) -> Dict[str, Tensor]:
"""This computes precision, recall and F1 score for the Rouge-N metric.

Args:
Expand All @@ -124,7 +126,7 @@ def _rouge_n_score(pred: List[str], target: List[str], n_gram: int) -> Dict[str,
N-gram overlap.
"""

def _create_ngrams(tokens: List[str], n: int) -> Counter:
def _create_ngrams(tokens: Sequence[str], n: int) -> Counter:
ngrams: Counter = Counter()
for ngram in (tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)):
ngrams[ngram] += 1
Expand All @@ -140,7 +142,7 @@ def _create_ngrams(tokens: List[str], n: int) -> Counter:
return _compute_metrics(hits, max(pred_len, 1), max(target_len, 1))


def _rouge_l_score(pred: List[str], target: List[str]) -> Dict[str, Tensor]:
def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tensor]:
"""This computes precision, recall and F1 score for the Rouge-L or Rouge-LSum metric.

Args:
Expand All @@ -158,9 +160,10 @@ def _rouge_l_score(pred: List[str], target: List[str]) -> Dict[str, Tensor]:


def _rouge_score_update(
preds: List[str],
targets: List[str],
preds: Sequence[str],
targets: Sequence[Sequence[str]],
rouge_keys_values: List[Union[int, str]],
accumulate: str,
stemmer: Optional[Any] = None,
) -> Dict[Union[int, str], List[Dict[str, Tensor]]]:
"""Update the rouge score with the current set of predicted and target sentences.
Expand All @@ -169,17 +172,22 @@ def _rouge_score_update(
preds:
An iterable of predicted sentences.
targets:
An iterable of target sentences.
An iterable of iterable of target sentences.
rouge_keys_values:
List of N-grams/'L'/'Lsum' arguments.
accumulate:
Useful incase of multi-reference rouge score.
``avg`` takes the avg of all references with respect to predictions
``best`` takes the best fmeasure score obtained between prediction and multiple corresponding references.
Allowed values are ``avg`` and ``best``.
stemmer:
Porter stemmer instance to strip word suffixes to improve matching.

Example:
>>> targets = "Is your name John".split()
>>> preds = "My name is John".split()
>>> from pprint import pprint
>>> score = _rouge_score_update(preds, targets, rouge_keys_values=[1, 2, 3, 'L'])
>>> score = _rouge_score_update(preds, targets, rouge_keys_values=[1, 2, 3, 'L'], accumulate='best')
>>> pprint(score) # doctest: +SKIP
{1: [{'fmeasure': tensor(0.), 'precision': tensor(0.), 'recall': tensor(0.)},
{'fmeasure': tensor(0.), 'precision': tensor(0.), 'recall': tensor(0.)},
Expand All @@ -199,24 +207,62 @@ def _rouge_score_update(
{'fmeasure': tensor(1.), 'precision': tensor(1.), 'recall': tensor(1.)}]}
"""
results: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values}

for pred_raw, target_raw in zip(preds, targets):
result_inner: Dict[Union[int, str], Dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values}
result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values}
list_results = []
pred = _normalize_and_tokenize_text(pred_raw, stemmer)
target = _normalize_and_tokenize_text(target_raw, stemmer)
pred_Lsum = _normalize_and_tokenize_text(_add_newline_to_end_of_each_sentence(pred_raw), stemmer)

if "Lsum" in rouge_keys_values:
# rougeLsum expects "\n" separated sentences within a summary
pred_Lsum = _normalize_and_tokenize_text(_add_newline_to_end_of_each_sentence(pred_raw), stemmer)
target_Lsum = _normalize_and_tokenize_text(_add_newline_to_end_of_each_sentence(target_raw), stemmer)
for target_raw_inner in target_raw:
target = _normalize_and_tokenize_text(target_raw_inner, stemmer)

for rouge_key in rouge_keys_values:
if isinstance(rouge_key, int):
score = _rouge_n_score(pred, target, rouge_key)
else:
score = _rouge_l_score(
pred if rouge_key != "Lsum" else pred_Lsum,
target if rouge_key != "Lsum" else target_Lsum,
if "Lsum" in rouge_keys_values:
# rougeLsum expects "\n" separated sentences within a summary
target_Lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer
)
results[rouge_key].append(score)

for rouge_key in rouge_keys_values:
if isinstance(rouge_key, int):
score = _rouge_n_score(pred, target, rouge_key)
else:
score = _rouge_l_score(
pred if rouge_key != "Lsum" else pred_Lsum,
target if rouge_key != "Lsum" else target_Lsum,
)
result_inner[rouge_key] = score
result_avg[rouge_key].append(score)
list_results.append(result_inner.copy())

if accumulate == "best":
key_curr = rouge_keys_values[0]
all_fmeasure = torch.tensor([v[key_curr]["fmeasure"] for v in list_results])
highest_idx = int(torch.argmax(all_fmeasure).item())

for rouge_key in rouge_keys_values:
results[rouge_key].append(list_results[highest_idx][rouge_key])

elif accumulate == "avg":
new_result_avg: Dict[Union[int, str], Dict[str, Tensor]] = {
rouge_key: {} for rouge_key in rouge_keys_values
}
for rouge_key, metrics in result_avg.items():
_dict_metric_score_batch: Dict[str, List[Tensor]] = {}
for metric in metrics:
for _type, value in metric.items():
if _type not in _dict_metric_score_batch:
_dict_metric_score_batch[_type] = []
_dict_metric_score_batch[_type].append(value)

new_result_avg[rouge_key] = {
_type: torch.tensor(_dict_metric_score_batch[_type]).mean() for _type in _dict_metric_score_batch
}

for rouge_key in rouge_keys_values:
results[rouge_key].append(new_result_avg[rouge_key])

return results


Expand All @@ -239,8 +285,9 @@ def _rouge_score_compute(sentence_results: Dict[str, List[Tensor]]) -> Dict[str,


def rouge_score(
preds: Union[str, List[str]],
targets: Union[str, List[str]],
preds: Union[str, Sequence[str]],
targets: Union[str, Sequence[str], Sequence[Sequence[str]]],
accumulate: Literal["avg", "best"] = "best",
use_stemmer: bool = False,
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
) -> Dict[str, Tensor]:
Expand All @@ -250,7 +297,11 @@ def rouge_score(
preds:
An iterable of predicted sentences or a single predicted sentence.
targets:
An iterable of target sentences or a single target sentence.
An iterable of iterables of target sentences or an iterable of target sentences or a single target sentence.
accumulate:
Useful incase of multi-reference rouge score.
- ``avg`` takes the avg of all references with respect to predictions
- ``best`` takes the best fmeasure score obtained between prediction and multiple corresponding references.
use_stemmer:
Use Porter stemmer to strip word suffixes to improve matching.
rouge_keys:
Expand Down Expand Up @@ -303,14 +354,17 @@ def rouge_score(
raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}")
rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys]

if isinstance(targets, list) and all(isinstance(target, str) for target in targets):
targets = [targets] if isinstance(preds, str) else [[target] for target in targets]

if isinstance(preds, str):
preds = [preds]

if isinstance(targets, str):
targets = [targets]
targets = [[targets]]

sentence_results: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update(
preds, targets, rouge_keys_values, stemmer=stemmer
preds, targets, rouge_keys_values, stemmer=stemmer, accumulate=accumulate
)

output: Dict[str, List[Tensor]] = {}
Expand Down
Loading