Skip to content

Commit

Permalink
Metric sweeping (#544)
Browse files Browse the repository at this point in the history
* static property

* typing

* setter

* chlog

* typing

* retrievel

* audio

* update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people authored Sep 27, 2021
1 parent ac52dd7 commit 1e29035
Show file tree
Hide file tree
Showing 16 changed files with 27 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510))


- Added metric sweeping `higher_is_better` as constant attribute ([#544](https://github.com/PyTorchLightning/metrics/pull/544))


### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))
Expand Down
2 changes: 2 additions & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _class_test(
)
with pytest.raises(RuntimeError):
metric.is_differentiable = not metric.is_differentiable
with pytest.raises(RuntimeError):
metric.higher_is_better = not metric.higher_is_better

# check that the metric is scriptable
if check_scriptable:
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/audio/si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SI_SDR(Metric):
"""

is_differentiable = True
higher_is_better = True
sum_si_sdr: Tensor
total: Tensor

Expand Down
1 change: 1 addition & 0 deletions torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SI_SNR(Metric):
is_differentiable = True
sum_si_snr: Tensor
total: Tensor
higher_is_better = True

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Metric(Module, ABC):
__jit_ignored_attributes__ = ["device"]
__jit_unused_properties__ = ["is_differentiable"]
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/retrieval/mean_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,7 @@ class RetrievalMAP(RetrievalMetric):
tensor(0.7917)
"""

higher_is_better = True

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_average_precision(preds, target)
2 changes: 2 additions & 0 deletions torchmetrics/retrieval/mean_reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,7 @@ class RetrievalMRR(RetrievalMetric):
tensor(0.7500)
"""

higher_is_better = True

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_reciprocal_rank(preds, target)
2 changes: 2 additions & 0 deletions torchmetrics/retrieval/retrieval_fallout.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class RetrievalFallOut(RetrievalMetric):
tensor(0.5000)
"""

higher_is_better = False

def __init__(
self,
empty_target_action: str = "pos",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class RetrievalMetric(Metric, ABC):
indexes: List[Tensor]
preds: List[Tensor]
target: List[Tensor]
higher_is_better = True

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/retrieval/retrieval_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class RetrievalNormalizedDCG(RetrievalMetric):
tensor(0.8467)
"""

higher_is_better = True

def __init__(
self,
empty_target_action: str = "neg",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/retrieval/retrieval_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class RetrievalPrecision(RetrievalMetric):
tensor(0.5000)
"""

higher_is_better = True

def __init__(
self,
empty_target_action: str = "neg",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/retrieval/retrieval_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class RetrievalRecall(RetrievalMetric):
tensor(0.7500)
"""

higher_is_better = True

def __init__(
self,
empty_target_action: str = "neg",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class BERTScore(Metric):
'f1': [0.99..., 0.99...]}
"""

higher_is_better = True

def __init__(
self,
model_name_or_path: Optional[str] = None,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class BLEUScore(Metric):
"""

is_differentiable = False
higher_is_better = True
trans_len: Tensor
ref_len: Tensor
numerator: Tensor
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class ROUGEScore(Metric):
[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin `Rouge Detail`_
"""

higher_is_better = True

def __init__(
self,
newline_sep: Optional[bool] = None, # remove in v0.7
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/text/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class WER(Metric):
tensor(0.5000)
"""
is_differentiable = False
higher_is_better = False
error: Tensor
total: Tensor

Expand Down

0 comments on commit 1e29035

Please sign in to comment.