diff --git a/CHANGELOG.md b/CHANGELOG.md index ee1bc72d827..93316f2ec29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added an option to specify own n-gram weights for `BLEUScore` and `SacreBLEUScore` instead of using uniform weights only. ([#1075](https://github.com/PyTorchLightning/metrics/pull/1075)) - diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index 029b5c34d1d..6d9307cc74c 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -113,3 +113,22 @@ def test_no_4_gram_class(): preds = ["My full pytorch-lightning"] targets = [["My full pytorch-lightning test", "Completely Different"]] assert bleu(preds, targets) == tensor(0.0) + + +def test_no_and_uniform_weights_functional(): + preds = ["My full pytorch-lightning"] + targets = [["My full pytorch-lightning test", "Completely Different"]] + no_weights_score = bleu_score(preds, targets, n_gram=2) + uniform_weights_score = bleu_score(preds, targets, n_gram=2, weights=[0.5, 0.5]) + assert no_weights_score == uniform_weights_score + + +def test_no_and_uniform_weights_class(): + no_weights_bleu = BLEUScore(n_gram=2) + uniform_weights_bleu = BLEUScore(n_gram=2, weights=[0.5, 0.5]) + + preds = ["My full pytorch-lightning"] + targets = [["My full pytorch-lightning test", "Completely Different"]] + no_weights_score = no_weights_bleu(preds, targets) + uniform_weights_score = uniform_weights_bleu(preds, targets) + assert no_weights_score == uniform_weights_score diff --git a/tests/text/test_sacre_bleu.py b/tests/text/test_sacre_bleu.py index f49e5798c6c..94648610c24 100644 --- a/tests/text/test_sacre_bleu.py +++ b/tests/text/test_sacre_bleu.py @@ -85,3 +85,22 @@ def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase) metric_functional=sacre_bleu_score, metric_args=metric_args, ) + + +def test_no_and_uniform_weights_functional(): + preds = ["My full pytorch-lightning"] + targets = [["My full pytorch-lightning test", "Completely Different"]] + no_weights_score = sacre_bleu_score(preds, targets, n_gram=2) + uniform_weights_score = sacre_bleu_score(preds, targets, n_gram=2, weights=[0.5, 0.5]) + assert no_weights_score == uniform_weights_score + + +def test_no_and_uniform_weights_class(): + no_weights_bleu = SacreBLEUScore(n_gram=2) + uniform_weights_bleu = SacreBLEUScore(n_gram=2, weights=[0.5, 0.5]) + + preds = ["My full pytorch-lightning"] + targets = [["My full pytorch-lightning test", "Completely Different"]] + no_weights_score = no_weights_bleu(preds, targets) + uniform_weights_score = uniform_weights_bleu(preds, targets) + assert no_weights_score == uniform_weights_score diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index fca35aae92d..93423c091bf 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -17,7 +17,7 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter -from typing import Callable, Sequence, Tuple, Union +from typing import Callable, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor @@ -108,8 +108,9 @@ def _bleu_score_compute( target_len: Tensor, numerator: Tensor, denominator: Tensor, - n_gram: int = 4, - smooth: bool = False, + n_gram: int, + weights: Sequence[float], + smooth: bool, ) -> Tensor: """Computes the BLEU score. @@ -119,6 +120,7 @@ def _bleu_score_compute( numerator: Numerator of precision score (true positives) denominator: Denominator of precision score (true positives + false positives) n_gram: gram value ranged 1 to 4 + weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score. smooth: Whether to apply smoothing """ device = numerator.device @@ -134,7 +136,7 @@ def _bleu_score_compute( else: precision_scores = numerator / denominator - log_precision_scores = tensor([1.0 / n_gram] * n_gram, device=device) * torch.log(precision_scores) + log_precision_scores = tensor(weights, device=device) * torch.log(precision_scores) geometric_mean = torch.exp(torch.sum(log_precision_scores)) brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len)) bleu = brevity_penalty * geometric_mean @@ -147,6 +149,7 @@ def bleu_score( target: Sequence[Union[str, Sequence[str]]], n_gram: int = 4, smooth: bool = False, + weights: Optional[Sequence[float]] = None, ) -> Tensor: """Calculate `BLEU score`_ of machine translated text with one or more references. @@ -155,10 +158,17 @@ def bleu_score( target: An iterable of iterables of reference corpus n_gram: Gram value ranged from 1 to 4 smooth: Whether to apply smoothing – see [2] + weights: + Weights used for unigrams, bigrams, etc. to calculate BLEU score. + If not provided, uniform weights are used. Return: Tensor with BLEU Score + Raises: + ValueError: If ``preds`` and ``target`` corpus have different lengths. + ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``. + Example: >>> from torchmetrics.functional import bleu_score >>> preds = ['the cat is on the mat'] @@ -179,6 +189,11 @@ def bleu_score( if len(preds_) != len(target_): raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") + if weights is not None and len(weights) != n_gram: + raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}") + if weights is None: + weights = [1.0 / n_gram] * n_gram + numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) preds_len = tensor(0.0) @@ -188,4 +203,4 @@ def bleu_score( preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn ) - return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth) + return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth) diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 8b21477e7a7..b9c166cd8da 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -39,7 +39,7 @@ import re from functools import partial -from typing import Sequence +from typing import Optional, Sequence import torch from torch import Tensor, tensor @@ -283,6 +283,7 @@ def sacre_bleu_score( smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", lowercase: bool = False, + weights: Optional[Sequence[float]] = None, ) -> Tensor: """Calculate `BLEU score`_ [1] of machine translated text with one or more references. This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu. @@ -295,10 +296,17 @@ def sacre_bleu_score( tokenize: Tokenization technique to be used. Supported tokenization: ['none', '13a', 'zh', 'intl', 'char'] lowercase: If ``True``, BLEU score over lowercased text is calculated. + weights: + Weights used for unigrams, bigrams, etc. to calculate BLEU score. + If not provided, uniform weights are used. Return: Tensor with BLEU Score + Raises: + ValueError: If ``preds`` and ``target`` corpus have different lengths. + ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``. + Example: >>> from torchmetrics.functional import sacre_bleu_score >>> preds = ['the cat is on the mat'] @@ -331,6 +339,11 @@ def sacre_bleu_score( " Use `pip install regex` or `pip install torchmetrics[text]`." ) + if weights is not None and len(weights) != n_gram: + raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}") + if weights is None: + weights = [1.0 / n_gram] * n_gram + numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) preds_len = tensor(0.0) @@ -348,4 +361,4 @@ def sacre_bleu_score( tokenize_fn, ) - return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth) + return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth) diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 0810f370acf..1f365d326e9 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,7 +16,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Dict, Sequence +from typing import Any, Dict, Optional, Sequence import torch from torch import Tensor, tensor @@ -32,6 +32,12 @@ class BLEUScore(Metric): n_gram: Gram value ranged from 1 to 4 smooth: Whether or not to apply smoothing, see [2] kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + weights: + Weights used for unigrams, bigrams, etc. to calculate BLEU score. + If not provided, uniform weights are used. + + Raises: + ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``. Example: >>> from torchmetrics import BLEUScore @@ -62,11 +68,15 @@ def __init__( self, n_gram: int = 4, smooth: bool = False, + weights: Optional[Sequence[float]] = None, **kwargs: Dict[str, Any], ): super().__init__(**kwargs) self.n_gram = n_gram self.smooth = smooth + if weights is not None and len(weights) != n_gram: + raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}") + self.weights = weights if weights is not None else [1.0 / n_gram] * n_gram self.add_state("preds_len", tensor(0.0), dist_reduce_fx="sum") self.add_state("target_len", tensor(0.0), dist_reduce_fx="sum") @@ -98,5 +108,5 @@ def compute(self) -> Tensor: Tensor with BLEU Score """ return _bleu_score_compute( - self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.smooth + self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.weights, self.smooth ) diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 628843698c2..9707beeb528 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -17,7 +17,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Dict, Sequence +from typing import Any, Dict, Optional, Sequence from typing_extensions import Literal @@ -42,12 +42,17 @@ class SacreBLEUScore(BLEUScore): Supported tokenization: ``['none', '13a', 'zh', 'intl', 'char']`` lowercase: If ``True``, BLEU score over lowercased text is calculated. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + weights: + Weights used for unigrams, bigrams, etc. to calculate BLEU score. + If not provided, uniform weights are used. Raises: ValueError: If ``tokenize`` not one of 'none', '13a', 'zh', 'intl' or 'char' ValueError: If ``tokenize`` is set to 'intl' and `regex` is not installed + ValueError: + If a length of a list of weights is not ``None`` and not equal to ``n_gram``. Example: @@ -78,9 +83,10 @@ def __init__( smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", lowercase: bool = False, + weights: Optional[Sequence[float]] = None, **kwargs: Dict[str, Any], ): - super().__init__(n_gram=n_gram, smooth=smooth, **kwargs) + super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")