diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d4ea306af2..17ecc734857 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,10 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +- Removed `rouge-score` as dependency for text package ([#443](https://github.com/PyTorchLightning/metrics/pull/443)) ### Fixed -- Fixed bug in the ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448)) +- Fixed ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448)) ## [0.5.0] - 2021-08-09 diff --git a/requirements/test.txt b/requirements/test.txt index 960f810cb89..7a1fbe57fdc 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -25,3 +25,6 @@ mir_eval>=0.6 #pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip #SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip + +# text +rouge-score>=0.0.4 diff --git a/requirements/text.txt b/requirements/text.txt index d850ded158a..99408171a0d 100644 --- a/requirements/text.txt +++ b/requirements/text.txt @@ -1,4 +1,3 @@ jiwer>=2.2.0 nltk>=3.6 -rouge-score>=0.0.4 bert-score==0.3.10 diff --git a/tests/text/test_rouge.py b/tests/text/test_rouge.py index 8fe1ce0903b..497ecd8467c 100644 --- a/tests/text/test_rouge.py +++ b/tests/text/test_rouge.py @@ -16,7 +16,6 @@ import pytest import torch -from torch import tensor from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.text.rouge import ROUGEScore @@ -30,16 +29,13 @@ ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum") -PRECISION = 0 -RECALL = 1 -F_MEASURE = 2 - SINGLE_SENTENCE_EXAMPLE_PREDS = "The quick brown fox jumps over the lazy dog" SINGLE_SENTENCE_EXAMPLE_TARGET = "The quick brown dog jumps on the log." PREDS = "My name is John".split() TARGETS = "Is your name John".split() + BATCHES_RS_PREDS = [SINGLE_SENTENCE_EXAMPLE_PREDS] BATCHES_RS_PREDS.extend(PREDS) BATCHES_RS_TARGETS = [SINGLE_SENTENCE_EXAMPLE_TARGET] @@ -55,145 +51,139 @@ def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) aggregator = BootstrapAggregator() for pred, target in zip(preds, targets): - aggregator.add_scores(scorer.score(pred, target)) + aggregator.add_scores(scorer.score(target, pred)) return aggregator.aggregate() -@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score") +@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") @pytest.mark.parametrize( - ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"], + ["pl_rouge_metric_key", "use_stemmer"], [ - pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True), - pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False), - pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True), - pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False), - pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True), - pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False), - pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True), - pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False), - pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True), - pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False), - pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True), - pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False), + pytest.param("rouge1_precision", True), + pytest.param("rouge1_recall", True), + pytest.param("rouge1_fmeasure", False), + pytest.param("rouge2_precision", False), + pytest.param("rouge2_recall", True), + pytest.param("rouge2_fmeasure", True), + pytest.param("rougeL_precision", False), + pytest.param("rougeL_recall", False), + pytest.param("rougeL_fmeasure", True), + pytest.param("rougeLsum_precision", True), + pytest.param("rougeLsum_recall", False), + pytest.param("rougeLsum_fmeasure", False), ], ) -def test_rouge_metric_functional_single_sentence( - pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep -): - scorer = RougeScorer(ROUGE_KEYS) - rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET) - rs_output = round(rs_scores[rouge_score_key][metric], decimal_places) +def test_rouge_metric_functional_single_sentence(pl_rouge_metric_key, use_stemmer): + rouge_level, metric = pl_rouge_metric_key.split("_") + + scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) + rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS) + rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32) - pl_output = rouge_score( - [SINGLE_SENTENCE_EXAMPLE_PREDS], - [SINGLE_SENTENCE_EXAMPLE_TARGET], - newline_sep=newline_sep, - use_stemmer=use_stemmer, - decimal_places=decimal_places, - ) + pl_output = rouge_score([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET], use_stemmer=use_stemmer) - assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32)) + assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result) -@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score") +@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") @pytest.mark.parametrize( - ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"], + ["pl_rouge_metric_key", "use_stemmer"], [ - pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True), - pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False), - pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True), - pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False), - pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True), - pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False), - pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True), - pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False), - pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True), - pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False), - pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True), - pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False), + pytest.param("rouge1_precision", True), + pytest.param("rouge1_recall", True), + pytest.param("rouge1_fmeasure", False), + pytest.param("rouge2_precision", False), + pytest.param("rouge2_recall", True), + pytest.param("rouge2_fmeasure", True), + pytest.param("rougeL_precision", False), + pytest.param("rougeL_recall", False), + pytest.param("rougeL_fmeasure", True), + pytest.param("rougeLsum_precision", True), + pytest.param("rougeLsum_recall", False), + pytest.param("rougeLsum_fmeasure", False), ], ) -def test_rouge_metric_functional( - pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep -): +def test_rouge_metric_functional(pl_rouge_metric_key, use_stemmer): + rouge_level, metric = pl_rouge_metric_key.split("_") + rs_scores = _compute_rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer) - rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places) + rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32) - pl_output = rouge_score( - PREDS, TARGETS, newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places - ) + pl_output = rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer) - assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32)) + assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result) -@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score") +@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") @pytest.mark.parametrize( - ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"], + ["pl_rouge_metric_key", "use_stemmer"], [ - pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True), - pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False), - pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True), - pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False), - pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True), - pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False), - pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True), - pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False), - pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True), - pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False), - pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True), - pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False), + pytest.param("rouge1_precision", True), + pytest.param("rouge1_recall", True), + pytest.param("rouge1_fmeasure", False), + pytest.param("rouge2_precision", False), + pytest.param("rouge2_recall", True), + pytest.param("rouge2_fmeasure", True), + pytest.param("rougeL_precision", False), + pytest.param("rougeL_recall", False), + pytest.param("rougeL_fmeasure", True), + pytest.param("rougeLsum_precision", True), + pytest.param("rougeLsum_recall", False), + pytest.param("rougeLsum_fmeasure", False), ], ) -def test_rouge_metric_class(pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep): - scorer = RougeScorer(ROUGE_KEYS) - rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET) - rs_output = round(rs_scores[rouge_score_key][metric], decimal_places) +def test_rouge_metric_class(pl_rouge_metric_key, use_stemmer): + rouge_level, metric = pl_rouge_metric_key.split("_") + + scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) + rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS) + rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32) - rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places) + rouge = ROUGEScore(use_stemmer=use_stemmer) pl_output = rouge([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET]) - assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32)) + assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result) -@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score") +@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") @pytest.mark.parametrize( - ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"], + ["pl_rouge_metric_key", "use_stemmer"], [ - pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True), - pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False), - pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True), - pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False), - pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True), - pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False), - pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True), - pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False), - pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True), - pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False), - pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True), - pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False), + pytest.param("rouge1_precision", True), + pytest.param("rouge1_recall", True), + pytest.param("rouge1_fmeasure", False), + pytest.param("rouge2_precision", False), + pytest.param("rouge2_recall", True), + pytest.param("rouge2_fmeasure", True), + pytest.param("rougeL_precision", False), + pytest.param("rougeL_recall", False), + pytest.param("rougeL_fmeasure", True), + pytest.param("rougeLsum_precision", True), + pytest.param("rougeLsum_recall", False), + pytest.param("rougeLsum_fmeasure", False), ], ) -def test_rouge_metric_class_batches( - pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep -): +def test_rouge_metric_class_batches(pl_rouge_metric_key, use_stemmer): + rouge_level, metric = pl_rouge_metric_key.split("_") + rs_scores = _compute_rouge_score(BATCHES_RS_PREDS, BATCHES_RS_TARGETS, use_stemmer=use_stemmer) - rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places) + rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32) - rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places) + rouge = ROUGEScore(use_stemmer=use_stemmer) for batch in BATCHES: rouge.update(batch["preds"], batch["targets"]) pl_output = rouge.compute() - assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32)) + assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result) def test_rouge_metric_raises_errors_and_warnings(): """Test that expected warnings and errors are raised.""" - if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE): + if not _NLTK_AVAILABLE: with pytest.raises( ValueError, - match="ROUGE metric requires that both nltk and rouge-score is installed." - "Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`", + match="ROUGE metric requires that nltk is installed." + "Either as `pip install torchmetrics[text]` or `pip install nltk`", ): ROUGEScore() diff --git a/torchmetrics/functional/text/rouge.py b/torchmetrics/functional/text/rouge.py index 91e61d93e74..688175d0c81 100644 --- a/torchmetrics/functional/text/rouge.py +++ b/torchmetrics/functional/text/rouge.py @@ -12,64 +12,150 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Dict, List, Tuple, Union +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from torch import Tensor, tensor - -from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE - -if _ROUGE_SCORE_AVAILABLE: - from rouge_score.rouge_scorer import RougeScorer - from rouge_score.scoring import AggregateScore, BootstrapAggregator -else: - RougeScorer, AggregateScore, BootstrapAggregator = object, object, object - -ALLOWED_ROUGE_KEYS = ( - "rouge1", - "rouge2", - "rouge3", - "rouge4", - "rouge5", - "rouge6", - "rouge7", - "rouge8", - "rouge9", - "rougeL", - "rougeLsum", -) - - -def add_newline_to_end_of_each_sentence(x: str) -> str: +from torch import Tensor + +from torchmetrics.utilities.imports import _NLTK_AVAILABLE + +ALLOWED_ROUGE_KEYS: Dict[str, Union[int, str]] = { + "rouge1": 1, + "rouge2": 2, + "rouge3": 3, + "rouge4": 4, + "rouge5": 5, + "rouge6": 6, + "rouge7": 7, + "rouge8": 8, + "rouge9": 9, + "rougeL": "L", + "rougeLsum": "Lsum", +} + + +def _add_newline_to_end_of_each_sentence(x: str) -> str: """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" - if _NLTK_AVAILABLE: - import nltk + if not _NLTK_AVAILABLE: + raise ValueError("ROUGE-Lsum calculation requires that nltk is installed. Use `pip install nltk`.") + import nltk - nltk.download("punkt", quiet=True, force=False) + nltk.download("punkt", quiet=True, force=False) re.sub("", "", x) # remove pegasus newline char - assert nltk, "nltk must be installed to separate newlines between sentences. (pip install nltk)" return "\n".join(nltk.sent_tokenize(x)) -def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, Tensor]: - """Formats the computed (aggregated) rouge score to a dictionary of tensors format.""" - flattened_result = {} - for rouge_key, rouge_aggregate_score in result.items(): - for stat in ["precision", "recall", "fmeasure"]: - mid = rouge_aggregate_score.mid - score = round(getattr(mid, stat), decimal_places) - flattened_result[f"{rouge_key}_{stat}"] = tensor(score, dtype=torch.float) - return flattened_result +def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, float]: + """This computes precision, recall and F1 score based on hits/lcs, and the length of lists of tokenizer + predicted and target sentences. + + Args: + hits_or_lcs: + A number of matches or a length of the longest common subsequence. + pred_len: + A length of a tokenized predicted sentence. + target_len: + A length of a tokenized target sentence. + """ + precision = hits_or_lcs / pred_len + recall = hits_or_lcs / target_len + if precision == recall == 0.0: + return dict(precision=0.0, recall=0.0, fmeasure=0.0) + + fmeasure = 2 * precision * recall / (precision + recall) + return dict(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _lcs(pred_tokens: List[str], target_tokens: List[str]) -> int: + """Common DP algorithm to compute the length of the longest common subsequence. + + Args: + pred_tokens: + A tokenized predicted sentence. + target_toknes: + A tokenized target sentence. + """ + LCS = [[0] * (len(pred_tokens) + 1) for _ in range(len(target_tokens) + 1)] + for i in range(1, len(target_tokens) + 1): + for j in range(1, len(pred_tokens) + 1): + if target_tokens[i - 1] == pred_tokens[j - 1]: + LCS[i][j] = LCS[i - 1][j - 1] + 1 + else: + LCS[i][j] = max(LCS[i - 1][j], LCS[i][j - 1]) + return LCS[-1][-1] + + +def _normalize_text(text: str, stemmer: Optional[Any] = None) -> str: + """Rouge score should be calculated only over lowercased words and digits. Optionally, Porter stemmer can be + used to strip word suffixes to improve matching. The text normalization follows the implemantion from + https://github.com/google-research/google-research/blob/master/rouge/tokenize.py. + + Args: + text: + An input sentence. + stemmer: + Porter stemmer instance to strip word suffixes to improve matching. + """ + text = re.sub(r"[^a-z0-9]+", " ", text.lower()) + if stemmer: + text = " ".join(stemmer.stem(x) if len(x) > 3 else x for x in text.split()) + return text.strip() # to ensure there are no whitespaces as the end of sentence + + +def _rouge_n_score(pred: str, target: str, n_gram: int) -> Dict[str, float]: + """This computes precision, recall and F1 score for the Rouge-N metric. + + Args: + pred: + A predicted sentence. + target: + A target sentence. + n_gram: + N-gram overlap. + """ + pred_tokenized, target_tokenized = _tokenize(pred, n_gram), _tokenize(target, n_gram) + pred_len, target_len = len(pred_tokenized), len(target_tokenized) + if 0 in (pred_len, target_len): + return dict(precision=0.0, recall=0.0, fmeasure=0.0) + + pred_counter: Dict[str, int] = defaultdict(int) + target_counter: Dict[str, int] = defaultdict(int) + for w in pred_tokenized: + pred_counter[w] += 1 + for w in target_tokenized: + target_counter[w] += 1 + # It is sufficient to take a set(pred_tokenized) for hits count as we consider intersenction of pred & target + hits = sum(min(pred_counter[w], target_counter[w]) for w in set(pred_tokenized)) + return _compute_metrics(hits, pred_len, target_len) + + +def _rouge_l_score(pred: str, target: str) -> Dict[str, float]: + """This computes precision, recall and F1 score for the Rouge-L or Rouge-LSum metric. + + Args: + pred: + A predicted sentence. + target: + A target sentence. + """ + pred_tokenized, target_tokenized = _tokenize(pred, 1), _tokenize(target, 1) + pred_len, target_len = len(pred_tokenized), len(target_tokenized) + if 0 in (pred_len, target_len): + return dict(precision=0.0, recall=0.0, fmeasure=0.0) + + lcs = _lcs(pred_tokenized, target_tokenized) + return _compute_metrics(lcs, pred_len, target_len) def _rouge_score_update( preds: List[str], targets: List[str], - scorer: RougeScorer, - aggregator: BootstrapAggregator, - newline_sep: bool = False, -) -> None: + rouge_keys_values: List[Union[int, str]], + results: Optional[Dict[Union[int, str], List[Dict[str, float]]]] = None, + stemmer: Optional[Any] = None, +) -> Dict[Union[int, str], List[Dict[str, float]]]: """Update the rouge score with the current set of predicted and target sentences. Args: @@ -77,49 +163,69 @@ def _rouge_score_update( An iterable of predicted sentences. targets: An iterable of target sentences. - scorer: - An instance of the ``RougeScorer`` class from the ``rouge_score`` package. - aggregator: - An instance of the ``BootstrapAggregator`` from the from the ``rouge_score`` package. - newline_sep: - New line separate the inputs. + rouge_keys_values: + List of N-grams/'L'/'Lsum' arguments. + stemmer: + Porter stemmer instance to strip word suffixes to improve matching. Example: >>> targets = "Is your name John".split() >>> preds = "My name is John".split() - >>> aggregator = BootstrapAggregator() - >>> scorer = RougeScorer(rouge_types=("rouge1", "rouge2", "rougeL", "rougeLsum"), use_stemmer=False) - >>> _rouge_score_update(preds, targets, scorer=scorer, aggregator=aggregator, newline_sep=False) + >>> from pprint import pprint + >>> score = _rouge_score_update(preds, targets, rouge_keys_values=[1, 2, 3, 'L']) + >>> pprint(score) # doctest: +NORMALIZE_WHITESPACE +SKIP + {'1': {'precision': 0.25, 'recall': 0.25, 'fmeasure': 0.25}, + '2': {'precision': 0.0, 'recall': 0.0, 'fmeasure': 0.0}, + '3': {'precision': 0.0, 'recall': 0.0, 'fmeasure': 0.0}, + 'L': {'precision': 0.25, 'recall': 0.25, 'fmeasure': 0.25}} """ - for pred, target in zip(preds, targets): + results = results if results is not None else {rouge_key: [] for rouge_key in rouge_keys_values} + for pred_raw, target_raw in zip(preds, targets): + pred, target = _normalize_text(pred_raw, stemmer), _normalize_text(target_raw, stemmer) # rougeLsum expects "\n" separated sentences within a summary - if newline_sep: - pred = add_newline_to_end_of_each_sentence(pred) - target = add_newline_to_end_of_each_sentence(target) - results = scorer.score(pred, target) - aggregator.add_scores(results) + if "Lsum" in rouge_keys_values: + pred_sum = _normalize_text(_add_newline_to_end_of_each_sentence(pred_raw), stemmer) + target_sum = _normalize_text(_add_newline_to_end_of_each_sentence(target_raw), stemmer) + + for rouge_key in rouge_keys_values: + if isinstance(rouge_key, int): + score = _rouge_n_score(pred, target, rouge_key) + else: + score = _rouge_l_score( + pred if rouge_key != "Lsum" else pred_sum, + target if rouge_key != "Lsum" else target_sum, + ) + results[rouge_key].append(score) + return results -def _rouge_score_compute(aggregator: BootstrapAggregator, decimal_places: int = 4) -> Dict[str, Tensor]: +def _rouge_score_compute( + sentence_results: Optional[Dict[Union[int, str], List[Dict[str, float]]]] +) -> Dict[str, Tensor]: """Compute the combined ROUGE metric for all the input set of predicted and target sentences. Args: - aggregator: - An instance of the ``BootstrapAggregator`` from the from the ``rouge_score`` package. - decimal_places: - The number of digits to round the computed the values to. + sentence_results: + Rouge-N/Rouge-L/Rouge-LSum metrics calculated for single sentence. """ - result = aggregator.aggregate() - return format_rouge_results(result, decimal_places) + results: Dict[str, Tensor] = {} + # Obtain mean scores for individual rouge metrics + if sentence_results is None: + return results + for rouge_key, scores in sentence_results.items(): + res = torch.tensor([(score["precision"], score["recall"], score["fmeasure"]) for score in scores]).mean(0) + results[f"rouge{rouge_key}_precision"] = res[0] + results[f"rouge{rouge_key}_recall"] = res[1] + results[f"rouge{rouge_key}_fmeasure"] = res[2] + + return results def rouge_score( preds: Union[str, List[str]], targets: Union[str, List[str]], - newline_sep: bool = False, use_stemmer: bool = False, rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore - decimal_places: int = 4, ) -> Dict[str, Tensor]: """Calculate `ROUGE score `_, used for automatic summarization. @@ -128,15 +234,11 @@ def rouge_score( An iterable of predicted sentences. targets: An iterable of target sentences. - newline_sep: - New line separate the inputs. use_stemmer: Use Porter stemmer to strip word suffixes to improve matching. rouge_keys: A list of rouge types to calculate. Keys that are allowed are ``rougeL``, ``rougeLsum``, and ``rouge1`` through ``rouge9``. - decimal_places: - The number of digits to round the computed the values to. Return: Python dictionary of rouge scores for each input rouge key. @@ -161,7 +263,7 @@ def rouge_score( Raises: ValueError: - If the python packages ``nltk`` or ``rouge-score`` are not installed. + If the python package ``nltk`` is not installed. ValueError: If any of the ``rouge_keys`` does not belong to the allowed set of keys. @@ -169,17 +271,19 @@ def rouge_score( [1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin https://aclanthology.org/W04-1013/ """ - if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE): - raise ValueError( - "ROUGE metric requires that both nltk and rouge-score is installed." - " Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`" - ) + if use_stemmer: + if not _NLTK_AVAILABLE: + raise ValueError("Stemmer requires that nltk is installed. Use `pip install nltk`.") + import nltk + + stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None if not isinstance(rouge_keys, tuple): rouge_keys = tuple([rouge_keys]) for key in rouge_keys: - if key not in ALLOWED_ROUGE_KEYS: - raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}") + if key not in ALLOWED_ROUGE_KEYS.keys(): + raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}") + rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys] if isinstance(preds, str): preds = [preds] @@ -187,8 +291,19 @@ def rouge_score( if isinstance(targets, str): targets = [targets] - aggregator = BootstrapAggregator() - scorer = RougeScorer(rouge_keys, use_stemmer=use_stemmer) + sentence_results = _rouge_score_update(preds, targets, rouge_keys_values, stemmer=stemmer) + return _rouge_score_compute(sentence_results) + + +def _tokenize(text: str, n_gram: int) -> List[str]: + """Retun the list of a tokenized input text, where tokens are represented by N-grams. - _rouge_score_update(preds, targets, scorer=scorer, aggregator=aggregator, newline_sep=newline_sep) - return _rouge_score_compute(aggregator=aggregator, decimal_places=decimal_places) + Args: + text: + An input sentence. + n_gram + N-gram size to return. + """ + tokens = re.split(r"\s+", text) + n_grams_list = [" ".join(tokens[i : i + n_gram]) for i in range(len(tokens) - n_gram + 1)] + return n_grams_list diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index 495f1232bda..b58793cfdff 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -11,27 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union from torch import Tensor from torchmetrics import Metric from torchmetrics.functional.text.rouge import ALLOWED_ROUGE_KEYS, _rouge_score_compute, _rouge_score_update -from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE - -if _ROUGE_SCORE_AVAILABLE: - from rouge_score.rouge_scorer import RougeScorer - from rouge_score.scoring import BootstrapAggregator -else: - RougeScorer, BootstrapAggregator = object, object +from torchmetrics.utilities.imports import _NLTK_AVAILABLE class ROUGEScore(Metric): """Calculate `ROUGE score `_, used for automatic summarization. + This implementation should imitate the behaviour of the `rouge-score` package https://pypi.org/project/rouge- + score/. Args: newline_sep: New line separate the inputs. + This argument has not been in use any more. It is deprecated in v0.6 and will be removed in v0.7. use_stemmer: Use Porter stemmer to strip word suffixes to improve matching. rouge_keys: @@ -39,6 +37,7 @@ class ROUGEScore(Metric): Keys that are allowed are ``rougeL``, ``rougeLsum``, and ``rouge1`` through ``rouge9``. decimal_places: The number of digits to round the computed the values to. + This argument has not been in usd any more. It is deprecated in v0.6 and will be removed in v0.7. compute_on_step: Forward only calls ``update()`` and returns None if this is set to False. default: True dist_sync_on_step: @@ -72,7 +71,7 @@ class ROUGEScore(Metric): Raises: ValueError: - If the python packages ``nltk`` or ``rouge-score`` are not installed. + If the python packages ``nltk`` is not installed. ValueError: If any of the ``rouge_keys`` does not belong to the allowed set of keys. @@ -82,10 +81,10 @@ class ROUGEScore(Metric): def __init__( self, - newline_sep: bool = False, + newline_sep: Optional[bool] = None, # remove in v0.7 use_stemmer: bool = False, rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore - decimal_places: int = 4, + decimal_places: Optional[bool] = None, # remove in v0.7 compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -97,12 +96,15 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) + if newline_sep is not None: + warnings.warn("Argument `newline_sep` is deprecated in v0.6 and will be removed in v0.7") + if decimal_places is not None: + warnings.warn("Argument `decimal_places` is deprecated in v0.6 and will be removed in v0.7") - if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE): - raise ValueError( - "ROUGE metric requires that both nltk and rouge-score is installed." - " Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`" - ) + if use_stemmer or "rougeLsum" in rouge_keys: + if not _NLTK_AVAILABLE: + raise ValueError("Stemmer and/or `rougeLsum` requires that nltk is installed. Use `pip install nltk`.") + import nltk if not isinstance(rouge_keys, tuple): rouge_keys = tuple([rouge_keys]) @@ -111,11 +113,9 @@ def __init__( raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}") self.rouge_keys = rouge_keys - self.newline_sep = newline_sep - self.use_stemmer = use_stemmer - self.aggregator = BootstrapAggregator() - self.scorer = RougeScorer(rouge_keys, use_stemmer=self.use_stemmer) - self.decimal_places = decimal_places + self.rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys] + self.stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None + self.sentence_results: Optional[Dict[Union[int, str], List[Dict[str, float]]]] = None def update(self, preds: Union[str, List[str]], targets: Union[str, List[str]]) -> None: # type: ignore """Compute rouge scores. @@ -131,8 +131,8 @@ def update(self, preds: Union[str, List[str]], targets: Union[str, List[str]]) - if isinstance(targets, str): targets = [targets] - _rouge_score_update( - preds, targets, scorer=self.scorer, aggregator=self.aggregator, newline_sep=self.newline_sep + self.sentence_results = _rouge_score_update( + preds, targets, self.rouge_keys_values, self.sentence_results, self.stemmer ) def compute(self) -> Dict[str, Tensor]: @@ -141,7 +141,7 @@ def compute(self) -> Dict[str, Tensor]: Return: Python dictionary of rouge scores for each input rouge key. """ - return _rouge_score_compute(aggregator=self.aggregator, decimal_places=self.decimal_places) + return _rouge_score_compute(self.sentence_results) def __hash__(self) -> int: # override to hash list objects.