From 6f0ef3b7ba5dae9aa82573e36b25c43b36179a84 Mon Sep 17 00:00:00 2001 From: Bhadresh Savani Date: Wed, 9 Jun 2021 03:53:07 +0530 Subject: [PATCH] Added differentiability for metrics - 4/n (#253) * added differentiability for metrics * fixed typo * fixed double function call * fix tests + changelog Co-authored-by: Nicki Skafte Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 + tests/classification/test_confusion_matrix.py | 18 +++- tests/classification/test_f_beta.py | 38 +++++++++ tests/classification/test_hamming_distance.py | 13 ++- tests/classification/test_hinge.py | 12 ++- tests/classification/test_iou.py | 13 +++ .../classification/test_matthews_corrcoef.py | 12 +++ tests/classification/test_precision_recall.py | 39 +++++++++ .../test_precision_recall_curve.py | 9 ++ tests/classification/test_roc.py | 9 ++ tests/classification/test_stat_scores.py | 31 +++++++ tests/helpers/testers.py | 29 ++++--- .../classification/confusion_matrix.py | 4 + torchmetrics/classification/f_beta.py | 4 + .../classification/hamming_distance.py | 4 + torchmetrics/classification/hinge.py | 4 + torchmetrics/classification/iou.py | 4 + .../classification/matthews_corrcoef.py | 4 + .../classification/precision_recall.py | 8 ++ .../classification/precision_recall_curve.py | 4 + torchmetrics/classification/roc.py | 4 + torchmetrics/classification/stat_scores.py | 4 + .../classification/precision_recall_curve.py | 77 +++++++++-------- torchmetrics/functional/classification/roc.py | 84 +++++++++---------- 24 files changed, 336 insertions(+), 95 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1568c58abcb..38c8b90397e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249)) +- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253)) + + ### Changed - Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260)) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 44f599eedcd..4dc28652158 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -149,8 +149,8 @@ def test_confusion_matrix( def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): self.run_functional_metric_test( - preds, - target, + preds=preds, + target=target, metric_functional=confusion_matrix, sk_metric=partial(sk_metric, normalize=normalize), metric_args={ @@ -161,6 +161,20 @@ def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, } ) + def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=ConfusionMatrix, + metric_functional=confusion_matrix, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + "normalize": normalize, + "multilabel": multilabel + } + ) + def test_warning_on_nan(tmpdir): preds = torch.randint(3, size=(20, )) diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 693c445dd4b..83b99074fb4 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -301,6 +301,44 @@ def test_fbeta_f1_functional( }, ) + def test_fbeta_f1_differentiability( + self, + preds: Tensor, + target: Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + sk_fn: Callable, + multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_differentiability_test( + preds, + target, + metric_functional=metric_fn, + metric_module=metric_class, + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "multiclass": multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + ) + _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) diff --git a/tests/classification/test_hamming_distance.py b/tests/classification/test_hamming_distance.py index 8f629e7d6f3..c2be69eb554 100644 --- a/tests/classification/test_hamming_distance.py +++ b/tests/classification/test_hamming_distance.py @@ -77,13 +77,22 @@ def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): def test_hamming_distance_fn(self, preds, target): self.run_functional_metric_test( - preds, - target, + preds=preds, + target=target, metric_functional=hamming_distance, sk_metric=_sk_hamming_loss, metric_args={"threshold": THRESHOLD}, ) + def test_hamming_distance_differentiability(self, preds, target): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=HammingDistance, + metric_functional=hamming_distance, + metric_args={"threshold": THRESHOLD}, + ) + @pytest.mark.parametrize("threshold", [1.5]) def test_wrong_params(threshold): diff --git a/tests/classification/test_hinge.py b/tests/classification/test_hinge.py index 568948df37e..d4b9af016e3 100644 --- a/tests/classification/test_hinge.py +++ b/tests/classification/test_hinge.py @@ -108,12 +108,20 @@ def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multi def test_hinge_fn(self, preds, target, squared, multiclass_mode): self.run_functional_metric_test( - preds, - target, + preds=preds, + target=target, metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode), sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), ) + def test_hinge_differentiability(self, preds, target, squared, multiclass_mode): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=Hinge, + metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode) + ) + _input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) diff --git a/tests/classification/test_iou.py b/tests/classification/test_iou.py index 2e21ab1b1a9..27b3eaabd3d 100644 --- a/tests/classification/test_iou.py +++ b/tests/classification/test_iou.py @@ -133,6 +133,19 @@ def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, } ) + def test_confusion_matrix_differentiability(self, reduction, preds, target, sk_metric, num_classes): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=IoU, + metric_functional=iou, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + "reduction": reduction + } + ) + @pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ pytest.param(False, 'none', None, Tensor([1, 1, 1])), diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py index 19b7aa93624..559b1cda437 100644 --- a/tests/classification/test_matthews_corrcoef.py +++ b/tests/classification/test_matthews_corrcoef.py @@ -127,3 +127,15 @@ def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classe "threshold": THRESHOLD, } ) + + def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MatthewsCorrcoef, + metric_functional=matthews_corrcoef, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + } + ) diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index d627ba67b34..72be11af5b1 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -302,6 +302,45 @@ def test_precision_recall_fn( }, ) + def test_precision_recall_differentiability( + self, + preds: Tensor, + target: Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + sk_fn: Callable, + multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + # todo: `metric_class` is unused + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=metric_class, + metric_functional=metric_fn, + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "multiclass": multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + ) + @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) def test_precision_recall_joint(average): diff --git a/tests/classification/test_precision_recall_curve.py b/tests/classification/test_precision_recall_curve.py index 5fd4a50bdcd..f08f66d71b1 100644 --- a/tests/classification/test_precision_recall_curve.py +++ b/tests/classification/test_precision_recall_curve.py @@ -97,6 +97,15 @@ def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_c metric_args={"num_classes": num_classes}, ) + def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes): + self.run_differentiability_test( + preds, + target, + metric_module=PrecisionRecallCurve, + metric_functional=precision_recall_curve, + metric_args={"num_classes": num_classes}, + ) + @pytest.mark.parametrize( ['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index 7d988e4f1fa..e1dfce0a03f 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -116,6 +116,15 @@ def test_roc_functional(self, preds, target, sk_metric, num_classes): metric_args={"num_classes": num_classes}, ) + def test_roc_differentiability(self, preds, target, sk_metric, num_classes): + self.run_differentiability_test( + preds, + target, + metric_module=ROC, + metric_functional=roc, + metric_args={"num_classes": num_classes}, + ) + @pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 30994a6d499..b85531a8c43 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -246,6 +246,37 @@ def test_stat_scores_fn( }, ) + def test_stat_scores_differentiability( + self, + sk_fn: Callable, + preds: Tensor, + target: Tensor, + reduce: str, + mdmc_reduce: Optional[str], + num_classes: Optional[int], + multiclass: Optional[bool], + ignore_index: Optional[int], + top_k: Optional[int], + ): + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + self.run_differentiability_test( + preds, + target, + metric_module=StatScores, + metric_functional=stat_scores, + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "multiclass": multiclass, + "ignore_index": ignore_index, + "top_k": top_k, + }, + ) + _mc_k_target = tensor([0, 1, 2]) _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 7e911c8220b..a448f53371c 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -15,7 +15,7 @@ import pickle import sys from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Sequence import numpy as np import pytest @@ -57,28 +57,39 @@ def setup_ddp(rank, world_size): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) -def _assert_allclose(pl_result, sk_result, atol: float = 1e-8): +def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8): """Utility function for recursively asserting that two results are within a certain tolerance """ # single output compare if isinstance(pl_result, Tensor): assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True) # multi output compare - elif isinstance(pl_result, (tuple, list)): + elif isinstance(pl_result, Sequence): for pl_res, sk_res in zip(pl_result, sk_result): _assert_allclose(pl_res, sk_res, atol=atol) else: raise ValueError("Unknown format for comparison") -def _assert_tensor(pl_result): +def _assert_tensor(pl_result: Any): """ Utility function for recursively checking that some input only consists of torch tensors """ - if isinstance(pl_result, (list, tuple)): + if isinstance(pl_result, Sequence): for plr in pl_result: _assert_tensor(plr) else: assert isinstance(pl_result, Tensor) +def _assert_requires_grad(metric: Metric, pl_result: Any): + """ Utility function for recursively asserting that metric output is consistent + with the `is_differentiable` attribute + """ + if isinstance(pl_result, Sequence): + for plr in pl_result: + _assert_requires_grad(metric, plr) + else: + assert metric.is_differentiable == pl_result.requires_grad + + def _class_test( rank: int, worldsize: int, @@ -472,11 +483,9 @@ def run_differentiability_test( if preds.is_floating_point(): preds.requires_grad = True out = metric(preds[0], target[0]) - # metrics can return list of values - if isinstance(out, list): - assert all(metric.is_differentiable == o.requires_grad for o in out) - else: - assert metric.is_differentiable == out.requires_grad + + # Check if requires_grad matches is_differentiable attribute + _assert_requires_grad(metric, out) if metric.is_differentiable: # check for numerical correctness diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 5dc2f03abdb..b108aa56074 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -141,3 +141,7 @@ def compute(self) -> Tensor: this will be a `[n_classes, 2, 2]` tensor """ return _confusion_matrix_compute(self.confmat, self.normalize) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index c07b110d5f0..8dcaa0f7e89 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -309,3 +309,7 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn ) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/hamming_distance.py b/torchmetrics/classification/hamming_distance.py index 90fc8302edb..b3dd370b49e 100644 --- a/torchmetrics/classification/hamming_distance.py +++ b/torchmetrics/classification/hamming_distance.py @@ -109,3 +109,7 @@ def compute(self) -> Tensor: Computes hamming distance based on inputs passed in to ``update`` previously. """ return _hamming_distance_compute(self.correct, self.total) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/hinge.py b/torchmetrics/classification/hinge.py index bfb2a5ea5b7..c03fa8b8c44 100644 --- a/torchmetrics/classification/hinge.py +++ b/torchmetrics/classification/hinge.py @@ -121,3 +121,7 @@ def update(self, preds: Tensor, target: Tensor): def compute(self) -> Tensor: return _hinge_compute(self.measure, self.total) + + @property + def is_differentiable(self) -> bool: + return True diff --git a/torchmetrics/classification/iou.py b/torchmetrics/classification/iou.py index 326b43d24ad..55253e3854e 100644 --- a/torchmetrics/classification/iou.py +++ b/torchmetrics/classification/iou.py @@ -106,3 +106,7 @@ def compute(self) -> Tensor: Computes intersection over union (IoU) """ return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index 91978e92706..1272718ab19 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -112,3 +112,7 @@ def compute(self) -> Tensor: Computes matthews correlation coefficient """ return _matthews_corrcoef_compute(self.confmat) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index a8a43e25393..bda82b2df2a 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -181,6 +181,10 @@ def compute(self) -> Tensor: tp, fp, tn, fn = self._get_final_stats() return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) + @property + def is_differentiable(self) -> bool: + return False + class Recall(StatScores): r""" @@ -341,3 +345,7 @@ def compute(self) -> Tensor: """ tp, fp, tn, fn = self._get_final_stats() return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index 06b7f1c2e35..98b0a204107 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -146,3 +146,7 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index 524231104d5..091d36ae5a2 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -166,3 +166,7 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) return _roc_compute(preds, target, self.num_classes, self.pos_label) + + @property + def is_differentiable(self) -> bool: + return False diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 1315c76d814..43eccaa9b7c 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -278,6 +278,10 @@ def compute(self) -> Tensor: tp, fp, tn, fn = self._get_final_stats() return _stat_scores_compute(tp, fp, tn, fn) + @property + def is_differentiable(self) -> bool: + return False + def _reduce_stat_scores( numerator: Tensor, diff --git a/torchmetrics/functional/classification/precision_recall_curve.py b/torchmetrics/functional/classification/precision_recall_curve.py index f5a2718dc93..efec99d1332 100644 --- a/torchmetrics/functional/classification/precision_recall_curve.py +++ b/torchmetrics/functional/classification/precision_recall_curve.py @@ -118,47 +118,50 @@ def _precision_recall_curve_compute( pos_label: int, sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + with torch.no_grad(): + if num_classes == 1: + fps, tps, thresholds = _binary_clf_curve( + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label + ) - if num_classes == 1: - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained - # and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) - - thresholds = reversed(thresholds[sl]).clone() + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained + # and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([ + reversed(precision[sl]), + torch.ones(1, dtype=precision.dtype, device=precision.device) + ]) + + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + + thresholds = reversed(thresholds[sl]).clone() + + return precision, recall, thresholds + + # Recursively call per class + precision, recall, thresholds = [], [], [] + for c in range(num_classes): + preds_c = preds[:, c] + res = precision_recall_curve( + preds=preds_c, + target=target, + num_classes=1, + pos_label=c, + sample_weights=sample_weights, + ) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) return precision, recall, thresholds - # Recursively call per class - precision, recall, thresholds = [], [], [] - for c in range(num_classes): - preds_c = preds[:, c] - res = precision_recall_curve( - preds=preds_c, - target=target, - num_classes=1, - pos_label=c, - sample_weights=sample_weights, - ) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - - return precision, recall, thresholds - def precision_recall_curve( preds: Tensor, diff --git a/torchmetrics/functional/classification/roc.py b/torchmetrics/functional/classification/roc.py index 4e9509c13b9..53783b590e2 100644 --- a/torchmetrics/functional/classification/roc.py +++ b/torchmetrics/functional/classification/roc.py @@ -39,51 +39,51 @@ def _roc_compute( pos_label: int, sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - - if num_classes == 1 and preds.ndim == 1: # binary - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - # Add an extra threshold position - # to make sure that the curve starts at (0, 0) - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) - - if fps[-1] <= 0: - raise ValueError("No negative samples in targets, false positive value should be meaningless") - fpr = fps / fps[-1] - - if tps[-1] <= 0: - raise ValueError("No positive samples in targets, true positive value should be meaningless") - tpr = tps / tps[-1] + with torch.no_grad(): + if num_classes == 1 and preds.ndim == 1: # binary + fps, tps, thresholds = _binary_clf_curve( + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label + ) + # Add an extra threshold position + # to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + raise ValueError("No negative samples in targets, false positive value should be meaningless") + fpr = fps / fps[-1] + + if tps[-1] <= 0: + raise ValueError("No positive samples in targets, true positive value should be meaningless") + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + # Recursively call per class + fpr, tpr, thresholds = [], [], [] + for c in range(num_classes): + if preds.shape == target.shape: + preds_c = preds[:, c] + target_c = target[:, c] + pos_label = 1 + else: + preds_c = preds[:, c] + target_c = target + pos_label = c + res = roc( + preds=preds_c, + target=target_c, + num_classes=1, + pos_label=pos_label, + sample_weights=sample_weights, + ) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) return fpr, tpr, thresholds - # Recursively call per class - fpr, tpr, thresholds = [], [], [] - for c in range(num_classes): - if preds.shape == target.shape: - preds_c = preds[:, c] - target_c = target[:, c] - pos_label = 1 - else: - preds_c = preds[:, c] - target_c = target - pos_label = c - res = roc( - preds=preds_c, - target=target_c, - num_classes=1, - pos_label=pos_label, - sample_weights=sample_weights, - ) - fpr.append(res[0]) - tpr.append(res[1]) - thresholds.append(res[2]) - - return fpr, tpr, thresholds - def roc( preds: Tensor,