From 92db9cb5f05a948acacb95f571b58e137bf9bbe1 Mon Sep 17 00:00:00 2001 From: Xinyan Guan Date: Sun, 4 Feb 2024 00:22:26 +0800 Subject: [PATCH 1/4] fix pred target misalignment --- src/torchmetrics/functional/text/bert.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 3623ff548cc..ff596fe349e 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -419,18 +419,12 @@ def bert_score( preds_loader, preds_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn ) + preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices] + target_embeddings = target_embeddings[target_loader.dataset.sorting_indices] + precision, recall, f1_score = _get_precision_recall_f1( preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale ) - # Sort predictions - if len(precision.shape) == 1: # i.e. when all_layers = False - precision = precision[preds_loader.dataset.sorting_indices] - recall = recall[preds_loader.dataset.sorting_indices] - f1_score = f1_score[preds_loader.dataset.sorting_indices] - elif len(precision.shape) == 2: # i.e. when all_layers = True - precision = precision[:, preds_loader.dataset.sorting_indices] - recall = recall[:, preds_loader.dataset.sorting_indices] - f1_score = f1_score[:, preds_loader.dataset.sorting_indices] if baseline is not None: precision, recall, f1_score = _rescale_metrics_with_baseline( From cfd73d308d5d2983948839eef867b288a4e6fbf2 Mon Sep 17 00:00:00 2001 From: guanxinyan Date: Tue, 16 Jul 2024 18:12:47 +0800 Subject: [PATCH 2/4] unsort idf --- src/torchmetrics/functional/text/bert.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index ff596fe349e..41c33718428 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -422,6 +422,9 @@ def bert_score( preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices] target_embeddings = target_embeddings[target_loader.dataset.sorting_indices] + preds_idf_scale = preds_idf_scale[preds_loader.dataset.sorting_indices] + target_idf_scale = target_idf_scale[target_loader.dataset.sorting_indices] + precision, recall, f1_score = _get_precision_recall_f1( preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale ) From b522ae52667bc3c42244bebaf8db3ede7a0f4ff6 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 19 Jul 2024 17:56:35 +0200 Subject: [PATCH 3/4] Add test --- tests/unittests/text/test_bertscore.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 0812808426f..dfd6d60a0e5 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -170,3 +170,24 @@ def test_bertscore_differentiability( metric_args=metric_args, key=metric_key, ) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + "idf", + [(False,), (True,)], +) +def test_bertscore_sorting(idf: bool): + """Test that BERTScore is invariant to the order of the inputs.""" + short = "Short text" + long = "This is a longer text" + + preds = [long, long] + targets = [long, short] + + metric = BERTScore(idf=idf) + score = metric(preds, targets) + + # First index should be the self-comparison - sorting by length should not shuffle this + assert score["f1"][0] > score["f1"][1] From dc3b758a93d6ede492301fd6dd5d429cbb129b5a Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 19 Jul 2024 18:21:12 +0200 Subject: [PATCH 4/4] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c09bdab17e0..78e86ed2761 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575)) +- Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347)) + + ## [1.4.0] - 2024-05-03 ### Added