diff --git a/CHANGELOG.md b/CHANGELOG.md index fddf548d679..0c2bcd321bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375)) +- Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386)) + + ### Changed - Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382)) diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index 25ff0f826c8..d838976819a 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -21,6 +21,7 @@ from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.classification.inputs import _input_multilabel from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.average_precision import AveragePrecision @@ -55,6 +56,12 @@ def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1): return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) +def _sk_avg_prec_multilabel_prob(preds, target, num_classes): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1, num_classes).numpy() + return sk_average_precision_score(sk_target, sk_preds, average=None) + + def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() sk_target = target.view(-1).numpy() @@ -66,6 +73,7 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), + (_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES), ] ) class TestAveragePrecision(MetricTester): diff --git a/torchmetrics/functional/classification/precision_recall_curve.py b/torchmetrics/functional/classification/precision_recall_curve.py index 0846e270cc8..65653ea5b43 100644 --- a/torchmetrics/functional/classification/precision_recall_curve.py +++ b/torchmetrics/functional/classification/precision_recall_curve.py @@ -148,13 +148,20 @@ def _precision_recall_curve_compute_multi_class( precision, recall, thresholds = [], [], [] for cls in range(num_classes): preds_cls = preds[:, cls] - res = precision_recall_curve( + + prc_args = dict( preds=preds_cls, target=target, num_classes=1, pos_label=cls, sample_weights=sample_weights, ) + if target.ndim > 1: + prc_args.update(dict( + target=target[:, cls], + pos_label=1, + )) + res = precision_recall_curve(**prc_args) precision.append(res[0]) recall.append(res[1]) thresholds.append(res[2])