Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BertScore on GPU #912

Merged
merged 11 commits into from
Mar 31, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed compatibility of `ClasswiseWrapper` with the `prefix` argument of `MetricCollection` ([#843](https://github.com/PyTorchLightning/metrics/pull/843))


- Fixed `BestScore` on GPU ([#912](https://github.com/PyTorchLightning/metrics/pull/912))


## [0.7.3] - 2022-03-23

### Fixed
Expand Down
8 changes: 6 additions & 2 deletions tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,20 @@ def test_score_fn_with_idf(preds, targets):
"preds,targets",
[(preds, targets)],
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers(preds, targets):
def test_score_fn_all_layers(preds, targets, device):
"""Tests for functional and all layers."""
if not torch.cuda.is_available() and device == "cuda":
pytest.skip("Test requires GPU support")

original_score = original_bert_score(
preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3
preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3, device=device
)

for metric in _METRICS:
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _get_embeddings_and_idf_scale(
batch["input_ids_idf"] * processed_attention_mask if idf else processed_attention_mask.type(out.dtype)
)
input_ids_idf /= input_ids_idf.sum(-1, keepdim=True)
idf_scale_list.append(input_ids_idf)
idf_scale_list.append(input_ids_idf.cpu())

embeddings = torch.cat(embeddings_list)
idf_scale = torch.cat(idf_scale_list)
Expand Down