Skip to content

Commit

Permalink
Fix ComputationalMetric.forward() to not only consider the last batch (
Browse files Browse the repository at this point in the history
…#645)

* Apply suggestions from code review

Co-authored-by: Björn Barz <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit b4ec3c5)
  • Loading branch information
Callidior authored and Borda committed Dec 5, 2021
1 parent 8f6a733 commit ceb25a1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606))


- Fixed `forward` in compositional metrics ([#645](https://github.com/PyTorchLightning/metrics/pull/645))


## [0.6.0] - 2021-10-28

Expand Down
31 changes: 25 additions & 6 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class DummyMetric(Metric):
def __init__(self, val_to_return):
super().__init__()
self._num_updates = 0
self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")
self._val_to_return = val_to_return
self._update_called = True

Expand All @@ -34,10 +34,6 @@ def update(self, *args, **kwargs) -> None:
def compute(self):
return tensor(self._val_to_return)

def reset(self):
self._num_updates = 0
return super().reset()


@pytest.mark.parametrize(
["second_operand", "expected_result"],
Expand Down Expand Up @@ -544,7 +540,7 @@ def test_metrics_getitem(value, idx, expected_result):


def test_compositional_metrics_update():

"""test update method for compositional metrics."""
compos = DummyMetric(5) + DummyMetric(4)

assert isinstance(compos, CompositionalMetric)
Expand All @@ -557,3 +553,26 @@ def test_compositional_metrics_update():

assert compos.metric_a._num_updates == 3
assert compos.metric_b._num_updates == 3


@pytest.mark.parametrize("compute_on_step", [True, False])
@pytest.mark.parametrize("metric_b", [4, DummyMetric(4)])
def test_compositional_metrics_forward(compute_on_step, metric_b):
"""test forward method of compositional metrics."""
metric_a = DummyMetric(5)
metric_a.compute_on_step = compute_on_step
compos = metric_a + metric_b

assert isinstance(compos, CompositionalMetric)
for _ in range(3):
val = compos()
assert val == 9 if compute_on_step else val is None

assert isinstance(compos.metric_a, DummyMetric)
assert compos.metric_a._num_updates == 3

if isinstance(metric_b, DummyMetric):
assert isinstance(compos.metric_b, DummyMetric)
assert compos.metric_b._num_updates == 3

compos.reset()
32 changes: 32 additions & 0 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,35 @@ def compute(self) -> Any:

return self.op(val_a, val_b)

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:

val_a = (
self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs))
if isinstance(self.metric_a, Metric)
else self.metric_a
)
val_b = (
self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs))
if isinstance(self.metric_b, Metric)
else self.metric_b
)

if val_a is None:
# compute_on_step of metric_a is False
return None

if val_b is None:
if isinstance(self.metric_b, Metric):
# compute_on_step of metric_b is False
return None

# Unary op
return self.op(val_a)

# Binary op
return self.op(val_a, val_b)

def reset(self) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.reset()
Expand All @@ -765,3 +794,6 @@ def __repr__(self) -> str:
repr_str = self.__class__.__name__ + _op_metrics

return repr_str

def _wrap_compute(self, compute: Callable) -> Callable:
return compute

0 comments on commit ceb25a1

Please sign in to comment.