Skip to content

Commit

Permalink
fix Bug in BertScore calculation: pred target misalignment (#2347)
Browse files Browse the repository at this point in the history
* fix pred target misalignment
* add test

---------

Co-authored-by: Xinyan Guan <[email protected]>
Co-authored-by: Bas Krahmer <[email protected]>
(cherry picked from commit 75c33ea)
  • Loading branch information
gxy-gxy authored and Borda committed Aug 2, 2024
1 parent 3ef3451 commit f5e31c2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575))


- Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347))


## [1.4.0] - 2024-05-03

### Added
Expand Down
15 changes: 6 additions & 9 deletions src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,18 +436,15 @@ def bert_score(
preds_loader, preds_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn
)

preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices]
target_embeddings = target_embeddings[target_loader.dataset.sorting_indices]

preds_idf_scale = preds_idf_scale[preds_loader.dataset.sorting_indices]
target_idf_scale = target_idf_scale[target_loader.dataset.sorting_indices]

precision, recall, f1_score = _get_precision_recall_f1(
preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale
)
# Sort predictions
if len(precision.shape) == 1: # i.e. when all_layers = False
precision = precision[preds_loader.dataset.sorting_indices]
recall = recall[preds_loader.dataset.sorting_indices]
f1_score = f1_score[preds_loader.dataset.sorting_indices]
elif len(precision.shape) == 2: # i.e. when all_layers = True
precision = precision[:, preds_loader.dataset.sorting_indices]
recall = recall[:, preds_loader.dataset.sorting_indices]
f1_score = f1_score[:, preds_loader.dataset.sorting_indices]

if baseline is not None:
precision, recall, f1_score = _rescale_metrics_with_baseline(
Expand Down
21 changes: 21 additions & 0 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,24 @@ def test_bertscore_differentiability(
metric_args=metric_args,
key=metric_key,
)


@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
"idf",
[(False,), (True,)],
)
def test_bertscore_sorting(idf: bool):
"""Test that BERTScore is invariant to the order of the inputs."""
short = "Short text"
long = "This is a longer text"

preds = [long, long]
targets = [long, short]

metric = BERTScore(idf=idf)
score = metric(preds, targets)

# First index should be the self-comparison - sorting by length should not shuffle this
assert score["f1"][0] > score["f1"][1]

0 comments on commit f5e31c2

Please sign in to comment.