-
Notifications
You must be signed in to change notification settings - Fork 411
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adding BertScore metric * Improvements * Update CHANGELOG.md * Final fixes * Apply suggestions from code review * Docstring updates Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka <[email protected]>
- Loading branch information
1 parent
525642d
commit 25e261a
Showing
11 changed files
with
350 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
jiwer>=2.2.0 | ||
nltk>=3.6 | ||
rouge-score>=0.0.4 | ||
bert-score==0.3.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from torchmetrics.functional import bert_score | ||
from torchmetrics.text import BERTScore | ||
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE | ||
|
||
# Examples and expected values taken from: | ||
# https://github.com/Tiiiger/bert_score/blob/master/tests/test_scorer.py | ||
preds = [ | ||
"28-year-old chef found dead in San Francisco mall", | ||
"A 28-year-old chef who recently moved to San Francisco was " | ||
"found dead in the staircase of a local shopping center.", | ||
"The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at " | ||
'him."', | ||
] | ||
refs = [ | ||
"28-Year-Old Chef Found Dead at San Francisco Mall", | ||
"A 28-year-old chef who had recently moved to San Francisco was found dead in the stairwell of a local mall this " | ||
"week.", | ||
"But the victim's brother says he can't think of anyone who would want to hurt him, saying, \"Things were finally " | ||
'going well for him."', | ||
] | ||
|
||
|
||
def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8): | ||
"""Assert two lists are equal.""" | ||
assert np.allclose(preds, refs, atol=threshold, equal_nan=True) | ||
|
||
|
||
preds_batched = [preds[0:2], preds[2:]] | ||
refs_batched = [refs[0:2], refs[2:]] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds,refs", | ||
[(preds, refs)], | ||
) | ||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") | ||
def test_score_fn(preds, refs): | ||
"""Tests for functional.""" | ||
Score = bert_score(preds, refs, model_type="roberta-large", num_layers=17, idf=False, batch_size=3) | ||
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516]) | ||
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749]) | ||
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds,refs", | ||
[(preds, refs)], | ||
) | ||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") | ||
def test_score(preds, refs): | ||
"""Tests for metric.""" | ||
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3) | ||
Scorer.update(predictions=preds, references=refs) | ||
Score = Scorer.compute() | ||
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516]) | ||
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749]) | ||
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds,refs", | ||
[(preds_batched, refs_batched)], | ||
) | ||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") | ||
def test_accumulation(preds, refs): | ||
"""Tests for metric works with accumulation.""" | ||
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3) | ||
for p, r in zip(preds, refs): | ||
Scorer.update(predictions=p, references=r) | ||
Score = Scorer.compute() | ||
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516]) | ||
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749]) | ||
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Optional | ||
|
||
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE | ||
|
||
if _BERTSCORE_AVAILABLE: | ||
from bert_score import BERTScorer, get_hash, lang2model, model2layers | ||
|
||
|
||
def bert_score( | ||
predictions: List[str], | ||
references: List[str], | ||
lang: str = "en", | ||
model_type: Optional[str] = None, | ||
num_layers: int = None, | ||
verbose: bool = False, | ||
idf: bool = False, | ||
device: Optional[str] = None, | ||
batch_size: int = 64, | ||
num_threads: int = 4, | ||
all_layers: bool = False, | ||
rescale_with_baseline: bool = False, | ||
baseline_path: Optional[str] = None, | ||
) -> Dict: | ||
"""`BERTScore <https://arxiv.org/abs/1904.09675>`_ leverages the pre-trained contextual embeddings from BERT | ||
and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate | ||
with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, | ||
recall, and F1 measure, which can be useful for evaluating different language generation tasks. | ||
Args: | ||
predictions: candidate sentences | ||
references: reference sentences | ||
model_type: bert specification | ||
num_layers: the layer of representation to use. | ||
verbose: turn on intermediate status update | ||
idf: use idf weighting, can also be a precomputed idf_dict | ||
device: on which the contextual embedding model will be allocated on. | ||
num_threads: number of threads | ||
batch_size: bert score processing batch size | ||
lang: language of the sentences | ||
rescale_with_baseline: rescale bertscore with pre-computed baseline | ||
baseline_path: customized baseline file | ||
Returns: | ||
Dict containing the keys `precision`, `recall`, `f1` and `hashcode` with corresponding values | ||
Example: | ||
>>> predictions = ["hello there", "general kenobi"] | ||
>>> references = ["hello there", "master kenobi"] | ||
>>> bert_score(predictions=predictions, references=references, lang="en") # doctest: +SKIP | ||
{'f1': [0.99..., 0.99...], | ||
'hashcode': '...', | ||
'precision': [0.99..., 0.99...], | ||
'recall': [0.99..., 0.99...]} | ||
""" | ||
|
||
if not _BERTSCORE_AVAILABLE: | ||
raise ValueError( | ||
"bert_score metric requires that bert-score package is installed." | ||
" Either install with `pip install bert-score` or `pip install torchmetrics[text]`" | ||
) | ||
|
||
if model_type is None: | ||
model_type = lang2model[lang.lower()] | ||
|
||
if num_layers is None: | ||
num_layers = model2layers[model_type] | ||
|
||
hashcode = get_hash( | ||
model=model_type, | ||
num_layers=num_layers, | ||
idf=idf, | ||
rescale_with_baseline=rescale_with_baseline, | ||
use_custom_baseline=baseline_path is not None, | ||
use_fast_tokenizer=True, | ||
) | ||
|
||
cached_bertscorer = BERTScorer( | ||
model_type=model_type, | ||
num_layers=num_layers, | ||
batch_size=batch_size, | ||
nthreads=num_threads, | ||
all_layers=all_layers, | ||
idf=idf, | ||
device=device, | ||
lang=lang, | ||
rescale_with_baseline=rescale_with_baseline, | ||
baseline_path=baseline_path, | ||
) | ||
|
||
prec, recall, f1 = cached_bertscorer.score( | ||
cands=predictions, | ||
refs=references, | ||
verbose=verbose, | ||
batch_size=batch_size, | ||
) | ||
output_dict = { | ||
"precision": prec.tolist(), | ||
"recall": recall.tolist(), | ||
"f1": f1.tolist(), | ||
"hashcode": hashcode, | ||
} | ||
return output_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.