Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multilabel for AveragePrecision #386

Merged
merged 12 commits into from
Jul 24, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375))


- Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386))


### Changed

- Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382))
Expand Down
8 changes: 8 additions & 0 deletions tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.average_precision import AveragePrecision
Expand Down Expand Up @@ -55,6 +56,12 @@ def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1):
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)


def _sk_avg_prec_multilabel_prob(preds, target, num_classes):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1, num_classes).numpy()
return sk_average_precision_score(sk_target, sk_preds, average=None)


def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
Expand All @@ -66,6 +73,7 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
(_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES),
(_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES),
]
)
class TestAveragePrecision(MetricTester):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,20 @@ def _precision_recall_curve_compute_multi_class(
precision, recall, thresholds = [], [], []
for cls in range(num_classes):
preds_cls = preds[:, cls]
res = precision_recall_curve(

prc_args = dict(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)
if target.ndim > 1:
prc_args.update(dict(
target=target[:, cls],
pos_label=1,
))
res = precision_recall_curve(**prc_args)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
Expand Down