Skip to content

Commit

Permalink
closes #359: multiclass for average_precision
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii Rychyk authored and Oleksii Rychyk committed Jul 18, 2021
1 parent 5c2069e commit a2a2b94
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
24 changes: 23 additions & 1 deletion tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import numpy as np
import pytest
from sklearn.metrics import average_precision_score as sk_average_precision_score
import torch
from torch import tensor

from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from tests.helpers.testers import NUM_CLASSES, MetricTester, BATCH_SIZE
from torchmetrics.classification.average_precision import AveragePrecision
from torchmetrics.functional import average_precision

Expand Down Expand Up @@ -118,3 +119,24 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num
)
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(torch.randn(BATCH_SIZE, NUM_CLASSES),
torch.randint(0, 2, (BATCH_SIZE, NUM_CLASSES)),
_sk_avg_prec_multiclass_prob,
NUM_CLASSES),
]
)
def test_average_precision_multilabel(preds, target, sk_metric, num_classes):
AP = AveragePrecision(
num_classes=num_classes,
pos_label=1,
compute_on_step=False,
)
preds = torch.sigmoid(preds)
AP.update(preds, target)
result = torch.tensor(AP.compute())
expected = sk_average_precision_score(target.numpy(), preds.numpy(), average=None)
assert np.allclose(result.numpy(), expected)
24 changes: 17 additions & 7 deletions torchmetrics/functional/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,23 @@ def _precision_recall_curve_compute_multi_class(
precision, recall, thresholds = [], [], []
for cls in range(num_classes):
preds_cls = preds[:, cls]
res = precision_recall_curve(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)

if target.ndim > 1:
res = precision_recall_curve(
preds=preds_cls,
target=target[:, cls],
num_classes=1,
pos_label=1,
sample_weights=sample_weights,
)
else:
res = precision_recall_curve(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
Expand Down

0 comments on commit a2a2b94

Please sign in to comment.