diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index 25ff0f826c8..aec3b1e4377 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -16,13 +16,14 @@ import numpy as np import pytest from sklearn.metrics import average_precision_score as sk_average_precision_score +import torch from torch import tensor from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, BATCH_SIZE from torchmetrics.classification.average_precision import AveragePrecision from torchmetrics.functional import average_precision @@ -118,3 +119,24 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num ) def test_average_precision(scores, target, expected_score): assert average_precision(scores, target) == expected_score + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", [ + (torch.randn(BATCH_SIZE, NUM_CLASSES), + torch.randint(0, 2, (BATCH_SIZE, NUM_CLASSES)), + _sk_avg_prec_multiclass_prob, + NUM_CLASSES), + ] +) +def test_average_precision_multilabel(preds, target, sk_metric, num_classes): + AP = AveragePrecision( + num_classes=num_classes, + pos_label=1, + compute_on_step=False, + ) + preds = torch.sigmoid(preds) + AP.update(preds, target) + result = torch.tensor(AP.compute()) + expected = sk_average_precision_score(target.numpy(), preds.numpy(), average=None) + assert np.allclose(result.numpy(), expected) diff --git a/torchmetrics/functional/classification/precision_recall_curve.py b/torchmetrics/functional/classification/precision_recall_curve.py index 0846e270cc8..19e22f94025 100644 --- a/torchmetrics/functional/classification/precision_recall_curve.py +++ b/torchmetrics/functional/classification/precision_recall_curve.py @@ -148,13 +148,23 @@ def _precision_recall_curve_compute_multi_class( precision, recall, thresholds = [], [], [] for cls in range(num_classes): preds_cls = preds[:, cls] - res = precision_recall_curve( - preds=preds_cls, - target=target, - num_classes=1, - pos_label=cls, - sample_weights=sample_weights, - ) + + if target.ndim > 1: + res = precision_recall_curve( + preds=preds_cls, + target=target[:, cls], + num_classes=1, + pos_label=1, + sample_weights=sample_weights, + ) + else: + res = precision_recall_curve( + preds=preds_cls, + target=target, + num_classes=1, + pos_label=cls, + sample_weights=sample_weights, + ) precision.append(res[0]) recall.append(res[1]) thresholds.append(res[2])