Skip to content

Commit

Permalink
Added differentiability for metrics - 4/n (#253)
Browse files Browse the repository at this point in the history
* added differentiability for metrics

* fixed typo

* fixed double function call

* fix tests + changelog

Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Jun 8, 2021
1 parent 5d6b0ce commit 6f0ef3b
Show file tree
Hide file tree
Showing 24 changed files with 336 additions and 95 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249))


- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253))


### Changed

- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260))
Expand Down
18 changes: 16 additions & 2 deletions tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def test_confusion_matrix(

def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel):
self.run_functional_metric_test(
preds,
target,
preds=preds,
target=target,
metric_functional=confusion_matrix,
sk_metric=partial(sk_metric, normalize=normalize),
metric_args={
Expand All @@ -161,6 +161,20 @@ def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric,
}
)

def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ConfusionMatrix,
metric_functional=confusion_matrix,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize,
"multilabel": multilabel
}
)


def test_warning_on_nan(tmpdir):
preds = torch.randint(3, size=(20, ))
Expand Down
38 changes: 38 additions & 0 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,44 @@ def test_fbeta_f1_functional(
},
)

def test_fbeta_f1_differentiability(
self,
preds: Tensor,
target: Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")

if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")

self.run_differentiability_test(
preds,
target,
metric_functional=metric_fn,
metric_module=metric_class,
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
Expand Down
13 changes: 11 additions & 2 deletions tests/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,22 @@ def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target):

def test_hamming_distance_fn(self, preds, target):
self.run_functional_metric_test(
preds,
target,
preds=preds,
target=target,
metric_functional=hamming_distance,
sk_metric=_sk_hamming_loss,
metric_args={"threshold": THRESHOLD},
)

def test_hamming_distance_differentiability(self, preds, target):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=HammingDistance,
metric_functional=hamming_distance,
metric_args={"threshold": THRESHOLD},
)


@pytest.mark.parametrize("threshold", [1.5])
def test_wrong_params(threshold):
Expand Down
12 changes: 10 additions & 2 deletions tests/classification/test_hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,20 @@ def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multi

def test_hinge_fn(self, preds, target, squared, multiclass_mode):
self.run_functional_metric_test(
preds,
target,
preds=preds,
target=target,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode),
sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode),
)

def test_hinge_differentiability(self, preds, target, squared, multiclass_mode):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=Hinge,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode)
)


_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2)))

Expand Down
13 changes: 13 additions & 0 deletions tests/classification/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric,
}
)

def test_confusion_matrix_differentiability(self, reduction, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=IoU,
metric_functional=iou,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction
}
)


@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, Tensor([1, 1, 1])),
Expand Down
12 changes: 12 additions & 0 deletions tests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,15 @@ def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classe
"threshold": THRESHOLD,
}
)

def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=MatthewsCorrcoef,
metric_functional=matthews_corrcoef,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
}
)
39 changes: 39 additions & 0 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,45 @@ def test_precision_recall_fn(
},
)

def test_precision_recall_differentiability(
self,
preds: Tensor,
target: Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
# todo: `metric_class` is unused
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")

if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")

self.run_differentiability_test(
preds=preds,
target=target,
metric_module=metric_class,
metric_functional=metric_fn,
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
)


@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
def test_precision_recall_joint(average):
Expand Down
9 changes: 9 additions & 0 deletions tests/classification/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_c
metric_args={"num_classes": num_classes},
)

def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds,
target,
metric_module=PrecisionRecallCurve,
metric_functional=precision_recall_curve,
metric_args={"num_classes": num_classes},
)


@pytest.mark.parametrize(
['pred', 'target', 'expected_p', 'expected_r', 'expected_t'],
Expand Down
9 changes: 9 additions & 0 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def test_roc_functional(self, preds, target, sk_metric, num_classes):
metric_args={"num_classes": num_classes},
)

def test_roc_differentiability(self, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds,
target,
metric_module=ROC,
metric_functional=roc,
metric_args={"num_classes": num_classes},
)


@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
Expand Down
31 changes: 31 additions & 0 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,37 @@ def test_stat_scores_fn(
},
)

def test_stat_scores_differentiability(
self,
sk_fn: Callable,
preds: Tensor,
target: Tensor,
reduce: str,
mdmc_reduce: Optional[str],
num_classes: Optional[int],
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

self.run_differentiability_test(
preds,
target,
metric_module=StatScores,
metric_functional=stat_scores,
metric_args={
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
},
)


_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
Expand Down
29 changes: 19 additions & 10 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pickle
import sys
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Sequence

import numpy as np
import pytest
Expand Down Expand Up @@ -57,28 +57,39 @@ def setup_ddp(rank, world_size):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def _assert_allclose(pl_result, sk_result, atol: float = 1e-8):
def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
"""Utility function for recursively asserting that two results are within a certain tolerance """
# single output compare
if isinstance(pl_result, Tensor):
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
# multi output compare
elif isinstance(pl_result, (tuple, list)):
elif isinstance(pl_result, Sequence):
for pl_res, sk_res in zip(pl_result, sk_result):
_assert_allclose(pl_res, sk_res, atol=atol)
else:
raise ValueError("Unknown format for comparison")


def _assert_tensor(pl_result):
def _assert_tensor(pl_result: Any):
""" Utility function for recursively checking that some input only consists of torch tensors """
if isinstance(pl_result, (list, tuple)):
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_tensor(plr)
else:
assert isinstance(pl_result, Tensor)


def _assert_requires_grad(metric: Metric, pl_result: Any):
""" Utility function for recursively asserting that metric output is consistent
with the `is_differentiable` attribute
"""
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_requires_grad(metric, plr)
else:
assert metric.is_differentiable == pl_result.requires_grad


def _class_test(
rank: int,
worldsize: int,
Expand Down Expand Up @@ -472,11 +483,9 @@ def run_differentiability_test(
if preds.is_floating_point():
preds.requires_grad = True
out = metric(preds[0], target[0])
# metrics can return list of values
if isinstance(out, list):
assert all(metric.is_differentiable == o.requires_grad for o in out)
else:
assert metric.is_differentiable == out.requires_grad

# Check if requires_grad matches is_differentiable attribute
_assert_requires_grad(metric, out)

if metric.is_differentiable:
# check for numerical correctness
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,7 @@ def compute(self) -> Tensor:
this will be a `[n_classes, 2, 2]` tensor
"""
return _confusion_matrix_compute(self.confmat, self.normalize)

@property
def is_differentiable(self) -> bool:
return False
4 changes: 4 additions & 0 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,7 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn
)

@property
def is_differentiable(self) -> bool:
return False
4 changes: 4 additions & 0 deletions torchmetrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ def compute(self) -> Tensor:
Computes hamming distance based on inputs passed in to ``update`` previously.
"""
return _hamming_distance_compute(self.correct, self.total)

@property
def is_differentiable(self) -> bool:
return False
4 changes: 4 additions & 0 deletions torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ def update(self, preds: Tensor, target: Tensor):

def compute(self) -> Tensor:
return _hinge_compute(self.measure, self.total)

@property
def is_differentiable(self) -> bool:
return True
4 changes: 4 additions & 0 deletions torchmetrics/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,7 @@ def compute(self) -> Tensor:
Computes intersection over union (IoU)
"""
return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)

@property
def is_differentiable(self) -> bool:
return False
Loading

0 comments on commit 6f0ef3b

Please sign in to comment.