Skip to content

Commit

Permalink
Untokenized Bleu score to stay consistent with all the other text met…
Browse files Browse the repository at this point in the history
…rics (#640)

* update bleu be consistent text metrics
* fix examples
* fix the additional parenthesis
* updated black format
* add tokenizer func
* update name for readability

Co-authored-by: Alex Leu <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Daniel Stancl <[email protected]>
  • Loading branch information
5 people authored Dec 6, 2021
1 parent eccfe83 commit 8d5b8ba
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 54 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))


- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))


### Deprecated


Expand Down
54 changes: 31 additions & 23 deletions tests/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,45 @@
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
# EXAMPLE 1
HYPOTHESIS_A = tuple(
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
)
REFERENCE_1A = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
REFERENCE_2A = tuple(
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
)
REFERENCE_3A = tuple("It is the practical guide for the army always to heed the directions of the party".split())
HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party"
REFERENCE_1A = "It is a guide to action that ensures that the military will forever heed Party commands"
REFERENCE_2A = "It is a guiding principle which makes the military forces always being under the command of the Party"
REFERENCE_3A = "It is the practical guide for the army always to heed the directions of the party"

# EXAMPLE 2
HYPOTHESIS_B = tuple("he read the book because he was interested in world history".split())
REFERENCE_1B = tuple("he was interested in world history because he read the book".split())
HYPOTHESIS_B = "he read the book because he was interested in world history"
REFERENCE_1B = "he was interested in world history because he read the book"

# EXAMPLE 3
HYPOTHESIS_C = tuple("the cat the cat on the mat".split())
REFERENCE_1C = tuple("the cat is on the mat".split())
REFERENCE_2C = tuple("there is a cat on the mat".split())
HYPOTHESIS_C = "the cat the cat on the mat"
REFERENCE_1C = "the cat is on the mat"
REFERENCE_2C = "there is a cat on the mat"

TUPLE_OF_REFERENCES = (
((REFERENCE_1A, REFERENCE_2A, REFERENCE_3A), tuple([REFERENCE_1B])),
(tuple([REFERENCE_1B]), (REFERENCE_1C, REFERENCE_2C)),
(REFERENCE_1B),
)
TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C))
TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C), (HYPOTHESIS_B))

BATCHES = {"preds": TUPLE_OF_HYPOTHESES, "targets": TUPLE_OF_REFERENCES}

# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
smooth_func = SmoothingFunction().method2


def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing_function, **kwargs):
hypotheses_ = [hypothesis.split() for hypothesis in hypotheses]
list_of_references_ = [[line.split() for line in ref] for ref in list_of_references]
return corpus_bleu(
list_of_references=list_of_references_,
hypotheses=hypotheses_,
weights=weights,
smoothing_function=smoothing_function,
**kwargs
)


@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
Expand All @@ -73,29 +82,28 @@ class TestBLEUScore(TextTester):
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, n_gram, smooth_func, smooth):
metric_args = {"n_gram": n_gram, "smooth": smooth}

nltk_metric = partial(corpus_bleu, weights=weights, smoothing_function=smooth_func)
compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func)

self.run_class_metric_test(
ddp=ddp,
preds=preds,
targets=targets,
metric_class=BLEUScore,
sk_metric=nltk_metric,
sk_metric=compute_bleu_metric_nltk,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)

def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth):
metric_args = {"n_gram": n_gram, "smooth": smooth}
nltk_metric = partial(corpus_bleu, weights=weights, smoothing_function=smooth_func)
compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func)

self.run_functional_metric_test(
preds,
targets,
metric_functional=bleu_score,
sk_metric=nltk_metric,
sk_metric=compute_bleu_metric_nltk,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)
Expand All @@ -120,8 +128,8 @@ def test_bleu_empty_functional():


def test_no_4_gram_functional():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu_score(refs, hyps) == tensor(0.0)


Expand All @@ -134,6 +142,6 @@ def test_bleu_empty_class():

def test_no_4_gram_class():
bleu = BLEUScore()
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(refs, hyps) == tensor(0.0)
45 changes: 34 additions & 11 deletions torchmetrics/functional/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Sequence, Tuple
from typing import Callable, Sequence, Tuple, Union

import torch
from torch import Tensor, tensor
Expand All @@ -44,14 +44,27 @@ def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:
return ngram_counter


def _tokenize_fn(sentence: str) -> Sequence[str]:
"""Tokenizes sentence into list of words.
Args:
sentence: A sentence separated by white space.
Return:
List of words
"""
return sentence.split()


