Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding BERTScore metric #424

Merged
merged 72 commits into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
dec9571
Adding BertScore metric
gagan3012 Aug 3, 2021
192c776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
703b5f5
Lint fix
gagan3012 Aug 3, 2021
684904f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
4354d64
Update test_bertscore.py
gagan3012 Aug 3, 2021
fe2cf88
Update test_bertscore.py
gagan3012 Aug 3, 2021
a4d84f4
Update test_bertscore.py
gagan3012 Aug 3, 2021
0cbb9a4
fixes
gagan3012 Aug 3, 2021
bf2811f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
a8a0dc4
Improvements
gagan3012 Aug 3, 2021
eb1978c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
b40365c
Update test_bertscore.py
gagan3012 Aug 3, 2021
87539a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
99f5273
Minor fix
gagan3012 Aug 3, 2021
90d34b7
Refactoring and applying changes
gagan3012 Aug 4, 2021
dc58f3a
updates and refactoring
gagan3012 Aug 4, 2021
04c6fab
fixes
gagan3012 Aug 4, 2021
1712870
Update functional.rst
gagan3012 Aug 4, 2021
9f18376
flake8
gagan3012 Aug 4, 2021
4f20fcb
Update CHANGELOG.md
gagan3012 Aug 4, 2021
e8ae706
Merge branch 'master' into feature/bert-score
Borda Aug 4, 2021
90eb058
adding changes to files
gagan3012 Aug 4, 2021
8a307c9
Merge branch 'master' into feature/bert-score
Borda Aug 4, 2021
184ef15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2021
28c5c12
test fixes
gagan3012 Aug 4, 2021
faf67cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2021
14e6fab
Update test_bertscore.py
gagan3012 Aug 4, 2021
75f602b
Update bert.py
gagan3012 Aug 4, 2021
e1f0196
Update bert.py
gagan3012 Aug 4, 2021
5c82141
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2021
1a25fb2
Final fixes
gagan3012 Aug 4, 2021
15ff499
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2021
ec755ef
Apply suggestions from code review
SkafteNicki Aug 5, 2021
cf7cf70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
222f932
Merge branch 'master' into feature/bert-score
Borda Aug 5, 2021
1f4d0c1
Update tests/text/test_bertscore.py
gagan3012 Aug 5, 2021
af69f4d
Update tests/text/test_bertscore.py
gagan3012 Aug 5, 2021
9bc0d6c
Update torchmetrics/text/bert.py
gagan3012 Aug 5, 2021
e9d3237
Update torchmetrics/text/bert.py
gagan3012 Aug 5, 2021
d2072fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
05c5983
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
5003012
Update test_bertscore.py
gagan3012 Aug 5, 2021
6000f79
Update bert.py
gagan3012 Aug 5, 2021
9a2a87a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
26744d6
Update text.txt
gagan3012 Aug 5, 2021
384a90a
Merge branch 'feature/bert-score' of https://github.com/gagan3012/met…
gagan3012 Aug 5, 2021
2dd26a0
Changes to docstring
gagan3012 Aug 5, 2021
5629cc0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
4d8524c
Update bert.py
gagan3012 Aug 5, 2021
1252589
Merge branch 'feature/bert-score' of https://github.com/gagan3012/met…
gagan3012 Aug 5, 2021
30854cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
8605fa0
Merge branch 'master' into feature/bert-score
SkafteNicki Aug 6, 2021
6b97a04
Update torchmetrics/functional/text/bert.py
gagan3012 Aug 6, 2021
eb23158
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
24e1488
Update torchmetrics/text/bert.py
gagan3012 Aug 6, 2021
37db353
Update torchmetrics/text/bert.py
gagan3012 Aug 6, 2021
7436479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
2ff8298
Docstring updates
gagan3012 Aug 6, 2021
209c46c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
eabf020
Update bert.py
gagan3012 Aug 6, 2021
955fdd3
Merge branch 'feature/bert-score' of https://github.com/gagan3012/met…
gagan3012 Aug 6, 2021
bbac727
Fixes
gagan3012 Aug 6, 2021
0267a55
Docstring error fix
gagan3012 Aug 6, 2021
2346f17
Fixes again
gagan3012 Aug 6, 2021
e624e0e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
a1bbcd3
Test fixing
gagan3012 Aug 6, 2021
608b02d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
5ac11e7
New fixes
gagan3012 Aug 6, 2021
733d028
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
4209669
docs
Borda Aug 8, 2021
55aa327
docs
Borda Aug 8, 2021
8801e64
Update torchmetrics/text/bert.py
SkafteNicki Aug 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,8 @@ wer [func]

.. autofunction:: torchmetrics.functional.wer
:noindex:

