Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add normalizer, tokenizer to rouge #838

Merged
merged 19 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from functools import partial
from typing import Sequence
from typing import Callable, Sequence

import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -164,3 +166,47 @@ def test_rouge_metric_wrong_key_value_error():
rouge_keys=key,
accumulate="best",
)


@pytest.mark.parametrize(
"pl_rouge_metric_key",
[
("rouge1_precision"),
("rouge1_recall"),
("rouge1_fmeasure"),
("rouge2_precision"),
("rouge2_recall"),
("rouge2_fmeasure"),
("rougeL_precision"),
("rougeL_recall"),
("rougeL_fmeasure"),
("rougeLsum_precision"),
("rougeLsum_recall"),
("rougeLsum_fmeasure"),
],
)
def test_rouge_metric_normalizer_tokenizer(pl_rouge_metric_key):
normalizer: Callable[[str], str] = lambda text: re.sub(r"[^a-z0-9]+", " ", text.lower())
tokenizer: Callable[[str], Sequence[str]] = lambda text: re.split(r"\s+", text)

rouge_level, metric = pl_rouge_metric_key.split("_")
original_score = _compute_rouge_score(
preds=_inputs_single_sentence_single_reference.preds,
target=_inputs_single_sentence_single_reference.targets,
rouge_level=rouge_level,
metric=metric,
accumulate="best",
use_stemmer=False,
)

Scorer = ROUGEScore(
normalizer=normalizer, tokenizer=tokenizer, rouge_keys=rouge_level, accumulate="best", use_stemmer=False
)
Scorer.update(
_inputs_single_sentence_single_reference.preds,
_inputs_single_sentence_single_reference.targets,
)
metrics_score = Scorer.compute()

threshold = 1e-08
np.isclose(metrics_score[rouge_level + "_" + metric], original_score, atol=threshold, equal_nan=True)
Borda marked this conversation as resolved.
Show resolved Hide resolved
88 changes: 88 additions & 0 deletions tm_examples/rouge_score-own_normalizer_and_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of how to use ROUGEScore with a user's defined/own normalizer and tokenizer.

To run: python rouge_score-own_normalizer_and_tokenizer.py
"""

import re
from pprint import pprint
from typing import Sequence

from torchmetrics.text.rouge import ROUGEScore


class UserNormalizer:
"""The `UserNormalizer` class is required to normalize a non-alphabet language text input.

The user's defined normalizer is expected to return string that are fed into the tokenizer.
"""

def __init__(self) -> None:
self.pattern = r"[^a-z0-9]+"

def __call__(self, text: str) -> str:
"""The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
should obey the input/output arguments structure described below.

Args:
text:
Input text. `str`
Borda marked this conversation as resolved.
Show resolved Hide resolved

Return:
Normalized python string object
"""
output_text = re.sub(self.pattern, " ", text.lower())

return output_text


class UserTokenizer:
"""The `UserNormalizer` class is required to tokenize a non-alphabet language text input.

The user's defined tokenizer is expected to return `Sequence[str]` that are fed into the rouge score.
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self) -> None:
self.pattern = r"\s+"
Borda marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, text: str) -> Sequence[str]:
"""The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
should obey the input/output arguments structure described below.

Args:
text:
Input text. `str`
hookSSi marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved

Return:
tokenized sentence
hookSSi marked this conversation as resolved.
Show resolved Hide resolved
"""
output_tokens = re.split(self.pattern, text)

return output_tokens


_PREDS = ["hello", "hello world", "world world world"]
_REFS = ["hello", "hello hello", "hello world hello"]


if __name__ == "__main__":
normalizer = UserNormalizer()
tokenizer = UserTokenizer()

rouge_score = ROUGEScore(normalizer=normalizer, tokenizer=tokenizer)

rouge_score.update(_PREDS, _REFS)

pprint(rouge_score.compute())
63 changes: 52 additions & 11 deletions torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import re
from collections import Counter
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -91,7 +91,12 @@ def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
return LCS[-1][-1]


def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> Sequence[str]:
def _normalize_and_tokenize_text(
text: str,
stemmer: Optional[Any] = None,
normalizer: Callable[[str], str] = None,
tokenizer: Callable[[str], Sequence[str]] = None,
) -> Sequence[str]:
"""Rouge score should be calculated only over lowercased words and digits. Optionally, Porter stemmer can be
used to strip word suffixes to improve matching. The text normalization follows the implemantion from `Rouge
score_Text Normalizition`_
Expand All @@ -101,17 +106,27 @@ def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> Se
An input sentence.
stemmer:
Porter stemmer instance to strip word suffixes to improve matching.
normalizer:
A user's own normalizer function.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This function must take a `str` and return a `str`.
tokenizer:
A user's own tokenizer function. If this is none, `spliting by spaces` is default
This function must take a `str` and return `Sequence[str]`
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""
# Replace any non-alpha-numeric characters with spaces.
text = re.sub(r"[^a-z0-9]+", " ", text.lower())

tokens = re.split(r"\s+", text)
# If normalizer is none, replace any non-alpha-numeric characters with spaces.
text = normalizer(text) if callable(normalizer) else re.sub(r"[^a-z0-9]+", " ", text.lower())

# If tokenizer is none, spliting by spaces
tokens = tokenizer(text) if callable(tokenizer) else re.split(r"\s+", text)

if stemmer:
# Only stem words more than 3 characters long.
tokens = [stemmer.stem(x) if len(x) > 3 else x for x in tokens]

# One final check to drop any empty or invalid tokens.
tokens = [x for x in tokens if (isinstance(x, str) and re.match(r"^[a-z0-9]+$", x))]
tokens = [x for x in tokens if (isinstance(x, str) and len(x) > 0)]

return tokens

Expand Down Expand Up @@ -167,6 +182,8 @@ def _rouge_score_update(
rouge_keys_values: List[Union[int, str]],
accumulate: str,
stemmer: Optional[Any] = None,
normalizer: Callable[[str], str] = None,
tokenizer: Callable[[str], Sequence[str]] = None,
) -> Dict[Union[int, str], List[Dict[str, Tensor]]]:
"""Update the rouge score with the current set of predicted and target sentences.

Expand All @@ -184,6 +201,13 @@ def _rouge_score_update(
Allowed values are ``avg`` and ``best``.
stemmer:
Porter stemmer instance to strip word suffixes to improve matching.
normalizer:
A user's own normalizer function.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This function must take a `str` and return a `str`.
tokenizer:
A user's own tokenizer function. If this is none, `spliting by spaces` is default
This function must take a `str` and return `Sequence[str]`
Borda marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> preds = "My name is John".split()
Expand Down Expand Up @@ -214,16 +238,18 @@ def _rouge_score_update(
result_inner: Dict[Union[int, str], Dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values}
result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values}
list_results = []
pred = _normalize_and_tokenize_text(pred_raw, stemmer)
pred_Lsum = _normalize_and_tokenize_text(_add_newline_to_end_of_each_sentence(pred_raw), stemmer)
pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer)
pred_Lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(pred_raw), stemmer, normalizer, tokenizer
)

for target_raw_inner in target_raw:
tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer)
tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer, normalizer, tokenizer)

if "Lsum" in rouge_keys_values:
# rougeLsum expects "\n" separated sentences within a summary
target_Lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer, normalizer, tokenizer
)

for rouge_key in rouge_keys_values:
Expand Down Expand Up @@ -291,6 +317,8 @@ def rouge_score(
target: Union[str, Sequence[str], Sequence[Sequence[str]]],
accumulate: Literal["avg", "best"] = "best",
use_stemmer: bool = False,
normalizer: Callable[[str], str] = None,
tokenizer: Callable[[str], Sequence[str]] = None,
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
) -> Dict[str, Tensor]:
"""Calculate `Calculate Rouge Score`_ , used for automatic summarization.
Expand All @@ -306,6 +334,13 @@ def rouge_score(
- ``best`` takes the best fmeasure score obtained between prediction and multiple corresponding references.
use_stemmer:
Use Porter stemmer to strip word suffixes to improve matching.
normalizer:
A user's own normalizer function.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This function must take a `str` and return a `str`.
tokenizer:
A user's own tokenizer function. If this is none, `spliting by spaces` is default
This function must take a `str` and return `Sequence[str]`
Borda marked this conversation as resolved.
Show resolved Hide resolved
rouge_keys:
A list of rouge types to calculate.
Keys that are allowed are ``rougeL``, ``rougeLsum``, and ``rouge1`` through ``rouge9``.
Expand Down Expand Up @@ -367,7 +402,13 @@ def rouge_score(
target = [[target]]

sentence_results: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update(
preds, target, rouge_keys_values, stemmer=stemmer, accumulate=accumulate
preds,
target,
rouge_keys_values,
stemmer=stemmer,
normalizer=normalizer,
tokenizer=tokenizer,
accumulate=accumulate,
)

output: Dict[str, List[Tensor]] = {}
Expand Down
21 changes: 19 additions & 2 deletions torchmetrics/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -35,6 +35,13 @@ class ROUGEScore(Metric):
Args:
use_stemmer:
Use Porter stemmer to strip word suffixes to improve matching.
normalizer:
A user's own normalizer function.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This function must take a `str` and return a `str`.
tokenizer:
A user's own tokenizer function. If this is none, `spliting by spaces` is default
This function must take a `str` and return `Sequence[str]`
Borda marked this conversation as resolved.
Show resolved Hide resolved
accumulate:
Useful incase of multi-reference rouge score.
- ``avg`` takes the avg of all references with respect to predictions
Expand Down Expand Up @@ -87,6 +94,8 @@ class ROUGEScore(Metric):
def __init__(
self,
use_stemmer: bool = False,
normalizer: Callable[[str], str] = None,
tokenizer: Callable[[str], Sequence[str]] = None,
accumulate: Literal["avg", "best"] = "best",
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
compute_on_step: Optional[bool] = None,
Expand Down Expand Up @@ -114,6 +123,8 @@ def __init__(
self.rouge_keys = rouge_keys
self.rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys]
self.stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None
self.normalizer = normalizer
self.tokenizer = tokenizer
self.accumulate = accumulate

# Adding stated dynamically to prevent IndexError during sync function as some lists can be empty.
Expand Down Expand Up @@ -143,7 +154,13 @@ def update( # type: ignore
target = [[target]]

output: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update(
preds, target, self.rouge_keys_values, stemmer=self.stemmer, accumulate=self.accumulate
preds,
target,
self.rouge_keys_values,
stemmer=self.stemmer,
normalizer=self.normalizer,
tokenizer=self.tokenizer,
accumulate=self.accumulate,
)
for rouge_key, metrics in output.items():
for metric in metrics:
Expand Down