diff --git a/CHANGELOG.md b/CHANGELOG.md index c09bdab17e0..78e86ed2761 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575)) +- Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347)) + + ## [1.4.0] - 2024-05-03 ### Added diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 1b7a30c68c7..cfdb8c743b4 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -436,18 +436,15 @@ def bert_score( preds_loader, preds_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn ) + preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices] + target_embeddings = target_embeddings[target_loader.dataset.sorting_indices] + + preds_idf_scale = preds_idf_scale[preds_loader.dataset.sorting_indices] + target_idf_scale = target_idf_scale[target_loader.dataset.sorting_indices] + precision, recall, f1_score = _get_precision_recall_f1( preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale ) - # Sort predictions - if len(precision.shape) == 1: # i.e. when all_layers = False - precision = precision[preds_loader.dataset.sorting_indices] - recall = recall[preds_loader.dataset.sorting_indices] - f1_score = f1_score[preds_loader.dataset.sorting_indices] - elif len(precision.shape) == 2: # i.e. when all_layers = True - precision = precision[:, preds_loader.dataset.sorting_indices] - recall = recall[:, preds_loader.dataset.sorting_indices] - f1_score = f1_score[:, preds_loader.dataset.sorting_indices] if baseline is not None: precision, recall, f1_score = _rescale_metrics_with_baseline( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 0812808426f..dfd6d60a0e5 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -170,3 +170,24 @@ def test_bertscore_differentiability( metric_args=metric_args, key=metric_key, ) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + "idf", + [(False,), (True,)], +) +def test_bertscore_sorting(idf: bool): + """Test that BERTScore is invariant to the order of the inputs.""" + short = "Short text" + long = "This is a longer text" + + preds = [long, long] + targets = [long, short] + + metric = BERTScore(idf=idf) + score = metric(preds, targets) + + # First index should be the self-comparison - sorting by length should not shuffle this + assert score["f1"][0] > score["f1"][1]