Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BERTScore raises RuntimeError when device="cuda" #909

Closed
om-hb opened this issue Mar 23, 2022 · 1 comment · Fixed by #912
Closed

BERTScore raises RuntimeError when device="cuda" #909

om-hb opened this issue Mar 23, 2022 · 1 comment · Fixed by #912
Assignees
Labels
bug / fix Something isn't working
Milestone

Comments

@om-hb
Copy link

om-hb commented Mar 23, 2022

Hi,

when I run the following (exemplary) code,

from torchmetrics.text.bert import BERTScore

bertscore = BERTScore(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cuda")
bertscore = bertscore.to("cuda")

preds = ["the sun is a star"]
target = ["the sun is classified as a star"]

results = bertscore(preds, target)

I get the following RuntimeError:

--------------------------------------------------
RuntimeError     Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_29808\645715684.py in <cell line: 1>()
----> 1 results = bertscore(preds, target)

~\Anaconda3\envs\xxx\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\metric.py in forward(self, *args, **kwargs)
    252         self.reset()
    253         self.update(*args, **kwargs)
--> 254         self._forward_cache = self.compute()
    255 
    256         # restore context

~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\metric.py in wrapped_func(*args, **kwargs)
    418                 should_unsync=self._should_unsync,
    419             ):
--> 420                 value = compute(*args, **kwargs)
    421                 self._computed = _squeeze_if_scalar(value)
    422 

~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\text\bert.py in compute(self)
    231             Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
    232         """
--> 233         return bert_score(
    234             preds=_get_input_dict(self.preds_input_ids, self.preds_attention_mask),
    235             target=_get_input_dict(self.target_input_ids, self.target_attention_mask),

~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\functional\text\bert.py in bert_score(preds, target, model_name_or_path, num_layers, all_layers, model, user_tokenizer, user_forward_fn, verbose, idf, device, max_length, batch_size, num_threads, return_hash, lang, rescale_with_baseline, baseline_path, baseline_url)
    646     )
    647 
--> 648     precision, recall, f1_score = _get_precision_recall_f1(
    649         preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale
    650     )

~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\functional\text\bert.py in _get_precision_recall_f1(preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale)
    373     cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings)
    374     # Final metrics shape = (batch_size * num_layers | batch_size)
--> 375     precision = _get_scaled_precision_or_recall(cos_sim, "precision", preds_idf_scale)
    376     recall = _get_scaled_precision_or_recall(cos_sim, "recall", target_idf_scale)
    377 

~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\functional\text\bert.py in _get_scaled_precision_or_recall(cos_sim, metric, idf_scale)
    350     dim = 3 if metric == "precision" else 2
    351     res = cos_sim.max(dim=dim).values
--> 352     res = torch.einsum("bls, bs -> bls", res, idf_scale).sum(-1)
    353     # We transpose the results and squeeze if possible to match the format of the original BERTScore implementation
    354     res = res.transpose(0, 1).squeeze()

~\Anaconda3\envs\xxx\lib\site-packages\torch\functional.py in einsum(*args)
    328         return einsum(equation, *_operands)
    329 
--> 330     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    331 
    332 # Wrapper around _histogramdd and _histogramdd_bin_edges needed due to (Tensor, Tensor[]) return type.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

As a workaround, changing cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings) to cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings).to("cuda") in torchmetrics\functional\text\bert.py seems to work for me.

Encountered with:

  • Python: 3.8.12
  • torchmetrics: 0.7.2 & 0.8.0dev
  • OS: Windows 11
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki SkafteNicki mentioned this issue Mar 24, 2022
4 tasks
@Borda Borda added the bug / fix Something isn't working label Mar 24, 2022
@Borda Borda added this to the v0.7 milestone Mar 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants