Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multilabel for AveragePrecision #386

Merged
merged 12 commits into from
Jul 24, 2021
24 changes: 23 additions & 1 deletion tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as sk_average_precision_score
from torch import tensor

from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from tests.helpers.testers import BATCH_SIZE, NUM_CLASSES, MetricTester
from torchmetrics.classification.average_precision import AveragePrecision
from torchmetrics.functional import average_precision

Expand Down Expand Up @@ -118,3 +119,24 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num
)
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score


@pytest.mark.parametrize(
discort marked this conversation as resolved.
Show resolved Hide resolved
"preds, target, sk_metric, num_classes", [
(
torch.randn(BATCH_SIZE, NUM_CLASSES), torch.randint(0, 2, (BATCH_SIZE, NUM_CLASSES)),
_sk_avg_prec_multiclass_prob, NUM_CLASSES
),
]
)
def test_average_precision_multilabel(preds, target, sk_metric, num_classes):
AP = AveragePrecision(
num_classes=num_classes,
pos_label=1,
compute_on_step=False,
)
preds = torch.sigmoid(preds)
AP.update(preds, target)
result = torch.tensor(AP.compute())
expected = sk_average_precision_score(target.numpy(), preds.numpy(), average=None)
assert np.allclose(result.numpy(), expected)
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,20 @@ def _precision_recall_curve_compute_multi_class(
precision, recall, thresholds = [], [], []
for cls in range(num_classes):
preds_cls = preds[:, cls]
res = precision_recall_curve(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)

prc_args = dict(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)
if target.ndim > 1:
prc_args.update(dict(precision_recall_curve(
Borda marked this conversation as resolved.
Show resolved Hide resolved
target=target[:, cls],
pos_label=1,
))
res = precision_recall_curve(**prc_args)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
Expand Down