bertscore [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bertscore
5 changes: 5 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,11 @@ WER
.. autoclass:: torchmetrics.WER
:noindex:

BERTScore
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~

.. autoclass:: torchmetrics.BERTScore
:noindex:

********
Wrappers
Expand Down
1 change: 1 addition & 0 deletions requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
jiwer>=2.2.0
nltk>=3.6
rouge-score>=0.0.4
bert-score>=0.3.8
63 changes: 63 additions & 0 deletions tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import pytest

from torchmetrics.functional import bertscore
from torchmetrics.text import BERTScore
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

def assertTensorsAlmostEqual(expected, actual, decimal=5):
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
"""
Test tensors are almost equal (EPS = 1e-5 by default)
"""
np.testing.assert_almost_equal(expected, actual, decimal=decimal)


preds = [
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
"28-year-old chef found dead in San Francisco mall",
"A 28-year-old chef who recently moved to San Francisco was "
"found dead in the staircase of a local shopping center.",
"The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at "
'him."',
]
refs = [
"28-Year-Old Chef Found Dead at San Francisco Mall",
"A 28-year-old chef who had recently moved to San Francisco was found dead in the stairwell of a local mall this "
"week.",
"But the victim's brother says he can't think of anyone who would want to hurt him, saying, \"Things were finally "
'going well for him."',
]


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn(preds, refs):
Score = bertscore(preds, refs, model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
assertTensorsAlmostEqual(
expected=Score["precision"], actual=[0.9843302369117737, 0.9832239747047424, 0.9120386242866516]
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
)
assertTensorsAlmostEqual(
expected=Score["recall"], actual=[0.9823839068412781, 0.9732863903045654, 0.920428991317749]
)
assertTensorsAlmostEqual(expected=Score["f1"], actual=[0.9833561182022095, 0.9782299995422363, 0.916214644908905])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score(preds, refs):
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Score = Scorer.compute()
assertTensorsAlmostEqual(
expected=Score["precision"], actual=[0.9843302369117737, 0.9832239747047424, 0.9120386242866516]
)
assertTensorsAlmostEqual(
expected=Score["recall"], actual=[0.9823839068412781, 0.9732863903045654, 0.920428991317749]
)
assertTensorsAlmostEqual(expected=Score["f1"], actual=[0.9833561182022095, 0.9782299995422363, 0.916214644908905])
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import WER, BLEUScore, ROUGEScore # noqa: E402, F401
from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402, F401
from torchmetrics.wrappers import BootStrapper # noqa: E402, F401
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
from torchmetrics.functional.text.bertscore import bertscore # noqa: F401
from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
from torchmetrics.functional.text.rouge import rouge_score # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
128 changes: 128 additions & 0 deletions torchmetrics/functional/text/bertscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright The PyTorch Lightning team.
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List

from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE

if _BERTSCORE_AVAILABLE:
import bert_score


def bertscore(
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
predictions: List,
references: List,
model_type: str = None,
num_layers: int = None,
verbose: bool = False,
idf: bool = False,
device: str = None,
batch_size: int = 64,
nthreads: int = 4,
all_layers: bool = False,
lang: str = None,
rescale_with_baseline: bool = False,
baseline_path: str = None,
) -> Dict:
"""BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and
reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level
and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be
useful for evaluating different language generation tasks.

Args:
- :param: `predictions` (list of str): candidate sentences
- :references: `refs` (list of str): reference sentences
- :param: `model_type` (str): bert specification, default using the suggested
model for the target langauge; has to specify at least one of
`model_type` or `lang`
- :param: `num_layers` (int): the layer of representation to use.
default using the number of layer tuned on WMT16 correlation data
- :param: `verbose` (bool): turn on intermediate status update
- :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict
- :param: `device` (str): on which the contextual embedding model will be allocated on.
If this argument is None, the model lives on cuda:0 if cuda is available.
- :param: `nthreads` (int): number of threads
- :param: `batch_size` (int): bert score processing batch size
- :param: `lang` (str): language of the sentences; has to specify
at least one of `model_type` or `lang`. `lang` needs to be
specified when `rescale_with_baseline` is True.
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
- :param: `baseline_path` (str): customized baseline file
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
- :param: `(P, R, F)`: each is of shape (N); N = number of input
candidate reference pairs. if returning hashcode, the
output will be ((P, R, F), hashcode). If a candidate have
multiple references, the returned score of this candidate is
the *best* score among all references.

Example:
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> results = bertscore(predictions=predictions, references=references, lang="en")
>>> print([round(v, 2) for v in results["f1"]])
[1.0, 1.0]
"""
if model_type is None:
model_type = bert_score.lang2model[lang.lower()]

if num_layers is None:
num_layers = bert_score.model2layers[model_type]

hashcode = bert_score.get_hash(
model=model_type,
num_layers=num_layers,
idf=idf,
rescale_with_baseline=rescale_with_baseline,
use_custom_baseline=baseline_path is not None,
)

cached_bertscorer = bert_score.BERTScorer(
model_type=model_type,
num_layers=num_layers,
batch_size=batch_size,
nthreads=nthreads,
all_layers=all_layers,
idf=idf,
device=device,
lang=lang,
rescale_with_baseline=rescale_with_baseline,
baseline_path=baseline_path,
)
if cached_bertscorer.hash != hashcode:
cached_bertscorer = bert_score.BERTScorer(
model_type=model_type,
num_layers=num_layers,
batch_size=batch_size,
nthreads=nthreads,
all_layers=all_layers,
idf=idf,
device=device,
lang=lang,
rescale_with_baseline=rescale_with_baseline,
baseline_path=baseline_path,
)

(P, R, F) = cached_bertscorer.score(
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
cands=predictions,
refs=references,
verbose=verbose,
batch_size=batch_size,
)
output_dict = {
"precision": P.tolist(),
"recall": R.tolist(),
"f1": F.tolist(),
"hashcode": hashcode,
}
return output_dict
1 change: 1 addition & 0 deletions torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.text.bertscore import BERTScore # noqa: F401
from torchmetrics.text.bleu import BLEUScore # noqa: F401
from torchmetrics.text.rouge import ROUGEScore # noqa: F401
from torchmetrics.text.wer import WER # noqa: F401
129 changes: 129 additions & 0 deletions torchmetrics/text/bertscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional

from torchmetrics.functional import bertscore
from torchmetrics.metric import Metric


class BERTScore(Metric):
"""BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and
reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level
and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be
useful for evaluating different language generation tasks.

Args:
- :param: `cands` (list of str): candidate sentences
- :param: `refs` (list of str or list of list of str): reference sentences
- :param: `model_type` (str): bert specification, default using the suggested
model for the target langauge; has to specify at least one of
`model_type` or `lang`
- :param: `num_layers` (int): the layer of representation to use.
default using the number of layer tuned on WMT16 correlation data
- :param: `verbose` (bool): turn on intermediate status update
- :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict
- :param: `device` (str): on which the contextual embedding model will be allocated on.
If this argument is None, the model lives on cuda:0 if cuda is available.
- :param: `nthreads` (int): number of threads
- :param: `batch_size` (int): bert score processing batch size
- :param: `lang` (str): language of the sentences; has to specify
at least one of `model_type` or `lang`. `lang` needs to be
specified when `rescale_with_baseline` is True.
- :param: `return_hash` (bool): return hash code of the setting
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
- :param: `baseline_path` (str): customized baseline file
- :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer

Returns:
- :param: `(P, R, F)`: each is of shape (N); N = number of input
candidate reference pairs. if returning hashcode, the
output will be ((P, R, F), hashcode). If a candidate have
multiple references, the returned score of this candidate is
the *best* score among all references.

Example:
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> results = bertscore(predictions=predictions, references=references, lang="en")
>>> print([round(v, 2) for v in results["f1"]])
[1.0, 1.0]
"""

def __init__(
self,
model_type: str = None,
num_layers: int = None,
verbose: bool = False,
idf: bool = False,
device: str = None,
batch_size: int = 64,
nthreads: int = 4,
all_layers: bool = False,
lang: str = None,
rescale_with_baseline: bool = False,
baseline_path: str = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.baseline_path = baseline_path
self.rescale_with_baseline = rescale_with_baseline
self.lang = lang
self.all_layers = all_layers
self.nthreads = nthreads
self.batch_size = batch_size
self.device = device
self.idf = idf
self.verbose = verbose
self.num_layers = num_layers
self.model_type = model_type
self.add_state("predictions", [], dist_reduce_fx="cat")
self.add_state("references", [], dist_reduce_fx="cat")

def update(self, predictions: List[str], references: List[str]) -> None: # type: ignore
"""Store predictions/references for computing BERT scores.

Args:
predictions: List of predicted sentences
references: List of refernces
"""
self.predictions.append(predictions)
self.references.append(references)

def compute(self) -> Dict:
"""Calculate Bertscores.

Return:
Dict with Bertscores.
"""
return bertscore(
predictions=self.predictions,
references=self.references,
model_type=self.model_type,
num_layers=self.num_layers,
verbose=self.verbose,
idf=self.idf,
device=self.device,
baseline_path=self.baseline_path,
batch_size=self.batch_size,
lang=self.lang,
all_layers=self.all_layers,
)
1 change: 1 addition & 0 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
_LIGHTNING_AVAILABLE: bool = _module_available("pytorch_lightning")

_JIWER_AVAILABLE: bool = _module_available("jiwer")
_BERTSCORE_AVAILABLE: bool = _module_available("bert_score")
_NLTK_AVAILABLE = _module_available("nltk")
_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score")
_SCIPY_AVAILABLE: bool = _module_available("scipy")
Expand Down