Skip to content

Commit

Permalink
Merge branch 'master' into critical_success_index
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Dec 21, 2023
2 parents 898d1bc + 357911a commit b5366ff
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import Tensor

from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13, _XLA_AVAILABLE
from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6
Expand Down Expand Up @@ -115,7 +115,9 @@ def to_onehot(
def _top_k_with_half_precision_support(x: Tensor, k: int = 1, dim: int = 1) -> Tensor:
"""torch.top_k does not support half precision on CPU."""
if x.dtype == torch.half and not x.is_cuda:
idx = torch.argsort(x, dim=dim, descending=True)
if not _TORCH_GREATER_EQUAL_1_13:
raise RuntimeError("Half precision (torch.float16) is not supported on CPU for PyTorch < 1.13.")
idx = torch.argsort(x, dim=dim, stable=True).flip(dim)
return idx.narrow(dim, 0, k)
return x.topk(k=k, dim=dim).indices

Expand All @@ -139,11 +141,11 @@ def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
[1, 1, 0]], dtype=torch.int32)
"""
zeros = torch.zeros_like(prob_tensor)
topk_tensor = torch.zeros_like(prob_tensor, dtype=torch.int)
if topk == 1: # argmax has better performance than topk
topk_tensor = zeros.scatter(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0)
topk_tensor.scatter_(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0)
else:
topk_tensor = zeros.scatter(dim, _top_k_with_half_precision_support(prob_tensor, k=topk, dim=dim), 1.0)
topk_tensor.scatter_(dim, _top_k_with_half_precision_support(prob_tensor, k=topk, dim=dim), 1.0)
return topk_tensor.int()


Expand Down
7 changes: 5 additions & 2 deletions tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,10 @@ def _reference_topk(x, dim, k):
one_hot = np.zeros((x.shape[0], x.shape[1]), dtype=int)
if dim == 1:
for i in range(x.shape[0]):
one_hot[i, np.argsort(x[i, :])[::-1][:k]] = 1
one_hot[i, np.argsort(x[i, :], kind="stable")[::-1][:k]] = 1
return one_hot
for i in range(x.shape[1]):
one_hot[np.argsort(x[:, i])[::-1][:k], i] = 1
one_hot[np.argsort(x[:, i], kind="stable")[::-1][:k], i] = 1
return one_hot


Expand All @@ -227,6 +228,8 @@ def _reference_topk(x, dim, k):
@pytest.mark.parametrize("dim", [0, 1])
def test_custom_topk(dtype, k, dim):
"""Test custom topk implementation."""
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13:
pytest.skip("half precision topk not supported in Pytorch < 1.13")
x = torch.randn(100, 10, dtype=dtype)
top_k = select_topk(x, dim=dim, topk=k)
assert top_k.shape == (100, 10)
Expand Down

0 comments on commit b5366ff

Please sign in to comment.