Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add normalizer, tokenizer to rouge #838

Merged
merged 19 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
return LCS[-1][-1]


def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> Sequence[str]:
def _normalize_and_tokenize_text(
text: str,
stemmer: Optional[Any] = None,
normalizer: Optional[Any] = None,
tokenizer: Optional[Any] = None,
) -> Sequence[str]:
"""Rouge score should be calculated only over lowercased words and digits. Optionally, Porter stemmer can be
used to strip word suffixes to improve matching. The text normalization follows the implemantion from `Rouge
score_Text Normalizition`_
Expand All @@ -101,17 +106,31 @@ def _normalize_and_tokenize_text(text: str, stemmer: Optional[Any] = None) -> Se
An input sentence.
stemmer:
Porter stemmer instance to strip word suffixes to improve matching.
normalizer:
A user's own normalizer instance.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This instance must have method named ``normalize``. This method must take a string and return a string.
tokenizer:
A user's own tokenizer instance. If this is none, `spliting by spaces` is default
This instance must have method named ``tokenize``. This method must take a string and return `List[str]`
hookSSi marked this conversation as resolved.
Show resolved Hide resolved
"""
# Replace any non-alpha-numeric characters with spaces.
text = re.sub(r"[^a-z0-9]+", " ", text.lower())
if normalizer:
text = normalizer.normalize(text)
else:
# Replace any non-alpha-numeric characters with spaces.
text = re.sub(r"[^a-z0-9]+", " ", text.lower())

if tokenizer:
tokens = tokenizer.tokenize(text)
else:
tokens = re.split(r"\s+", text)
Borda marked this conversation as resolved.
Show resolved Hide resolved
hookSSi marked this conversation as resolved.
Show resolved Hide resolved

tokens = re.split(r"\s+", text)
if stemmer:
# Only stem words more than 3 characters long.
tokens = [stemmer.stem(x) if len(x) > 3 else x for x in tokens]

# One final check to drop any empty or invalid tokens.
tokens = [x for x in tokens if (isinstance(x, str) and re.match(r"^[a-z0-9]+$", x))]
tokens = [x for x in tokens if (isinstance(x, str) and len(x) > 0)]

return tokens

Expand Down Expand Up @@ -167,6 +186,8 @@ def _rouge_score_update(
rouge_keys_values: List[Union[int, str]],
accumulate: str,
stemmer: Optional[Any] = None,
normalizer: Optional[Any] = None,
tokenizer: Optional[Any] = None,
) -> Dict[Union[int, str], List[Dict[str, Tensor]]]:
"""Update the rouge score with the current set of predicted and target sentences.

Expand All @@ -184,6 +205,13 @@ def _rouge_score_update(
Allowed values are ``avg`` and ``best``.
stemmer:
Porter stemmer instance to strip word suffixes to improve matching.
normalizer:
A user's own normalizer instance.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This instance must have method named ``normalize``. This method must take a string and return a string.
tokenizer:
A user's own tokenizer instance. If this is none, `spliting by spaces` is default
This instance must have method named ``tokenize``. This method must take a string and return `List[str]`

Example:
>>> preds = "My name is John".split()
Expand Down Expand Up @@ -214,16 +242,18 @@ def _rouge_score_update(
result_inner: Dict[Union[int, str], Dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values}
result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values}
list_results = []
pred = _normalize_and_tokenize_text(pred_raw, stemmer)
pred_Lsum = _normalize_and_tokenize_text(_add_newline_to_end_of_each_sentence(pred_raw), stemmer)
pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer)
pred_Lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(pred_raw), stemmer, normalizer, tokenizer
)

for target_raw_inner in target_raw:
tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer)
tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer, normalizer, tokenizer)

if "Lsum" in rouge_keys_values:
# rougeLsum expects "\n" separated sentences within a summary
target_Lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer, normalizer, tokenizer
)

for rouge_key in rouge_keys_values:
Expand Down Expand Up @@ -291,6 +321,8 @@ def rouge_score(
target: Union[str, Sequence[str], Sequence[Sequence[str]]],
accumulate: Literal["avg", "best"] = "best",
use_stemmer: bool = False,
normalizer: Optional[Any] = None,
tokenizer: Optional[Any] = None,
hookSSi marked this conversation as resolved.
Show resolved Hide resolved
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
) -> Dict[str, Tensor]:
"""Calculate `Calculate Rouge Score`_ , used for automatic summarization.
Expand All @@ -306,6 +338,13 @@ def rouge_score(
- ``best`` takes the best fmeasure score obtained between prediction and multiple corresponding references.
use_stemmer:
Use Porter stemmer to strip word suffixes to improve matching.
normalizer:
A user's own normalizer instance.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This instance must have method named ``normalize``. This method must take a string and return a string.
tokenizer:
A user's own tokenizer instance. If this is none, `spliting by spaces` is default
This instance must have method named ``tokenize``. This method must take a string and return `List[str]`
hookSSi marked this conversation as resolved.
Show resolved Hide resolved
rouge_keys:
A list of rouge types to calculate.
Keys that are allowed are ``rougeL``, ``rougeLsum``, and ``rouge1`` through ``rouge9``.
Expand Down Expand Up @@ -367,7 +406,13 @@ def rouge_score(
target = [[target]]

sentence_results: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update(
preds, target, rouge_keys_values, stemmer=stemmer, accumulate=accumulate
preds,
target,
rouge_keys_values,
stemmer=stemmer,
normalizer=normalizer,
tokenizer=tokenizer,
accumulate=accumulate,
)

output: Dict[str, List[Tensor]] = {}
Expand Down
19 changes: 18 additions & 1 deletion torchmetrics/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class ROUGEScore(Metric):
Args:
use_stemmer:
Use Porter stemmer to strip word suffixes to improve matching.
normalizer:
A user's own normalizer instance.
If this is none, `replacing any non-alpha-numeric characters with spaces` is default.
This instance must have method named ``normalize``. This method must take a string and return a string.
tokenizer:
A user's own tokenizer instance. If this is none, `spliting by spaces` is default
This instance must have method named ``tokenize``. This method must take a string and return `List[str]`
accumulate:
Useful incase of multi-reference rouge score.
- ``avg`` takes the avg of all references with respect to predictions
Expand Down Expand Up @@ -89,6 +96,8 @@ class ROUGEScore(Metric):
def __init__(
self,
use_stemmer: bool = False,
normalizer: Optional[Any] = None,
tokenizer: Optional[Any] = None,
accumulate: Literal["avg", "best"] = "best",
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
compute_on_step: bool = True,
Expand Down Expand Up @@ -123,6 +132,8 @@ def __init__(
self.rouge_keys = rouge_keys
self.rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys]
self.stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None
self.normalizer = normalizer
self.tokenizer = tokenizer
self.accumulate = accumulate

# Adding stated dynamically to prevent IndexError during sync function as some lists can be empty.
Expand Down Expand Up @@ -152,7 +163,13 @@ def update( # type: ignore
target = [[target]]

output: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update(
preds, target, self.rouge_keys_values, stemmer=self.stemmer, accumulate=self.accumulate
preds,
target,
self.rouge_keys_values,
stemmer=self.stemmer,
normalizer=self.normalizer,
tokenizer=self.tokenizer,
accumulate=self.accumulate,
)
for rouge_key, metrics in output.items():
for metric in metrics:
Expand Down