From 1e290351709aa48ad444f9f4ef8aafb0670355fa Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 27 Sep 2021 15:13:30 +0200 Subject: [PATCH] Metric sweeping (#544) * static property * typing * setter * chlog * typing * retrievel * audio * update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: Jirka --- CHANGELOG.md | 3 +++ tests/helpers/testers.py | 2 ++ torchmetrics/audio/si_sdr.py | 1 + torchmetrics/audio/si_snr.py | 1 + torchmetrics/metric.py | 1 + torchmetrics/retrieval/mean_average_precision.py | 2 ++ torchmetrics/retrieval/mean_reciprocal_rank.py | 2 ++ torchmetrics/retrieval/retrieval_fallout.py | 2 ++ torchmetrics/retrieval/retrieval_metric.py | 1 + torchmetrics/retrieval/retrieval_ndcg.py | 2 ++ torchmetrics/retrieval/retrieval_precision.py | 2 ++ torchmetrics/retrieval/retrieval_recall.py | 2 ++ torchmetrics/text/bert.py | 2 ++ torchmetrics/text/bleu.py | 1 + torchmetrics/text/rouge.py | 2 ++ torchmetrics/text/wer.py | 1 + 16 files changed, 27 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c860575529d..4792c2156b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510)) +- Added metric sweeping `higher_is_better` as constant attribute ([#544](https://github.com/PyTorchLightning/metrics/pull/544)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 6def41e4609..7641e07ecff 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -151,6 +151,8 @@ def _class_test( ) with pytest.raises(RuntimeError): metric.is_differentiable = not metric.is_differentiable + with pytest.raises(RuntimeError): + metric.higher_is_better = not metric.higher_is_better # check that the metric is scriptable if check_scriptable: diff --git a/torchmetrics/audio/si_sdr.py b/torchmetrics/audio/si_sdr.py index 5c0ac9708b4..a8ed159f3aa 100644 --- a/torchmetrics/audio/si_sdr.py +++ b/torchmetrics/audio/si_sdr.py @@ -65,6 +65,7 @@ class SI_SDR(Metric): """ is_differentiable = True + higher_is_better = True sum_si_sdr: Tensor total: Tensor diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index 6be0e0fd807..9b82f18dda6 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -65,6 +65,7 @@ class SI_SNR(Metric): is_differentiable = True sum_si_snr: Tensor total: Tensor + higher_is_better = True def __init__( self, diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 7743cb21536..d23994d11fe 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -70,6 +70,7 @@ class Metric(Module, ABC): __jit_ignored_attributes__ = ["device"] __jit_unused_properties__ = ["is_differentiable"] is_differentiable: Optional[bool] = None + higher_is_better: Optional[bool] = None def __init__( self, diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py index a0b4779e8cd..774cdb790e6 100644 --- a/torchmetrics/retrieval/mean_average_precision.py +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -64,5 +64,7 @@ class RetrievalMAP(RetrievalMetric): tensor(0.7917) """ + higher_is_better = True + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: return retrieval_average_precision(preds, target) diff --git a/torchmetrics/retrieval/mean_reciprocal_rank.py b/torchmetrics/retrieval/mean_reciprocal_rank.py index 9ac7706c0bc..43f32a2f586 100644 --- a/torchmetrics/retrieval/mean_reciprocal_rank.py +++ b/torchmetrics/retrieval/mean_reciprocal_rank.py @@ -64,5 +64,7 @@ class RetrievalMRR(RetrievalMetric): tensor(0.7500) """ + higher_is_better = True + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: return retrieval_reciprocal_rank(preds, target) diff --git a/torchmetrics/retrieval/retrieval_fallout.py b/torchmetrics/retrieval/retrieval_fallout.py index 2052d4cbb3a..b99fde3bfa7 100644 --- a/torchmetrics/retrieval/retrieval_fallout.py +++ b/torchmetrics/retrieval/retrieval_fallout.py @@ -69,6 +69,8 @@ class RetrievalFallOut(RetrievalMetric): tensor(0.5000) """ + higher_is_better = False + def __init__( self, empty_target_action: str = "pos", diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index a0926a3e7a3..9f7da296cf6 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -67,6 +67,7 @@ class RetrievalMetric(Metric, ABC): indexes: List[Tensor] preds: List[Tensor] target: List[Tensor] + higher_is_better = True def __init__( self, diff --git a/torchmetrics/retrieval/retrieval_ndcg.py b/torchmetrics/retrieval/retrieval_ndcg.py index 30d96197e93..9ddd87b1a9b 100644 --- a/torchmetrics/retrieval/retrieval_ndcg.py +++ b/torchmetrics/retrieval/retrieval_ndcg.py @@ -67,6 +67,8 @@ class RetrievalNormalizedDCG(RetrievalMetric): tensor(0.8467) """ + higher_is_better = True + def __init__( self, empty_target_action: str = "neg", diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index d97c2125349..48e2b58088d 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -67,6 +67,8 @@ class RetrievalPrecision(RetrievalMetric): tensor(0.5000) """ + higher_is_better = True + def __init__( self, empty_target_action: str = "neg", diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index b6c1fcff63d..76836e9c4bd 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -67,6 +67,8 @@ class RetrievalRecall(RetrievalMetric): tensor(0.7500) """ + higher_is_better = True + def __init__( self, empty_target_action: str = "neg", diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index d8714125529..ff0059e929a 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -121,6 +121,8 @@ class BERTScore(Metric): 'f1': [0.99..., 0.99...]} """ + higher_is_better = True + def __init__( self, model_name_or_path: Optional[str] = None, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index cf582b4a382..6ba530c56ca 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -60,6 +60,7 @@ class BLEUScore(Metric): """ is_differentiable = False + higher_is_better = True trans_len: Tensor ref_len: Tensor numerator: Tensor diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index a0622969805..31c7d41a2b7 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -78,6 +78,8 @@ class ROUGEScore(Metric): [1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin `Rouge Detail`_ """ + higher_is_better = True + def __init__( self, newline_sep: Optional[bool] = None, # remove in v0.7 diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index 395446d2c7c..0a360dacc1d 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -66,6 +66,7 @@ class WER(Metric): tensor(0.5000) """ is_differentiable = False + higher_is_better = False error: Tensor total: Tensor