You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fromtorchmetrics.text.bertimportBERTScorebertscore=BERTScore(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cuda")
bertscore=bertscore.to("cuda")
preds= ["the sun is a star"]
target= ["the sun is classified as a star"]
results=bertscore(preds, target)
I get the following RuntimeError:
--------------------------------------------------
RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_29808\645715684.py in <cell line: 1>()
----> 1 results = bertscore(preds, target)
~\Anaconda3\envs\xxx\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\metric.py in forward(self, *args, **kwargs)
252 self.reset()
253 self.update(*args, **kwargs)
--> 254 self._forward_cache = self.compute()
255
256 # restore context
~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\metric.py in wrapped_func(*args, **kwargs)
418 should_unsync=self._should_unsync,
419 ):
--> 420 value = compute(*args, **kwargs)
421 self._computed = _squeeze_if_scalar(value)
422
~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\text\bert.py in compute(self)
231 Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
232 """
--> 233 return bert_score(
234 preds=_get_input_dict(self.preds_input_ids, self.preds_attention_mask),
235 target=_get_input_dict(self.target_input_ids, self.target_attention_mask),
~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\functional\text\bert.py in bert_score(preds, target, model_name_or_path, num_layers, all_layers, model, user_tokenizer, user_forward_fn, verbose, idf, device, max_length, batch_size, num_threads, return_hash, lang, rescale_with_baseline, baseline_path, baseline_url)
646 )
647
--> 648 precision, recall, f1_score = _get_precision_recall_f1(
649 preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale
650 )
~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\functional\text\bert.py in _get_precision_recall_f1(preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale)
373 cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings)
374 # Final metrics shape = (batch_size * num_layers | batch_size)
--> 375 precision = _get_scaled_precision_or_recall(cos_sim, "precision", preds_idf_scale)
376 recall = _get_scaled_precision_or_recall(cos_sim, "recall", target_idf_scale)
377
~\Anaconda3\envs\xxx\lib\site-packages\torchmetrics\functional\text\bert.py in _get_scaled_precision_or_recall(cos_sim, metric, idf_scale)
350 dim = 3 if metric == "precision" else 2
351 res = cos_sim.max(dim=dim).values
--> 352 res = torch.einsum("bls, bs -> bls", res, idf_scale).sum(-1)
353 # We transpose the results and squeeze if possible to match the format of the original BERTScore implementation
354 res = res.transpose(0, 1).squeeze()
~\Anaconda3\envs\xxx\lib\site-packages\torch\functional.py in einsum(*args)
328 return einsum(equation, *_operands)
329
--> 330 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
331
332 # Wrapper around _histogramdd and _histogramdd_bin_edges needed due to (Tensor, Tensor[]) return type.
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
As a workaround, changing cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings) to cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings).to("cuda") in torchmetrics\functional\text\bert.py seems to work for me.
Encountered with:
Python: 3.8.12
torchmetrics: 0.7.2 & 0.8.0dev
OS: Windows 11
The text was updated successfully, but these errors were encountered:
Hi,
when I run the following (exemplary) code,
I get the following RuntimeError:
As a workaround, changing
cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings)
tocos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings).to("cuda")
intorchmetrics\functional\text\bert.py
seems to work for me.Encountered with:
The text was updated successfully, but these errors were encountered: