Skip to content

Commit

Permalink
Double support for stat score metrics (#1023)
Browse files Browse the repository at this point in the history
* dtype robust
* changelog
  • Loading branch information
SkafteNicki authored May 12, 2022
1 parent 6d7ee02 commit 6f5ac1e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed non-empty state dict for a few metrics ([#1012](https://github.com/PyTorchLightning/metrics/pull/1012))

- Fixed `torch.double` support in stat score metrics ([#1023](https://github.com/PyTorchLightning/metrics/pull/1023))



## [0.8.2] - 2022-05-06

Expand Down
7 changes: 7 additions & 0 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,12 @@ class TestStatScores(MetricTester):
# DDP tests temporarily disabled due to hanging issues
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_stat_scores_class(
self,
ddp: bool,
dist_sync_on_step: bool,
dtype: torch.dtype,
sk_fn: Callable,
preds: Tensor,
target: Tensor,
Expand All @@ -191,6 +193,11 @@ def test_stat_scores_class(
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

if preds.is_floating_point():
preds = preds.to(dtype)
if target.is_floating_point():
target = target.to(dtype)

self.run_class_metric_test(
ddp=ddp,
preds=preds,
Expand Down
12 changes: 8 additions & 4 deletions torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,21 @@ def _reduce_stat_scores(
else:
weights = weights.float()

numerator = torch.where(zero_div_mask, tensor(float(zero_division), device=numerator.device), numerator)
denominator = torch.where(zero_div_mask | ignore_mask, tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, tensor(0.0, device=weights.device), weights)
numerator = torch.where(
zero_div_mask, tensor(zero_division, dtype=numerator.dtype, device=numerator.device), numerator
)
denominator = torch.where(
zero_div_mask | ignore_mask, tensor(1.0, dtype=denominator.dtype, device=denominator.device), denominator
)
weights = torch.where(ignore_mask, tensor(0.0, dtype=weights.dtype, device=weights.device), weights)

if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True)

scores = weights * (numerator / denominator)

# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), tensor(float(zero_division), device=scores.device), scores)
scores = torch.where(torch.isnan(scores), tensor(zero_division, dtype=scores.dtype, device=scores.device), scores)

if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0)
Expand Down

0 comments on commit 6f5ac1e

Please sign in to comment.