def _bleu_score_update(
reference_corpus: Sequence[Sequence[Sequence[str]]],
translate_corpus: Sequence[Sequence[str]],
reference_corpus: Sequence[Sequence[str]],
translate_corpus: Sequence[str],
numerator: Tensor,
denominator: Tensor,
trans_len: Tensor,
ref_len: Tensor,
n_gram: int = 4,
tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn,
) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute the BLEU score.
Expand All @@ -63,9 +76,14 @@ def _bleu_score_update(
trans_len: count of words in a candidate translation
ref_len: count of words in a reference translation
n_gram: gram value ranged 1 to 4
tokenizer: A function that turns sentence into list of words
"""
reference_corpus_: Sequence[Sequence[Sequence[str]]] = [
[tokenizer(line) if line else [] for line in reference] for reference in reference_corpus
]
translate_corpus_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in translate_corpus]

for (translation, references) in zip(translate_corpus, reference_corpus):
for (translation, references) in zip(translate_corpus_, reference_corpus_):
trans_len += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
Expand Down Expand Up @@ -122,8 +140,8 @@ def _bleu_score_compute(


def bleu_score(
reference_corpus: Sequence[Sequence[Sequence[str]]],
translate_corpus: Sequence[Sequence[str]],
reference_corpus: Sequence[Union[str, Sequence[str]]],
translate_corpus: Union[str, Sequence[str]],
n_gram: int = 4,
smooth: bool = False,
) -> Tensor:
Expand All @@ -144,8 +162,8 @@ def bleu_score(
Example:
>>> from torchmetrics.functional import bleu_score
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> translate_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> bleu_score(reference_corpus, translate_corpus)
tensor(0.7598)
Expand All @@ -156,16 +174,21 @@ def bleu_score(
[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
"""
translate_corpus_ = [translate_corpus] if isinstance(translate_corpus, str) else translate_corpus
reference_corpus_ = [
[reference_text] if isinstance(reference_text, str) else reference_text for reference_text in reference_corpus
]

if len(translate_corpus_) != len(reference_corpus_):
raise ValueError(f"Corpus has different size {len(translate_corpus_)} != {len(reference_corpus_)}")

if len(translate_corpus) != len(reference_corpus):
raise ValueError(f"Corpus has different size {len(translate_corpus)} != {len(reference_corpus)}")
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
trans_len = tensor(0, dtype=torch.float)
ref_len = tensor(0, dtype=torch.float)

trans_len, ref_len = _bleu_score_update(
reference_corpus, translate_corpus, numerator, denominator, trans_len, ref_len, n_gram
reference_corpus_, translate_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn
)

return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth)
19 changes: 10 additions & 9 deletions torchmetrics/functional/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@


import re
from functools import partial
from typing import Sequence

import torch
Expand Down Expand Up @@ -335,21 +336,21 @@ def sacre_bleu_score(
"torchmetrics[text]`."
)

reference_corpus_: Sequence[Sequence[Sequence[str]]] = [
[_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase) for line in reference]
for reference in reference_corpus
]
translate_corpus_: Sequence[Sequence[str]] = [
_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase) for line in translate_corpus
]

numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
trans_len = tensor(0, dtype=torch.float)
ref_len = tensor(0, dtype=torch.float)

tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase)
trans_len, ref_len = _bleu_score_update(
reference_corpus_, translate_corpus_, numerator, denominator, trans_len, ref_len, n_gram
reference_corpus,
translate_corpus,
numerator,
denominator,
trans_len,
ref_len,
n_gram,
tokenize_fn,
)

return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth)
10 changes: 6 additions & 4 deletions torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import Tensor, tensor

from torchmetrics import Metric
from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update
from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn


class BLEUScore(Metric):
Expand All @@ -45,8 +45,8 @@ class BLEUScore(Metric):
will be used to perform the allgather.
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> translate_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric = BLEUScore()
>>> metric(reference_corpus, translate_corpus)
tensor(0.7598)
Expand Down Expand Up @@ -91,14 +91,15 @@ def __init__(
self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")

def update( # type: ignore
self, reference_corpus: Sequence[Sequence[Sequence[str]]], translate_corpus: Sequence[Sequence[str]]
self, reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str]
) -> None:
"""Compute Precision Scores.
Args:
reference_corpus: An iterable of iterables of reference corpus
translate_corpus: An iterable of machine translated corpus
"""

self.trans_len, self.ref_len = _bleu_score_update(
reference_corpus,
translate_corpus,
Expand All @@ -107,6 +108,7 @@ def update( # type: ignore
self.trans_len,
self.ref_len,
self.n_gram,
_tokenize_fn,
)

def compute(self) -> Tensor:
Expand Down
10 changes: 3 additions & 7 deletions torchmetrics/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,13 @@ def update( # type: ignore
reference_corpus: An iterable of iterables of reference corpus
translate_corpus: An iterable of machine translated corpus
"""
reference_corpus_: Sequence[Sequence[Sequence[str]]] = [
[self.tokenizer(line) for line in reference] for reference in reference_corpus
]
translate_corpus_: Sequence[Sequence[str]] = [self.tokenizer(line) for line in translate_corpus]

self.trans_len, self.ref_len = _bleu_score_update(
reference_corpus_,
translate_corpus_,
reference_corpus,
translate_corpus,
self.numerator,
self.denominator,
self.trans_len,
self.ref_len,
self.n_gram,
self.tokenizer,
)

0 comments on commit 8d5b8ba

Please sign in to comment.