Skip to content

Commit

Permalink
Adding BERTScore metric (#424)
Browse files Browse the repository at this point in the history
* Adding BertScore metric

* Improvements

* Update CHANGELOG.md

* Final fixes

* Apply suggestions from code review

* Docstring updates

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
5 people authored Aug 8, 2021
1 parent 525642d commit 25e261a
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- ROUGE ([#399](https://github.com/PyTorchLightning/metrics/issues/399))

- BERT score ([#365](https://github.com/PyTorchLightning/metrics/issues/365))


- Added `MetricTracker` wrapper metric for keeping track of the same metric over multiple epochs ([#238](https://github.com/PyTorchLightning/metrics/pull/238))

Expand Down
5 changes: 5 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ retrieval_normalized_dcg [func]
Text
****

bert_score [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bert_score

bleu_score [func]
~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 6 additions & 1 deletion docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,12 @@ RetrievalNormalizedDCG
Text
****

BERTScore
~~~~~~~~~~

.. autoclass:: torchmetrics.BERTScore
:noindex:

BLEUScore
~~~~~~~~~

Expand All @@ -543,7 +549,6 @@ WER
.. autoclass:: torchmetrics.WER
:noindex:


********
Wrappers
********
Expand Down
1 change: 1 addition & 0 deletions requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
jiwer>=2.2.0
nltk>=3.6
rouge-score>=0.0.4
bert-score==0.3.10
78 changes: 78 additions & 0 deletions tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any

import numpy as np
import pytest

from torchmetrics.functional import bert_score
from torchmetrics.text import BERTScore
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE

# Examples and expected values taken from:
# https://github.com/Tiiiger/bert_score/blob/master/tests/test_scorer.py
preds = [
"28-year-old chef found dead in San Francisco mall",
"A 28-year-old chef who recently moved to San Francisco was "
"found dead in the staircase of a local shopping center.",
"The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at "
'him."',
]
refs = [
"28-Year-Old Chef Found Dead at San Francisco Mall",
"A 28-year-old chef who had recently moved to San Francisco was found dead in the stairwell of a local mall this "
"week.",
"But the victim's brother says he can't think of anyone who would want to hurt him, saying, \"Things were finally "
'going well for him."',
]


def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8):
"""Assert two lists are equal."""
assert np.allclose(preds, refs, atol=threshold, equal_nan=True)


preds_batched = [preds[0:2], preds[2:]]
refs_batched = [refs[0:2], refs[2:]]


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn(preds, refs):
"""Tests for functional."""
Score = bert_score(preds, refs, model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score(preds, refs):
"""Tests for metric."""
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Score = Scorer.compute()
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])


@pytest.mark.parametrize(
"preds,refs",
[(preds_batched, refs_batched)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_accumulation(preds, refs):
"""Tests for metric works with accumulation."""
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
for p, r in zip(preds, refs):
Scorer.update(predictions=p, references=r)
Score = Scorer.compute()
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import WER, BLEUScore, ROUGEScore # noqa: E402
from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402

__all__ = [
Expand All @@ -75,6 +75,7 @@
"BinnedAveragePrecision",
"BinnedPrecisionRecallCurve",
"BinnedRecallAtFixedPrecision",
"BERTScore",
"BLEUScore",
"BootStrapper",
"CalibrationError",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.functional.self_supervised import embedding_similarity
from torchmetrics.functional.text.bert import bert_score
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.wer import wer
Expand All @@ -66,6 +67,7 @@
"auc",
"auroc",
"average_precision",
"bert_score",
"bleu_score",
"calibration_error",
"cohen_kappa",
Expand Down
115 changes: 115 additions & 0 deletions torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional

from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE

if _BERTSCORE_AVAILABLE:
from bert_score import BERTScorer, get_hash, lang2model, model2layers


def bert_score(
predictions: List[str],
references: List[str],
lang: str = "en",
model_type: Optional[str] = None,
num_layers: int = None,
verbose: bool = False,
idf: bool = False,
device: Optional[str] = None,
batch_size: int = 64,
num_threads: int = 4,
all_layers: bool = False,
rescale_with_baseline: bool = False,
baseline_path: Optional[str] = None,
) -> Dict:
"""`BERTScore <https://arxiv.org/abs/1904.09675>`_ leverages the pre-trained contextual embeddings from BERT
and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate
with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision,
recall, and F1 measure, which can be useful for evaluating different language generation tasks.
Args:
predictions: candidate sentences
references: reference sentences
model_type: bert specification
num_layers: the layer of representation to use.
verbose: turn on intermediate status update
idf: use idf weighting, can also be a precomputed idf_dict
device: on which the contextual embedding model will be allocated on.
num_threads: number of threads
batch_size: bert score processing batch size
lang: language of the sentences
rescale_with_baseline: rescale bertscore with pre-computed baseline
baseline_path: customized baseline file
Returns:
Dict containing the keys `precision`, `recall`, `f1` and `hashcode` with corresponding values
Example:
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "master kenobi"]
>>> bert_score(predictions=predictions, references=references, lang="en") # doctest: +SKIP
{'f1': [0.99..., 0.99...],
'hashcode': '...',
'precision': [0.99..., 0.99...],
'recall': [0.99..., 0.99...]}
"""

if not _BERTSCORE_AVAILABLE:
raise ValueError(
"bert_score metric requires that bert-score package is installed."
" Either install with `pip install bert-score` or `pip install torchmetrics[text]`"
)

if model_type is None:
model_type = lang2model[lang.lower()]

if num_layers is None:
num_layers = model2layers[model_type]

hashcode = get_hash(
model=model_type,
num_layers=num_layers,
idf=idf,
rescale_with_baseline=rescale_with_baseline,
use_custom_baseline=baseline_path is not None,
use_fast_tokenizer=True,
)

cached_bertscorer = BERTScorer(
model_type=model_type,
num_layers=num_layers,
batch_size=batch_size,
nthreads=num_threads,
all_layers=all_layers,
idf=idf,
device=device,
lang=lang,
rescale_with_baseline=rescale_with_baseline,
baseline_path=baseline_path,
)

prec, recall, f1 = cached_bertscorer.score(
cands=predictions,
refs=references,
verbose=verbose,
batch_size=batch_size,
)
output_dict = {
"precision": prec.tolist(),
"recall": recall.tolist(),
"f1": f1.tolist(),
"hashcode": hashcode,
}
return output_dict
1 change: 1 addition & 0 deletions torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.text.bert import BERTScore # noqa: F401
from torchmetrics.text.bleu import BLEUScore # noqa: F401
from torchmetrics.text.rouge import ROUGEScore # noqa: F401
from torchmetrics.text.wer import WER # noqa: F401
Loading

0 comments on commit 25e261a

Please sign in to comment.