Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/metriccollection_update_count
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored May 27, 2022
2 parents 962e8f4 + 8263e27 commit ebd7079
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 4 deletions.
13 changes: 13 additions & 0 deletions tests/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torchmetrics import Metric


class MetricWrapper(Metric):
def __init__(self, metric):
super().__init__()
self.metric = metric

def update(self, *args, **kwargs):
self.metric.update(*args, **kwargs)

def compute(self, *args, **kwargs):
return self.metric.compute(*args, **kwargs)
10 changes: 10 additions & 0 deletions tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,13 @@ def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_S
_temp[_temp == _class_remove] = _class_replace

_input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone())


_negmetric_noneavg = {
"pred1": torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]),
"target1": torch.tensor([0, 1]),
"res1": torch.tensor([0.0, 0.0, float("nan")]),
"pred2": torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]),
"target2": torch.tensor([0, 2]),
"res2": torch.tensor([0.0, 0.0, 0.0]),
}
10 changes: 10 additions & 0 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.metrics import accuracy_score as sk_accuracy
from torch import tensor

from tests.classification import MetricWrapper
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
Expand All @@ -32,6 +33,7 @@
from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.classification.inputs import _negmetric_noneavg
from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Accuracy
Expand Down Expand Up @@ -438,3 +440,11 @@ def test_negative_ignore_index(preds, target, ignore_index, result):
# Test functional
with pytest.raises(ValueError, match="^[The `target` has to be a non-negative tensor.]"):
acc_score = accuracy(preds, target, num_classes=num_classes, ignore_index=ignore_index)


def test_negmetric_noneavg(noneavg=_negmetric_noneavg):
acc = MetricWrapper(Accuracy(average="none", num_classes=noneavg["pred1"].shape[1]))
result1 = acc(noneavg["pred1"], noneavg["target1"])
assert torch.allclose(noneavg["res1"], result1, equal_nan=True)
result2 = acc(noneavg["pred2"], noneavg["target2"])
assert torch.allclose(noneavg["res2"], result2, equal_nan=True)
11 changes: 11 additions & 0 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.metrics import precision_score, recall_score
from torch import Tensor, tensor

from tests.classification import MetricWrapper
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
Expand All @@ -30,6 +31,7 @@
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.classification.inputs import _negmetric_noneavg
from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Metric, Precision, Recall
Expand Down Expand Up @@ -457,3 +459,12 @@ def test_same_input(metric_class, metric_functional, sk_fn, average):

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())


@pytest.mark.parametrize("metric_cls", [Precision, Recall])
def test_noneavg(metric_cls, noneavg=_negmetric_noneavg):
prec = MetricWrapper(metric_cls(average="none", num_classes=noneavg["pred1"].shape[1]))
result1 = prec(noneavg["pred1"], noneavg["target1"])
assert torch.allclose(noneavg["res1"], result1, equal_nan=True)
result2 = prec(noneavg["pred2"], noneavg["target2"])
assert torch.allclose(noneavg["res2"], result2, equal_nan=True)
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _accuracy_compute(
numerator = tp + tn
denominator = tp + tn + fp + fn
else:
numerator = tp
numerator = tp.clone()
denominator = tp + fn

if mdmc_average != MDMCAverageMethod.SAMPLEWISE:
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _precision_compute(
tensor(0.2500)
"""

numerator = tp
numerator = tp.clone()
denominator = tp + fp

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
Expand Down Expand Up @@ -241,7 +241,7 @@ def _recall_compute(
>>> _recall_compute(tp, fp, fn, average='micro', mdmc_average=None)
tensor(0.2500)
"""
numerator = tp
numerator = tp.clone()
denominator = tp + fn

if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _specificity_compute(
tensor(0.6250)
"""

numerator = tn
numerator = tn.clone()
denominator = tn + fp
if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
# a class is not present if there exists no TPs, no FPs, and no FNs
Expand Down

0 comments on commit ebd7079

Please sign in to comment.