From 4ce908b5814d4d90c20614bc256c721ab82df4d5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 24 Mar 2022 10:38:05 +0100 Subject: [PATCH 1/5] fix --- tests/text/test_bertscore.py | 5 +++-- torchmetrics/functional/text/bert.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index 6b34fdfad2c..567b1917dbc 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -94,8 +94,9 @@ def test_score_fn_with_idf(preds, targets): "preds,targets", [(preds, targets)], ) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn_all_layers(preds, targets): +def test_score_fn_all_layers(preds, targets, device): """Tests for functional and all layers.""" original_score = original_bert_score( preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3 @@ -103,7 +104,7 @@ def test_score_fn_all_layers(preds, targets): original_score = _parse_original_bert_score(original_score) metrics_score = metrics_bert_score( - preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3 + preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3, device=device ) for metric in _METRICS: diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index e3f8844dd81..64a56631064 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -330,7 +330,7 @@ def _get_embeddings_and_idf_scale( processed_attention_mask = _process_attention_mask_for_special_tokens(attention_mask) # Multiply embeddings with attention_mask (b=batch_size, l=num_layers, s=seq_len, d=emb_dim) out = torch.einsum("blsd, bs -> blsd", out, processed_attention_mask) - embeddings_list.append(out.cpu()) + embeddings_list.append(out) # Calculate weighted (w.r.t. sentence length) input_ids IDF matrix input_ids_idf = ( From fa16fe6f081f11acc9117fb7350ef76986d9e40b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 24 Mar 2022 10:40:16 +0100 Subject: [PATCH 2/5] skip --- tests/text/test_bertscore.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index 567b1917dbc..33ea67486a7 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -98,6 +98,9 @@ def test_score_fn_with_idf(preds, targets): @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") def test_score_fn_all_layers(preds, targets, device): """Tests for functional and all layers.""" + if not torch.cuda.is_available() and device == "cuda": + pytest.skip("Test requires GPU support") + original_score = original_bert_score( preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3 ) From 08c2ae812c543c6d9e6ca4bfee7ad1111e187692 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 24 Mar 2022 11:32:19 +0100 Subject: [PATCH 3/5] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f73bc17fc2..44236336936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,6 +121,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed compatibility of `ClasswiseWrapper` with the `prefix` argument of `MetricCollection` ([#843](https://github.com/PyTorchLightning/metrics/pull/843)) +- Fixed `BestScore` on GPU ([#912](https://github.com/PyTorchLightning/metrics/pull/912)) + + ## [0.7.3] - 2022-03-23 ### Fixed From 1f60aadf46871cc4621d94e2c162da53a0938378 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Fri, 25 Mar 2022 15:29:50 +0100 Subject: [PATCH 4/5] Enforce storing both embeddings and input_ids_idf on CPU --- torchmetrics/functional/text/bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 64a56631064..0e74fee6b55 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -330,14 +330,14 @@ def _get_embeddings_and_idf_scale( processed_attention_mask = _process_attention_mask_for_special_tokens(attention_mask) # Multiply embeddings with attention_mask (b=batch_size, l=num_layers, s=seq_len, d=emb_dim) out = torch.einsum("blsd, bs -> blsd", out, processed_attention_mask) - embeddings_list.append(out) + embeddings_list.append(out.cpu()) # Calculate weighted (w.r.t. sentence length) input_ids IDF matrix input_ids_idf = ( batch["input_ids_idf"] * processed_attention_mask if idf else processed_attention_mask.type(out.dtype) ) input_ids_idf /= input_ids_idf.sum(-1, keepdim=True) - idf_scale_list.append(input_ids_idf) + idf_scale_list.append(input_ids_idf.cpu()) embeddings = torch.cat(embeddings_list) idf_scale = torch.cat(idf_scale_list) From 4aa41b5d1bf24959fe080c1a889f6b948f58c126 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 31 Mar 2022 09:28:43 +0200 Subject: [PATCH 5/5] test None Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- tests/text/test_bertscore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index 33ea67486a7..c6eb391eba3 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -94,7 +94,7 @@ def test_score_fn_with_idf(preds, targets): "preds,targets", [(preds, targets)], ) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("device", ["cpu", "cuda", None]) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") def test_score_fn_all_layers(preds, targets, device): """Tests for functional and all layers."""