Skip to content

Commit

Permalink
bleuscore: Add weights argument to allow non-uniform n-gram weights (#…
Browse files Browse the repository at this point in the history
…1075)

* bleuscore: Add weights argument to allow non-uniform n-gram weights
* Update chlog
  • Loading branch information
stancld authored Jun 7, 2022
1 parent c04090b commit f0279bb
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added an option to specify own n-gram weights for `BLEUScore` and `SacreBLEUScore` instead of using uniform weights only. ([#1075](https://github.com/PyTorchLightning/metrics/pull/1075))

-

Expand Down
19 changes: 19 additions & 0 deletions tests/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,22 @@ def test_no_4_gram_class():
preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(preds, targets) == tensor(0.0)


def test_no_and_uniform_weights_functional():
preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
no_weights_score = bleu_score(preds, targets, n_gram=2)
uniform_weights_score = bleu_score(preds, targets, n_gram=2, weights=[0.5, 0.5])
assert no_weights_score == uniform_weights_score


def test_no_and_uniform_weights_class():
no_weights_bleu = BLEUScore(n_gram=2)
uniform_weights_bleu = BLEUScore(n_gram=2, weights=[0.5, 0.5])

preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
no_weights_score = no_weights_bleu(preds, targets)
uniform_weights_score = uniform_weights_bleu(preds, targets)
assert no_weights_score == uniform_weights_score
19 changes: 19 additions & 0 deletions tests/text/test_sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,22 @@ def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase)
metric_functional=sacre_bleu_score,
metric_args=metric_args,
)


def test_no_and_uniform_weights_functional():
preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
no_weights_score = sacre_bleu_score(preds, targets, n_gram=2)
uniform_weights_score = sacre_bleu_score(preds, targets, n_gram=2, weights=[0.5, 0.5])
assert no_weights_score == uniform_weights_score


def test_no_and_uniform_weights_class():
no_weights_bleu = SacreBLEUScore(n_gram=2)
uniform_weights_bleu = SacreBLEUScore(n_gram=2, weights=[0.5, 0.5])

preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
no_weights_score = no_weights_bleu(preds, targets)
uniform_weights_score = uniform_weights_bleu(preds, targets)
assert no_weights_score == uniform_weights_score
25 changes: 20 additions & 5 deletions torchmetrics/functional/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Callable, Sequence, Tuple, Union
from typing import Callable, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -108,8 +108,9 @@ def _bleu_score_compute(
target_len: Tensor,
numerator: Tensor,
denominator: Tensor,
n_gram: int = 4,
smooth: bool = False,
n_gram: int,
weights: Sequence[float],
smooth: bool,
) -> Tensor:
"""Computes the BLEU score.
Expand All @@ -119,6 +120,7 @@ def _bleu_score_compute(
numerator: Numerator of precision score (true positives)
denominator: Denominator of precision score (true positives + false positives)
n_gram: gram value ranged 1 to 4
weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score.
smooth: Whether to apply smoothing
"""
device = numerator.device
Expand All @@ -134,7 +136,7 @@ def _bleu_score_compute(
else:
precision_scores = numerator / denominator

log_precision_scores = tensor([1.0 / n_gram] * n_gram, device=device) * torch.log(precision_scores)
log_precision_scores = tensor(weights, device=device) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len))
bleu = brevity_penalty * geometric_mean
Expand All @@ -147,6 +149,7 @@ def bleu_score(
target: Sequence[Union[str, Sequence[str]]],
n_gram: int = 4,
smooth: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor:
"""Calculate `BLEU score`_ of machine translated text with one or more references.
Expand All @@ -155,10 +158,17 @@ def bleu_score(
target: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing – see [2]
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used.
Return:
Tensor with BLEU Score
Raises:
ValueError: If ``preds`` and ``target`` corpus have different lengths.
ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
Example:
>>> from torchmetrics.functional import bleu_score
>>> preds = ['the cat is on the mat']
Expand All @@ -179,6 +189,11 @@ def bleu_score(
if len(preds_) != len(target_):
raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}")

if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
if weights is None:
weights = [1.0 / n_gram] * n_gram

numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
preds_len = tensor(0.0)
Expand All @@ -188,4 +203,4 @@ def bleu_score(
preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn
)

return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth)
return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)
17 changes: 15 additions & 2 deletions torchmetrics/functional/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import re
from functools import partial
from typing import Sequence
from typing import Optional, Sequence

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -283,6 +283,7 @@ def sacre_bleu_score(
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor:
"""Calculate `BLEU score`_ [1] of machine translated text with one or more references. This implementation
follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu.
Expand All @@ -295,10 +296,17 @@ def sacre_bleu_score(
tokenize: Tokenization technique to be used.
Supported tokenization: ['none', '13a', 'zh', 'intl', 'char']
lowercase: If ``True``, BLEU score over lowercased text is calculated.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used.
Return:
Tensor with BLEU Score
Raises:
ValueError: If ``preds`` and ``target`` corpus have different lengths.
ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
Example:
>>> from torchmetrics.functional import sacre_bleu_score
>>> preds = ['the cat is on the mat']
Expand Down Expand Up @@ -331,6 +339,11 @@ def sacre_bleu_score(
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
if weights is None:
weights = [1.0 / n_gram] * n_gram

numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
preds_len = tensor(0.0)
Expand All @@ -348,4 +361,4 @@ def sacre_bleu_score(
tokenize_fn,
)

return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth)
return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)
14 changes: 12 additions & 2 deletions torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence

import torch
from torch import Tensor, tensor
Expand All @@ -32,6 +32,12 @@ class BLEUScore(Metric):
n_gram: Gram value ranged from 1 to 4
smooth: Whether or not to apply smoothing, see [2]
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used.
Raises:
ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
Example:
>>> from torchmetrics import BLEUScore
Expand Down Expand Up @@ -62,11 +68,15 @@ def __init__(
self,
n_gram: int = 4,
smooth: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs: Dict[str, Any],
):
super().__init__(**kwargs)
self.n_gram = n_gram
self.smooth = smooth
if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
self.weights = weights if weights is not None else [1.0 / n_gram] * n_gram

self.add_state("preds_len", tensor(0.0), dist_reduce_fx="sum")
self.add_state("target_len", tensor(0.0), dist_reduce_fx="sum")
Expand Down Expand Up @@ -98,5 +108,5 @@ def compute(self) -> Tensor:
Tensor with BLEU Score
"""
return _bleu_score_compute(
self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.smooth
self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.weights, self.smooth
)
10 changes: 8 additions & 2 deletions torchmetrics/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence

from typing_extensions import Literal

Expand All @@ -42,12 +42,17 @@ class SacreBLEUScore(BLEUScore):
Supported tokenization: ``['none', '13a', 'zh', 'intl', 'char']``
lowercase: If ``True``, BLEU score over lowercased text is calculated.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used.
Raises:
ValueError:
If ``tokenize`` not one of 'none', '13a', 'zh', 'intl' or 'char'
ValueError:
If ``tokenize`` is set to 'intl' and `regex` is not installed
ValueError:
If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
Example:
Expand Down Expand Up @@ -78,9 +83,10 @@ def __init__(
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs: Dict[str, Any],
):
super().__init__(n_gram=n_gram, smooth=smooth, **kwargs)
super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs)
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

Expand Down

0 comments on commit f0279bb

Please sign in to comment.