From c494f3575f542e7c264555a71fb97e8676f64700 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 26 Apr 2022 15:45:49 +0200 Subject: [PATCH] Fix the use of logits in calibration error (#985) * fix * Update CHANGELOG.md * fix logical operator Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 +- tests/classification/test_calibration_error.py | 12 ++++++++++-- .../functional/classification/calibration_error.py | 4 ++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d4b52f6383..b5137f2b3cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,7 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `BinnedPrecisionRecallCurve` when `thresholds` argument is not provided ([#968](https://github.com/PyTorchLightning/metrics/pull/968)) -- +- Fixed `CalibrationError` to work on logit input ([#985](https://github.com/PyTorchLightning/metrics/pull/985)) ## [0.8.0] - 2022-04-14 diff --git a/tests/classification/test_calibration_error.py b/tests/classification/test_calibration_error.py index 3a20bd3616e..2d419f6a3f2 100644 --- a/tests/classification/test_calibration_error.py +++ b/tests/classification/test_calibration_error.py @@ -3,8 +3,10 @@ import numpy as np import pytest +from scipy.special import softmax as _softmax -from tests.classification.inputs import _input_binary_prob +from tests.classification.inputs import _input_binary_logits, _input_binary_prob +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob @@ -24,8 +26,12 @@ def _sk_calibration(preds, target, n_bins, norm, debias=False): _, _, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = preds.numpy(), target.numpy() - + if mode == DataType.BINARY: + if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all(): + sk_preds = 1.0 / (1 + np.exp(-sk_preds)) # sigmoid transform if mode == DataType.MULTICLASS: + if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all(): + sk_preds = _softmax(sk_preds, axis=1) # binary label is whether or not the predicted class is correct sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target) sk_preds = np.max(sk_preds, axis=1) @@ -46,7 +52,9 @@ def _sk_calibration(preds, target, n_bins, norm, debias=False): "preds, target", [ (_input_binary_prob.preds, _input_binary_prob.target), + (_input_binary_logits.preds, _input_binary_logits.target), (_input_mcls_prob.preds, _input_mcls_prob.target), + (_input_mcls_logits.preds, _input_mcls_logits.target), (_input_mdmc_prob.preds, _input_mdmc_prob.target), ], ) diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 25dd6d64a79..5f08cf73400 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -143,8 +143,12 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: _, _, mode = _input_format_classification(preds, target) if mode == DataType.BINARY: + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.sigmoid() confidences, accuracies = preds, target elif mode == DataType.MULTICLASS: + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.softmax(dim=1) confidences, predictions = preds.max(dim=1) accuracies = predictions.eq(target) elif mode == DataType.MULTIDIM_MULTICLASS: