Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace thredshold argument to binned metrics #322

Merged
merged 11 commits into from
Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))
- Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301))
- Added `sync` and `sync_context` methods for manually controlling when metric states are synced ([#302](https://github.com/PyTorchLightning/metrics/pull/302))
- Added `thresholds` argument to binned metrics for manually controlling the thresholds ([#322](https://github.com/PyTorchLightning/metrics/pull/322))
- Added `KLDivergence` metric ([#247](https://github.com/PyTorchLightning/metrics/pull/247))

### Changed
Expand All @@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated

- Deprecated `torchmetrics.functional.mean_relative_error` ([#248](https://github.com/PyTorchLightning/metrics/pull/248))
- Deprecated `num_thresholds` argument in `BinnedPrecisionRecallCurve` ([#322](https://github.com/PyTorchLightning/metrics/pull/322))

### Removed

Expand Down
34 changes: 31 additions & 3 deletions tests/classification/test_binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from tests.classification.inputs import _input_multilabel_prob_plausible as _input_mlb_prob_ok
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision
from torchmetrics.classification.binned_precision_recall import (
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
)

seed_all(42)

Expand Down Expand Up @@ -112,8 +116,10 @@ class TestBinnedAveragePrecision(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("num_thresholds", [101, 301])
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds):
@pytest.mark.parametrize(
"num_thresholds, thresholds", ([101, None], [301, None], [None, torch.linspace(0.0, 1.0, 101)])
)
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds, thresholds):
# rounding will simulate binning for both implementations
preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6

Expand All @@ -127,5 +133,27 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o
metric_args={
"num_classes": num_classes,
"num_thresholds": num_thresholds,
"thresholds": thresholds
},
)


@pytest.mark.parametrize(
"metric_class", [BinnedAveragePrecision, BinnedRecallAtFixedPrecision, BinnedPrecisionRecallCurve]
)
def test_raises_errors_and_warning(metric_class):
if metric_class == BinnedRecallAtFixedPrecision:
metric_class = partial(metric_class, min_precision=0.5)

with pytest.warns(
DeprecationWarning,
match="Argument `num_thresholds` "
"is deprecated in v0.4 and will be removed in v0.5. Use `thresholds` instead."
):
metric_class(num_classes=10, num_thresholds=100)

with pytest.raises(
ValueError, match="Expected argument `thresholds` to either"
" be an integer, list of floats or a tensor"
):
metric_class(num_classes=10, thresholds={'temp': [10, 20]})
75 changes: 61 additions & 14 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor
Expand Down Expand Up @@ -52,8 +53,14 @@ class BinnedPrecisionRecallCurve(Metric):

Args:
num_classes: integer with number of classes. For binary, set to 1.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
num_thresholds: number of bins used for computation.

.. deprecated:: v0.4
Use `thresholds`. Will be removed in v0.5.

thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -62,11 +69,15 @@ class BinnedPrecisionRecallCurve(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If ``thresholds`` is not a int, list or tensor

Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, num_thresholds=5)
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, thresholds=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000])
Expand All @@ -81,7 +92,7 @@ class BinnedPrecisionRecallCurve(Metric):
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, num_thresholds=3)
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, thresholds=3)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([0.2500, 1.0000, 1.0000, 1.0000]),
Expand All @@ -106,10 +117,11 @@ class BinnedPrecisionRecallCurve(Metric):
def __init__(
self,
num_classes: int,
num_thresholds: int = 100,
thresholds: Optional[Union[Tensor, List[float]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
num_thresholds: Optional[int] = 100, # ToDo: remove in v0.5
):
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -118,14 +130,27 @@ def __init__(
)

self.num_classes = num_classes
self.num_thresholds = num_thresholds
thresholds = torch.linspace(0, 1.0, num_thresholds)
self.register_buffer("thresholds", thresholds)
if thresholds is None and num_thresholds is not None:
warn(
"Argument `num_thresholds` is deprecated in v0.4 and will be removed in v0.5."
" Use `thresholds` instead.", DeprecationWarning
)
thresholds = num_thresholds
if isinstance(thresholds, int):
self.num_thresholds = thresholds
thresholds = torch.linspace(0, 1.0, thresholds)
self.register_buffer("thresholds", thresholds)
elif thresholds is not None:
if not isinstance(thresholds, (list, Tensor)):
raise ValueError('Expected argument `thresholds` to either be an integer, list of floats or a tensor')
thresholds = torch.tensor(thresholds) if isinstance(thresholds, list) else thresholds
self.num_thresholds = thresholds.numel()
self.register_buffer("thresholds", thresholds)

for name in ("TPs", "FPs", "FNs"):
self.add_state(
name=name,
default=torch.zeros(num_classes, num_thresholds, dtype=torch.float32),
default=torch.zeros(num_classes, self.num_thresholds, dtype=torch.float32),
dist_reduce_fx="sum",
)

Expand Down Expand Up @@ -185,13 +210,23 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
num_thresholds: number of bins used for computation.

.. deprecated:: v0.4
Use `thresholds`. Will be removed in v0.5.

thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If ``thresholds`` is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
Expand Down Expand Up @@ -233,13 +268,23 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
Args:
num_classes: integer with number of classes. Provide 1 for for binary problems.
min_precision: float value specifying minimum precision threshold.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
num_thresholds: number of bins used for computation.

.. deprecated:: v0.4
Use `thresholds`. Will be removed in v0.5.

thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If ``thresholds`` is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedRecallAtFixedPrecision
>>> pred = torch.tensor([0, 0.2, 0.5, 0.8])
Expand All @@ -264,14 +309,16 @@ def __init__(
self,
num_classes: int,
min_precision: float,
num_thresholds: int = 100,
thresholds: Optional[Union[Tensor, List[float]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
num_thresholds: int = 100, # ToDo: remove in v0.5
):
super().__init__(
num_classes=num_classes,
num_thresholds=num_thresholds,
thresholds=thresholds,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

class PSNR(_PSNR):
"""
.. deprecated:: 0.4
The PSNR was moved to `torchmetrics.image.psnr`.
.. deprecated:: v0.4
The PSNR was moved to `torchmetrics.image.psnr`. It will be removed in v0.5.

"""

Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/regression/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

class SSIM(_SSIM):
"""
.. deprecated:: 0.4
The SSIM was moved to `torchmetrics.image.ssim`.
.. deprecated:: v0.4
The SSIM was moved to `torchmetrics.image.ssim`. It will be removed in v0.5.

"""

Expand Down