From 8d5b8bacbe75a4922fbf4097cc171195e57f7b78 Mon Sep 17 00:00:00 2001 From: "Mr. Leu" <40532483+mrleu@users.noreply.github.com> Date: Mon, 6 Dec 2021 05:26:59 -0800 Subject: [PATCH] Untokenized Bleu score to stay consistent with all the other text metrics (#640) * update bleu be consistent text metrics * fix examples * fix the additional parenthesis * updated black format * add tokenizer func * update name for readability Co-authored-by: Alex Leu Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: Jirka Borovec Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- CHANGELOG.md | 3 ++ tests/text/test_bleu.py | 54 +++++++++++++--------- torchmetrics/functional/text/bleu.py | 45 +++++++++++++----- torchmetrics/functional/text/sacre_bleu.py | 19 ++++---- torchmetrics/text/bleu.py | 10 ++-- torchmetrics/text/sacre_bleu.py | 10 ++-- 6 files changed, 87 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee4adff3cfe..e6730a05d93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622)) +- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640)) + + ### Deprecated diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index c71fae5ccac..df9de22eda9 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -25,29 +25,26 @@ # example taken from # https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu # EXAMPLE 1 -HYPOTHESIS_A = tuple( - "It is a guide to action which ensures that the military always obeys the commands of the party".split() -) -REFERENCE_1A = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) -REFERENCE_2A = tuple( - "It is a guiding principle which makes the military forces always being under the command of the Party".split() -) -REFERENCE_3A = tuple("It is the practical guide for the army always to heed the directions of the party".split()) +HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party" +REFERENCE_1A = "It is a guide to action that ensures that the military will forever heed Party commands" +REFERENCE_2A = "It is a guiding principle which makes the military forces always being under the command of the Party" +REFERENCE_3A = "It is the practical guide for the army always to heed the directions of the party" # EXAMPLE 2 -HYPOTHESIS_B = tuple("he read the book because he was interested in world history".split()) -REFERENCE_1B = tuple("he was interested in world history because he read the book".split()) +HYPOTHESIS_B = "he read the book because he was interested in world history" +REFERENCE_1B = "he was interested in world history because he read the book" # EXAMPLE 3 -HYPOTHESIS_C = tuple("the cat the cat on the mat".split()) -REFERENCE_1C = tuple("the cat is on the mat".split()) -REFERENCE_2C = tuple("there is a cat on the mat".split()) +HYPOTHESIS_C = "the cat the cat on the mat" +REFERENCE_1C = "the cat is on the mat" +REFERENCE_2C = "there is a cat on the mat" TUPLE_OF_REFERENCES = ( ((REFERENCE_1A, REFERENCE_2A, REFERENCE_3A), tuple([REFERENCE_1B])), (tuple([REFERENCE_1B]), (REFERENCE_1C, REFERENCE_2C)), + (REFERENCE_1B), ) -TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C)) +TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C), (HYPOTHESIS_B)) BATCHES = {"preds": TUPLE_OF_HYPOTHESES, "targets": TUPLE_OF_REFERENCES} @@ -55,6 +52,18 @@ smooth_func = SmoothingFunction().method2 +def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing_function, **kwargs): + hypotheses_ = [hypothesis.split() for hypothesis in hypotheses] + list_of_references_ = [[line.split() for line in ref] for ref in list_of_references] + return corpus_bleu( + list_of_references=list_of_references_, + hypotheses=hypotheses_, + weights=weights, + smoothing_function=smoothing_function, + **kwargs + ) + + @pytest.mark.parametrize( ["weights", "n_gram", "smooth_func", "smooth"], [ @@ -73,15 +82,14 @@ class TestBLEUScore(TextTester): @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, n_gram, smooth_func, smooth): metric_args = {"n_gram": n_gram, "smooth": smooth} - - nltk_metric = partial(corpus_bleu, weights=weights, smoothing_function=smooth_func) + compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func) self.run_class_metric_test( ddp=ddp, preds=preds, targets=targets, metric_class=BLEUScore, - sk_metric=nltk_metric, + sk_metric=compute_bleu_metric_nltk, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, @@ -89,13 +97,13 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth): metric_args = {"n_gram": n_gram, "smooth": smooth} - nltk_metric = partial(corpus_bleu, weights=weights, smoothing_function=smooth_func) + compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func) self.run_functional_metric_test( preds, targets, metric_functional=bleu_score, - sk_metric=nltk_metric, + sk_metric=compute_bleu_metric_nltk, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) @@ -120,8 +128,8 @@ def test_bleu_empty_functional(): def test_no_4_gram_functional(): - hyps = [["My", "full", "pytorch-lightning"]] - refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] + hyps = ["My full pytorch-lightning"] + refs = [["My full pytorch-lightning test", "Completely Different"]] assert bleu_score(refs, hyps) == tensor(0.0) @@ -134,6 +142,6 @@ def test_bleu_empty_class(): def test_no_4_gram_class(): bleu = BLEUScore() - hyps = [["My", "full", "pytorch-lightning"]] - refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] + hyps = ["My full pytorch-lightning"] + refs = [["My full pytorch-lightning test", "Completely Different"]] assert bleu(refs, hyps) == tensor(0.0) diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index a6102b74a52..81b6e100e9e 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -17,7 +17,7 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter -from typing import Sequence, Tuple +from typing import Callable, Sequence, Tuple, Union import torch from torch import Tensor, tensor @@ -44,14 +44,27 @@ def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter: return ngram_counter +def _tokenize_fn(sentence: str) -> Sequence[str]: + """Tokenizes sentence into list of words. + + Args: + sentence: A sentence separated by white space. + + Return: + List of words + """ + return sentence.split() + + def _bleu_score_update( - reference_corpus: Sequence[Sequence[Sequence[str]]], - translate_corpus: Sequence[Sequence[str]], + reference_corpus: Sequence[Sequence[str]], + translate_corpus: Sequence[str], numerator: Tensor, denominator: Tensor, trans_len: Tensor, ref_len: Tensor, n_gram: int = 4, + tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn, ) -> Tuple[Tensor, Tensor]: """Updates and returns variables required to compute the BLEU score. @@ -63,9 +76,14 @@ def _bleu_score_update( trans_len: count of words in a candidate translation ref_len: count of words in a reference translation n_gram: gram value ranged 1 to 4 + tokenizer: A function that turns sentence into list of words """ + reference_corpus_: Sequence[Sequence[Sequence[str]]] = [ + [tokenizer(line) if line else [] for line in reference] for reference in reference_corpus + ] + translate_corpus_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in translate_corpus] - for (translation, references) in zip(translate_corpus, reference_corpus): + for (translation, references) in zip(translate_corpus_, reference_corpus_): trans_len += len(translation) ref_len_list = [len(ref) for ref in references] ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] @@ -122,8 +140,8 @@ def _bleu_score_compute( def bleu_score( - reference_corpus: Sequence[Sequence[Sequence[str]]], - translate_corpus: Sequence[Sequence[str]], + reference_corpus: Sequence[Union[str, Sequence[str]]], + translate_corpus: Union[str, Sequence[str]], n_gram: int = 4, smooth: bool = False, ) -> Tensor: @@ -144,8 +162,8 @@ def bleu_score( Example: >>> from torchmetrics.functional import bleu_score - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> translate_corpus = ['the cat is on the mat'] + >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] >>> bleu_score(reference_corpus, translate_corpus) tensor(0.7598) @@ -156,16 +174,21 @@ def bleu_score( [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ + translate_corpus_ = [translate_corpus] if isinstance(translate_corpus, str) else translate_corpus + reference_corpus_ = [ + [reference_text] if isinstance(reference_text, str) else reference_text for reference_text in reference_corpus + ] + + if len(translate_corpus_) != len(reference_corpus_): + raise ValueError(f"Corpus has different size {len(translate_corpus_)} != {len(reference_corpus_)}") - if len(translate_corpus) != len(reference_corpus): - raise ValueError(f"Corpus has different size {len(translate_corpus)} != {len(reference_corpus)}") numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) trans_len = tensor(0, dtype=torch.float) ref_len = tensor(0, dtype=torch.float) trans_len, ref_len = _bleu_score_update( - reference_corpus, translate_corpus, numerator, denominator, trans_len, ref_len, n_gram + reference_corpus_, translate_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn ) return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index ff92210e3e8..e42409c4f3c 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -39,6 +39,7 @@ import re +from functools import partial from typing import Sequence import torch @@ -335,21 +336,21 @@ def sacre_bleu_score( "torchmetrics[text]`." ) - reference_corpus_: Sequence[Sequence[Sequence[str]]] = [ - [_SacreBLEUTokenizer.tokenize(line, tokenize, lowercase) for line in reference] - for reference in reference_corpus - ] - translate_corpus_: Sequence[Sequence[str]] = [ - _SacreBLEUTokenizer.tokenize(line, tokenize, lowercase) for line in translate_corpus - ] - numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) trans_len = tensor(0, dtype=torch.float) ref_len = tensor(0, dtype=torch.float) + tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase) trans_len, ref_len = _bleu_score_update( - reference_corpus_, translate_corpus_, numerator, denominator, trans_len, ref_len, n_gram + reference_corpus, + translate_corpus, + numerator, + denominator, + trans_len, + ref_len, + n_gram, + tokenize_fn, ) return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 6ba530c56ca..da2d1415bc8 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -22,7 +22,7 @@ from torch import Tensor, tensor from torchmetrics import Metric -from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update +from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn class BLEUScore(Metric): @@ -45,8 +45,8 @@ class BLEUScore(Metric): will be used to perform the allgather. Example: - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> translate_corpus = ['the cat is on the mat'] + >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = BLEUScore() >>> metric(reference_corpus, translate_corpus) tensor(0.7598) @@ -91,7 +91,7 @@ def __init__( self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") def update( # type: ignore - self, reference_corpus: Sequence[Sequence[Sequence[str]]], translate_corpus: Sequence[Sequence[str]] + self, reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str] ) -> None: """Compute Precision Scores. @@ -99,6 +99,7 @@ def update( # type: ignore reference_corpus: An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus """ + self.trans_len, self.ref_len = _bleu_score_update( reference_corpus, translate_corpus, @@ -107,6 +108,7 @@ def update( # type: ignore self.trans_len, self.ref_len, self.n_gram, + _tokenize_fn, ) def compute(self) -> Tensor: diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 71925d40224..b68d12106a6 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -118,17 +118,13 @@ def update( # type: ignore reference_corpus: An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus """ - reference_corpus_: Sequence[Sequence[Sequence[str]]] = [ - [self.tokenizer(line) for line in reference] for reference in reference_corpus - ] - translate_corpus_: Sequence[Sequence[str]] = [self.tokenizer(line) for line in translate_corpus] - self.trans_len, self.ref_len = _bleu_score_update( - reference_corpus_, - translate_corpus_, + reference_corpus, + translate_corpus, self.numerator, self.denominator, self.trans_len, self.ref_len, self.n_gram, + self.tokenizer, )