Skip to content

Commit

Permalink
Warning on wrong call order (#164)
Browse files Browse the repository at this point in the history
* fix

* remove space

* pep8

* fix tests

* changelog

* test

* update based on discussion

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored May 3, 2021
1 parent 07eeca2 commit 8284657
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Calling `compute` before `update` will now give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164))


### Deprecated

Expand Down
43 changes: 37 additions & 6 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, val_to_return):
super().__init__()
self._num_updates = 0
self._val_to_return = val_to_return
self._update_called = True

def update(self, *args, **kwargs) -> None:
self._num_updates += 1
Expand Down Expand Up @@ -57,6 +58,9 @@ def test_metrics_add(second_operand, expected_result):
assert isinstance(final_add, CompositionalMetric)
assert isinstance(final_radd, CompositionalMetric)

final_add.update()
final_radd.update()

assert torch.allclose(expected_result, final_add.compute())
assert torch.allclose(expected_result, final_radd.compute())

Expand All @@ -75,6 +79,8 @@ def test_metrics_and(second_operand, expected_result):
assert isinstance(final_and, CompositionalMetric)
assert isinstance(final_rand, CompositionalMetric)

final_and.update()
final_rand.update()
assert torch.allclose(expected_result, final_and.compute())
assert torch.allclose(expected_result, final_rand.compute())

Expand All @@ -95,6 +101,7 @@ def test_metrics_eq(second_operand, expected_result):

assert isinstance(final_eq, CompositionalMetric)

final_eq.update()
# can't use allclose for bool tensors
assert (expected_result == final_eq.compute()).all()

Expand All @@ -116,6 +123,7 @@ def test_metrics_floordiv(second_operand, expected_result):

assert isinstance(final_floordiv, CompositionalMetric)

final_floordiv.update()
assert torch.allclose(expected_result, final_floordiv.compute())


Expand All @@ -135,6 +143,7 @@ def test_metrics_ge(second_operand, expected_result):

assert isinstance(final_ge, CompositionalMetric)

final_ge.update()
# can't use allclose for bool tensors
assert (expected_result == final_ge.compute()).all()

Expand All @@ -155,6 +164,7 @@ def test_metrics_gt(second_operand, expected_result):

assert isinstance(final_gt, CompositionalMetric)

final_gt.update()
# can't use allclose for bool tensors
assert (expected_result == final_gt.compute()).all()

Expand All @@ -175,6 +185,7 @@ def test_metrics_le(second_operand, expected_result):

assert isinstance(final_le, CompositionalMetric)

final_le.update()
# can't use allclose for bool tensors
assert (expected_result == final_le.compute()).all()

Expand All @@ -195,6 +206,7 @@ def test_metrics_lt(second_operand, expected_result):

assert isinstance(final_lt, CompositionalMetric)

final_lt.update()
# can't use allclose for bool tensors
assert (expected_result == final_lt.compute()).all()

Expand All @@ -210,6 +222,7 @@ def test_metrics_matmul(second_operand, expected_result):

assert isinstance(final_matmul, CompositionalMetric)

final_matmul.update()
assert torch.allclose(expected_result, final_matmul.compute())


Expand All @@ -228,6 +241,8 @@ def test_metrics_mod(second_operand, expected_result):
final_mod = first_metric % second_operand

assert isinstance(final_mod, CompositionalMetric)

final_mod.update()
# prevent Runtime error for PT 1.8 - Long did not match Float
assert torch.allclose(expected_result.to(float), final_mod.compute().to(float))

Expand All @@ -250,6 +265,8 @@ def test_metrics_mul(second_operand, expected_result):
assert isinstance(final_mul, CompositionalMetric)
assert isinstance(final_rmul, CompositionalMetric)

final_mul.update()
final_rmul.update()
assert torch.allclose(expected_result, final_mul.compute())
assert torch.allclose(expected_result, final_rmul.compute())

Expand All @@ -270,6 +287,7 @@ def test_metrics_ne(second_operand, expected_result):

assert isinstance(final_ne, CompositionalMetric)

final_ne.update()
# can't use allclose for bool tensors
assert (expected_result == final_ne.compute()).all()

Expand All @@ -288,6 +306,8 @@ def test_metrics_or(second_operand, expected_result):
assert isinstance(final_or, CompositionalMetric)
assert isinstance(final_ror, CompositionalMetric)

final_or.update()
final_ror.update()
assert torch.allclose(expected_result, final_or.compute())
assert torch.allclose(expected_result, final_ror.compute())

Expand All @@ -308,6 +328,7 @@ def test_metrics_pow(second_operand, expected_result):

assert isinstance(final_pow, CompositionalMetric)

final_pow.update()
assert torch.allclose(expected_result, final_pow.compute())


Expand All @@ -322,6 +343,8 @@ def test_metrics_rfloordiv(first_operand, expected_result):
final_rfloordiv = first_operand // second_operand

assert isinstance(final_rfloordiv, CompositionalMetric)

final_rfloordiv.update()
assert torch.allclose(expected_result, final_rfloordiv.compute())


Expand All @@ -336,6 +359,7 @@ def test_metrics_rmatmul(first_operand, expected_result):

assert isinstance(final_rmatmul, CompositionalMetric)

final_rmatmul.update()
assert torch.allclose(expected_result, final_rmatmul.compute())


Expand All @@ -350,6 +374,7 @@ def test_metrics_rmod(first_operand, expected_result):

assert isinstance(final_rmod, CompositionalMetric)

final_rmod.update()
assert torch.allclose(expected_result, final_rmod.compute())


Expand All @@ -367,7 +392,7 @@ def test_metrics_rpow(first_operand, expected_result):
final_rpow = first_operand**second_operand

assert isinstance(final_rpow, CompositionalMetric)

final_rpow.update()
assert torch.allclose(expected_result, final_rpow.compute())


Expand All @@ -386,7 +411,7 @@ def test_metrics_rsub(first_operand, expected_result):
final_rsub = first_operand - second_operand

assert isinstance(final_rsub, CompositionalMetric)

final_rsub.update()
assert torch.allclose(expected_result, final_rsub.compute())


Expand All @@ -406,7 +431,7 @@ def test_metrics_rtruediv(first_operand, expected_result):
final_rtruediv = first_operand / second_operand

assert isinstance(final_rtruediv, CompositionalMetric)

final_rtruediv.update()
assert torch.allclose(expected_result, final_rtruediv.compute())


Expand All @@ -425,7 +450,7 @@ def test_metrics_sub(second_operand, expected_result):
final_sub = first_metric - second_operand

assert isinstance(final_sub, CompositionalMetric)

final_sub.update()
assert torch.allclose(expected_result, final_sub.compute())


Expand All @@ -445,7 +470,7 @@ def test_metrics_truediv(second_operand, expected_result):
final_truediv = first_metric / second_operand

assert isinstance(final_truediv, CompositionalMetric)

final_truediv.update()
assert torch.allclose(expected_result, final_truediv.compute())


Expand All @@ -463,6 +488,8 @@ def test_metrics_xor(second_operand, expected_result):
assert isinstance(final_xor, CompositionalMetric)
assert isinstance(final_rxor, CompositionalMetric)

final_xor.update()
final_rxor.update()
assert torch.allclose(expected_result, final_xor.compute())
assert torch.allclose(expected_result, final_rxor.compute())

Expand All @@ -473,7 +500,7 @@ def test_metrics_abs():
final_abs = abs(first_metric)

assert isinstance(final_abs, CompositionalMetric)

final_abs.update()
assert torch.allclose(tensor(1), final_abs.compute())


Expand All @@ -482,6 +509,7 @@ def test_metrics_invert():

final_inverse = ~first_metric
assert isinstance(final_inverse, CompositionalMetric)
final_inverse.update()
assert torch.allclose(tensor(-2), final_inverse.compute())


Expand All @@ -490,6 +518,7 @@ def test_metrics_neg():

final_neg = neg(first_metric)
assert isinstance(final_neg, CompositionalMetric)
final_neg.update()
assert torch.allclose(tensor(-1), final_neg.compute())


Expand All @@ -498,6 +527,7 @@ def test_metrics_pos():

final_pos = pos(first_metric)
assert isinstance(final_pos, CompositionalMetric)
final_pos.update()
assert torch.allclose(tensor(1), final_pos.compute())


Expand All @@ -510,6 +540,7 @@ def test_metrics_getitem(value, idx, expected_result):

final_getitem = first_metric[idx]
assert isinstance(final_getitem, CompositionalMetric)
final_getitem.update()
assert torch.allclose(expected_result, final_getitem.compute())


Expand Down
22 changes: 22 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,28 @@ def test_device_and_dtype_transfer(tmpdir):
assert metric.x.dtype == torch.float16


def test_warning_on_compute_before_update():
metric = DummyMetricSum()

# make sure everything is fine with forward
with pytest.warns(None) as record:
val = metric(1)
assert not record

metric.reset()

with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'):
val = metric.compute()
assert val == 0.0

# after update things should be fine
metric.update(2.0)
with pytest.warns(None) as record:
val = metric.compute()
assert not record
assert val == 2.0


def test_metric_scripts():
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())
13 changes: 12 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
from torch import Tensor, nn

from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities import apply_to_collection, rank_zero_warn
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version
Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
self.compute = self._wrap_compute(self.compute)
self._computed = None
self._forward_cache = None
self._update_called = False

# initialize state
self._defaults = {}
Expand Down Expand Up @@ -211,6 +212,7 @@ def _wrap_update(self, update):
@functools.wraps(update)
def wrapped_func(*args, **kwargs):
self._computed = None
self._update_called = True
return update(*args, **kwargs)

return wrapped_func
Expand All @@ -219,6 +221,14 @@ def _wrap_compute(self, compute):

@functools.wraps(compute)
def wrapped_func(*args, **kwargs):
if not self._update_called:
rank_zero_warn(
f"The ``compute`` method of metric {self.__class__.__name__}"
" was called before the ``update`` method which may lead to errors,"
" as metric states have not yet been updated.",
UserWarning
)

# return cached value
if self._computed is not None:
return self._computed
Expand Down Expand Up @@ -267,6 +277,7 @@ def reset(self):
"""
This method automatically resets the metric state variables to their default value.
"""
self._update_called = False
# lower lightning versions requires this implicitly to log metric objects correctly in self.log
if not _LIGHTNING_AVAILABLE or self._LIGHTNING_GREATER_EQUAL_1_3:
self._computed = None
Expand Down

0 comments on commit 8284657

Please sign in to comment.