Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added differentiability for metrics - 4/n #253

Merged
merged 12 commits into from
Jun 8, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249))


- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253))


### Changed

- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260))
Expand Down
18 changes: 16 additions & 2 deletions tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def test_confusion_matrix(

def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel):
self.run_functional_metric_test(
preds,
target,
preds=preds,
target=target,
metric_functional=confusion_matrix,
sk_metric=partial(sk_metric, normalize=normalize),
metric_args={
Expand All @@ -161,6 +161,20 @@ def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric,
}
)

def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ConfusionMatrix,
metric_functional=confusion_matrix,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize,
"multilabel": multilabel
}
)


def test_warning_on_nan(tmpdir):
preds = torch.randint(3, size=(20, ))
Expand Down
38 changes: 38 additions & 0 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,44 @@ def test_fbeta_f1_functional(
},
)

def test_fbeta_f1_differentiability(
self,
preds: Tensor,
target: Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")

if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")

self.run_differentiability_test(
preds,
target,
metric_functional=metric_fn,
metric_module=metric_class,
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
Expand Down
13 changes: 11 additions & 2 deletions tests/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,22 @@ def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target):

def test_hamming_distance_fn(self, preds, target):
self.run_functional_metric_test(
preds,
target,
preds=preds,
target=target,
Borda marked this conversation as resolved.
Show resolved Hide resolved
metric_functional=hamming_distance,
sk_metric=_sk_hamming_loss,
metric_args={"threshold": THRESHOLD},
)

def test_hamming_distance_differentiability(self, preds, target):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=HammingDistance,
metric_functional=hamming_distance,
metric_args={"threshold": THRESHOLD},
)


@pytest.mark.parametrize("threshold", [1.5])
def test_wrong_params(threshold):
Expand Down
12 changes: 10 additions & 2 deletions tests/classification/test_hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,20 @@ def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multi

def test_hinge_fn(self, preds, target, squared, multiclass_mode):
self.run_functional_metric_test(
preds,
target,
preds=preds,
target=target,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode),
sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode),
)

def test_hinge_differentiability(self, preds, target, squared, multiclass_mode):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=Hinge,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode)
)


_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2)))

Expand Down
13 changes: 13 additions & 0 deletions tests/classification/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric,
}
)

def test_confusion_matrix_differentiability(self, reduction, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=IoU,
metric_functional=iou,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction
}
)


@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, Tensor([1, 1, 1])),
Expand Down
12 changes: 12 additions & 0 deletions tests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,15 @@ def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classe
"threshold": THRESHOLD,
}
)

def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=MatthewsCorrcoef,
metric_functional=matthews_corrcoef,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
}
)
39 changes: 39 additions & 0 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,45 @@ def test_precision_recall_fn(
},
)

def test_precision_recall_differentiability(
self,
preds: Tensor,
target: Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
# todo: `metric_class` is unused
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")

if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")

self.run_differentiability_test(
preds=preds,
target=target,
metric_module=metric_class,
metric_functional=metric_fn,
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
)


@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
def test_precision_recall_joint(average):
Expand Down
9 changes: 9 additions & 0 deletions tests/classification/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_c
metric_args={"num_classes": num_classes},
)

def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds,
target,
metric_module=PrecisionRecallCurve,
metric_functional=precision_recall_curve,
metric_args={"num_classes": num_classes},
)


@pytest.mark.parametrize(
['pred', 'target', 'expected_p', 'expected_r', 'expected_t'],
Expand Down
9 changes: 9 additions & 0 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def test_roc_functional(self, preds, target, sk_metric, num_classes):
metric_args={"num_classes": num_classes},
)

def test_roc_differentiability(self, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds,
target,
metric_module=ROC,
metric_functional=roc,
metric_args={"num_classes": num_classes},
)


@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
Expand Down
31 changes: 31 additions & 0 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,37 @@ def test_stat_scores_fn(
},
)

def test_stat_scores_differentiability(
self,
sk_fn: Callable,
preds: Tensor,
target: Tensor,
reduce: str,
mdmc_reduce: Optional[str],
num_classes: Optional[int],
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")

self.run_differentiability_test(
preds,
target,
metric_module=StatScores,
metric_functional=stat_scores,
metric_args={
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
},
)


_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
Expand Down
29 changes: 19 additions & 10 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pickle
import sys
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Sequence

import numpy as np
import pytest
Expand Down Expand Up @@ -57,28 +57,39 @@ def setup_ddp(rank, world_size):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def _assert_allclose(pl_result, sk_result, atol: float = 1e-8):
def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
"""Utility function for recursively asserting that two results are within a certain tolerance """
# single output compare
if isinstance(pl_result, Tensor):
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
# multi output compare
elif isinstance(pl_result, (tuple, list)):
elif isinstance(pl_result, Sequence):
for pl_res, sk_res in zip(pl_result, sk_result):
_assert_allclose(pl_res, sk_res, atol=atol)
else:
raise ValueError("Unknown format for comparison")


def _assert_tensor(pl_result):
def _assert_tensor(pl_result: Any):
""" Utility function for recursively checking that some input only consists of torch tensors """
if isinstance(pl_result, (list, tuple)):
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_tensor(plr)
else:
assert isinstance(pl_result, Tensor)


def _assert_requires_grad(metric: Metric, pl_result: Any):
""" Utility function for recursively asserting that metric output is consistent
with the `is_differentiable` attribute
"""
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_requires_grad(metric, plr)
else:
assert metric.is_differentiable == pl_result.requires_grad


def _class_test(
rank: int,
worldsize: int,
Expand Down Expand Up @@ -472,11 +483,9 @@ def run_differentiability_test(
if preds.is_floating_point():
preds.requires_grad = True
out = metric(preds[0], target[0])
# metrics can return list of values
if isinstance(out, list):
assert all(metric.is_differentiable == o.requires_grad for o in out)
else:
assert metric.is_differentiable == out.requires_grad

# Check if requires_grad matches is_differentiable attribute
_assert_requires_grad(metric, out)

if metric.is_differentiable:
# check for numerical correctness
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,7 @@ def compute(self) -> Tensor:
this will be a `[n_classes, 2, 2]` tensor
"""
return _confusion_matrix_compute(self.confmat, self.normalize)

@property
def is_differentiable(self):
return False
4 changes: 4 additions & 0 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,7 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn
)

@property
def is_differentiable(self):
return False
4 changes: 4 additions & 0 deletions torchmetrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ def compute(self) -> Tensor:
Computes hamming distance based on inputs passed in to ``update`` previously.
"""
return _hamming_distance_compute(self.correct, self.total)

@property
def is_differentiable(self):
return False
4 changes: 4 additions & 0 deletions torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ def update(self, preds: Tensor, target: Tensor):

def compute(self) -> Tensor:
return _hinge_compute(self.measure, self.total)

@property
def is_differentiable(self):
return True
4 changes: 4 additions & 0 deletions torchmetrics/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,7 @@ def compute(self) -> Tensor:
Computes intersection over union (IoU)
"""
return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)

@property
def is_differentiable(self):
return False
Loading