Skip to content

Commit

Permalink
Fix metric collection missing update of property (#1052)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored May 27, 2022
1 parent 8263e27 commit 0495862
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed default value for `mdmc_average` in `Accuracy` ([#1036](https://github.com/PyTorchLightning/metrics/pull/1036))


- Fixed missing copy of property when using compute groups in `MetricCollection` ([#1052](https://github.com/PyTorchLightning/metrics/pull/1052))


## [0.8.2] - 2022-05-06


Expand Down
42 changes: 26 additions & 16 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,25 +340,35 @@ def test_check_compute_groups(metrics, expected, prefix, postfix):
assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)
for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

assert m.compute_groups == expected
assert m2.compute_groups == {}
for _, member in m.items():
assert member._update_called

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)
assert m.compute_groups == expected
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])
m.reset()
m2.reset()


@pytest.mark.parametrize(
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def update(self, *args: Any, **kwargs: Any) -> None:
# only update the first member
m0 = getattr(self, cg[0])
m0.update(*args, **m0._filter_kwargs(**kwargs))
for i in range(1, len(cg)): # copy over the update count
mi = getattr(self, cg[i])
mi._update_count = m0._update_count
else: # the first update always do per metric to form compute groups
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
Expand Down

0 comments on commit 0495862

Please sign in to comment.