Skip to content

Commit

Permalink
Remove compute_on_step from aggregation and tests (#990)
Browse files Browse the repository at this point in the history
* update

* changelog

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and mergify[bot] authored Apr 29, 2022
1 parent 452864b commit eb4cfaf
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 68 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

- Removed deprecated `compute_on_step` argument ([#962](https://github.com/PyTorchLightning/metrics/pull/962))
- Removed deprecated `compute_on_step` argument in base classes([#962](https://github.com/PyTorchLightning/metrics/pull/962))


- Removed deprecated `compute_on_step` argument in Regression ([#967](https://github.com/PyTorchLightning/metrics/pull/967))
Expand All @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed deprecated `compute_on_step` argument in Wrappers ([#991](https://github.com/PyTorchLightning/metrics/pull/991))


- Removed deprecated `compute_on_step` argument in aggregation ([#990](https://github.com/PyTorchLightning/metrics/pull/990))


### Fixed

-
Expand Down
2 changes: 0 additions & 2 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,6 @@ def test_fbeta_f1(
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
check_dist_sync_on_step=True,
check_batch=True,
)

def test_fbeta_f1_functional(
Expand Down
2 changes: 0 additions & 2 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def test_precision_recall_class(
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
check_dist_sync_on_step=True,
check_batch=True,
)

def test_precision_recall_fn(
Expand Down
2 changes: 0 additions & 2 deletions tests/classification/test_specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ def test_specificity_class(
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
check_dist_sync_on_step=True,
check_batch=True,
)

def test_specificity_fn(
Expand Down
2 changes: 0 additions & 2 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ def test_stat_scores_class(
"ignore_index": ignore_index,
"top_k": top_k,
},
check_dist_sync_on_step=True,
check_batch=True,
)

def test_stat_scores_fn(
Expand Down
4 changes: 1 addition & 3 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def _class_test(
metric_args = {}

# Instantiate metric
metric = metric_class(
compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args
)
metric = metric_class(dist_sync_on_step=dist_sync_on_step, **metric_args)
with pytest.raises(RuntimeError):
metric.is_differentiable = not metric.is_differentiable
with pytest.raises(RuntimeError):
Expand Down
4 changes: 1 addition & 3 deletions tests/text/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def _class_test(
metric_args = {}

# Instanciate metric
metric = metric_class(
compute_on_step=check_dist_sync_on_step or check_batch, dist_sync_on_step=dist_sync_on_step, **metric_args
)
metric = metric_class(dist_sync_on_step=dist_sync_on_step, **metric_args)

# check that the metric is scriptable
if check_scriptable:
Expand Down
52 changes: 3 additions & 49 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Union

import torch
from torch import Tensor
Expand All @@ -33,12 +33,6 @@ class BaseAggregator(Metric):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -55,10 +49,9 @@ def __init__(
fn: Union[Callable, str],
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
):
super().__init__(compute_on_step=compute_on_step, **kwargs)
super().__init__(**kwargs)
allowed_nan_strategy = ("error", "warn", "ignore")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
Expand Down Expand Up @@ -108,12 +101,6 @@ class MaxMetric(BaseAggregator):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -132,14 +119,12 @@ class MaxMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"max",
-torch.tensor(float("inf")),
nan_strategy,
compute_on_step,
**kwargs,
)

Expand All @@ -165,12 +150,6 @@ class MinMetric(BaseAggregator):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -189,14 +168,12 @@ class MinMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"min",
torch.tensor(float("inf")),
nan_strategy,
compute_on_step,
**kwargs,
)

Expand All @@ -222,12 +199,6 @@ class SumMetric(BaseAggregator):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -246,14 +217,12 @@ class SumMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"sum",
torch.tensor(0.0),
nan_strategy,
compute_on_step,
**kwargs,
)

Expand All @@ -278,12 +247,6 @@ class CatMetric(BaseAggregator):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -302,10 +265,9 @@ class CatMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
):
super().__init__("cat", [], nan_strategy, compute_on_step, **kwargs)
super().__init__("cat", [], nan_strategy, **kwargs)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
"""Update state with data.
Expand Down Expand Up @@ -335,12 +297,6 @@ class MeanMetric(BaseAggregator):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -359,14 +315,12 @@ class MeanMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"sum",
torch.tensor(0.0),
nan_strategy,
compute_on_step,
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")
Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ def add_state(

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Automatically calls ``update()``.
"""``forward`` serves the dual purpose of both computing the metric on the current batch of inputs but also
add the batch statistics to the overall accumululating metric state.
Returns the metric value over inputs if ``compute_on_step`` is True.
Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as
the output of ``compute``.
"""
# add current step
if self._is_synced:
Expand Down Expand Up @@ -799,12 +801,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
)

if val_a is None:
# compute_on_step of metric_a is False
return None

if val_b is None:
if isinstance(self.metric_b, Metric):
# compute_on_step of metric_b is False
return None

# Unary op
Expand Down

0 comments on commit eb4cfaf

Please sign in to comment.