Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Untokenized Bleu score to stay consistent with all the other text metrics #640

Merged
merged 21 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions tests/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import pytest
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
from torch import tensor
Expand All @@ -25,23 +23,19 @@
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
# EXAMPLE 1
HYPOTHESIS_A = tuple(
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
)
REFERENCE_1A = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
REFERENCE_2A = tuple(
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
)
REFERENCE_3A = tuple("It is the practical guide for the army always to heed the directions of the party".split())
HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party"
REFERENCE_1A = "It is a guide to action that ensures that the military will forever heed Party commands"
REFERENCE_2A = "It is a guiding principle which makes the military forces always being under the command of the Party"
REFERENCE_3A = "It is the practical guide for the army always to heed the directions of the party"

# EXAMPLE 2
HYPOTHESIS_B = tuple("he read the book because he was interested in world history".split())
REFERENCE_1B = tuple("he was interested in world history because he read the book".split())
HYPOTHESIS_B = "he read the book because he was interested in world history"
REFERENCE_1B = "he was interested in world history because he read the book"

# EXAMPLE 3
HYPOTHESIS_C = tuple("the cat the cat on the mat".split())
REFERENCE_1C = tuple("the cat is on the mat".split())
REFERENCE_2C = tuple("there is a cat on the mat".split())
HYPOTHESIS_C = "the cat the cat on the mat"
REFERENCE_1C = "the cat is on the mat"
REFERENCE_2C = "there is a cat on the mat"

TUPLE_OF_REFERENCES = (
((REFERENCE_1A, REFERENCE_2A, REFERENCE_3A), tuple([REFERENCE_1B])),
Expand Down Expand Up @@ -76,7 +70,16 @@ class TestBLEUScore(TextTester):
def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, n_gram, smooth_func, smooth):
metric_args = {"n_gram": n_gram, "smooth": smooth}

nltk_metric = partial(corpus_bleu, weights=weights, smoothing_function=smooth_func)
def nltk_metric(list_of_references, hypotheses, weights=weights, smoothing_function=smooth_func, **kwargs):
hypotheses_ = [hypothesis.split() for hypothesis in hypotheses]
list_of_references_ = [[line.split() for line in ref] for ref in list_of_references]
return corpus_bleu(
list_of_references=list_of_references_,
hypotheses=hypotheses_,
weights=weights,
smoothing_function=smoothing_function,
**kwargs
)

self.run_class_metric_test(
ddp=ddp,
Expand All @@ -91,7 +94,17 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights,

def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth):
metric_args = {"n_gram": n_gram, "smooth": smooth}
nltk_metric = partial(corpus_bleu, weights=weights, smoothing_function=smooth_func)

def nltk_metric(list_of_references, hypotheses, weights=weights, smoothing_function=smooth_func, **kwargs):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
hypotheses_ = [hypothesis.split() for hypothesis in hypotheses]
list_of_references_ = [[line.split() for line in ref] for ref in list_of_references]
return corpus_bleu(
list_of_references=list_of_references_,
hypotheses=hypotheses_,
weights=weights,
smoothing_function=smoothing_function,
**kwargs
)

self.run_functional_metric_test(
preds,
Expand Down Expand Up @@ -122,8 +135,8 @@ def test_bleu_empty_functional():


def test_no_4_gram_functional():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu_score(refs, hyps) == tensor(0.0)


Expand All @@ -136,6 +149,6 @@ def test_bleu_empty_class():

def test_no_4_gram_class():
bleu = BLEUScore()
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(refs, hyps) == tensor(0.0)
18 changes: 11 additions & 7 deletions torchmetrics/functional/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:


def _bleu_score_update(
reference_corpus: Sequence[Sequence[Sequence[str]]],
translate_corpus: Sequence[Sequence[str]],
reference_corpus: Sequence[Sequence[str]],
translate_corpus: Sequence[str],
numerator: Tensor,
denominator: Tensor,
trans_len: Tensor,
Expand All @@ -64,8 +64,12 @@ def _bleu_score_update(
ref_len: count of words in a reference translation
n_gram: gram value ranged 1 to 4
"""
reference_corpus_: Sequence[Sequence[Sequence[str]]] = [
[line.split() if line else [] for line in reference] for reference in reference_corpus
]
translate_corpus_: Sequence[Sequence[str]] = [line.split() if line else [] for line in translate_corpus]

for (translation, references) in zip(translate_corpus, reference_corpus):
for (translation, references) in zip(translate_corpus_, reference_corpus_):
trans_len += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
Expand Down Expand Up @@ -122,8 +126,8 @@ def _bleu_score_compute(


def bleu_score(
reference_corpus: Sequence[Sequence[Sequence[str]]],
translate_corpus: Sequence[Sequence[str]],
reference_corpus: Sequence[Sequence[str]],
translate_corpus: Sequence[str],
Borda marked this conversation as resolved.
Show resolved Hide resolved
n_gram: int = 4,
smooth: bool = False,
) -> Tensor:
Expand All @@ -144,8 +148,8 @@ def bleu_score(

Example:
>>> from torchmetrics.functional import bleu_score
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> translate_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> bleu_score(reference_corpus, translate_corpus)
tensor(0.7598)

Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/functional/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ def sacre_bleu_score(
"torchmetrics[text]`."
)

reference_corpus_: Sequence[Sequence[Sequence[str]]] = [
[_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase) for line in reference]
reference_corpus_: Sequence[Sequence[str]] = [
[" ".join(_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase)) for line in reference]
for reference in reference_corpus
]
translate_corpus_: Sequence[Sequence[str]] = [
_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase) for line in translate_corpus
translate_corpus_: Sequence[str] = [
" ".join(_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase)) for line in translate_corpus
]

numerator = torch.zeros(n_gram)
Expand Down
7 changes: 4 additions & 3 deletions torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class BLEUScore(Metric):
will be used to perform the allgather.

Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> translate_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric = BLEUScore()
>>> metric(reference_corpus, translate_corpus)
tensor(0.7598)
Expand Down Expand Up @@ -91,14 +91,15 @@ def __init__(
self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")

def update( # type: ignore
self, reference_corpus: Sequence[Sequence[Sequence[str]]], translate_corpus: Sequence[Sequence[str]]
self, reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str]
) -> None:
"""Compute Precision Scores.

Args:
reference_corpus: An iterable of iterables of reference corpus
translate_corpus: An iterable of machine translated corpus
"""

self.trans_len, self.ref_len = _bleu_score_update(
reference_corpus,
translate_corpus,
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def update( # type: ignore
reference_corpus: An iterable of iterables of reference corpus
translate_corpus: An iterable of machine translated corpus
"""
reference_corpus_: Sequence[Sequence[Sequence[str]]] = [
[self.tokenizer(line) for line in reference] for reference in reference_corpus
reference_corpus_: Sequence[Sequence[str]] = [
[" ".join(self.tokenizer(line)) for line in reference] for reference in reference_corpus
Borda marked this conversation as resolved.
Show resolved Hide resolved
]
translate_corpus_: Sequence[Sequence[str]] = [self.tokenizer(line) for line in translate_corpus]
translate_corpus_: Sequence[str] = [" ".join(self.tokenizer(line)) for line in translate_corpus]

self.trans_len, self.ref_len = _bleu_score_update(
reference_corpus_,
Expand Down