Skip to content

Commit

Permalink
Merge branch 'bugfix/classwise_wrapper_forward' of https://github.com…
Browse files Browse the repository at this point in the history
…/PyTorchLightning/metrics into bugfix/classwise_wrapper_forward
  • Loading branch information
SkafteNicki committed Sep 16, 2022
2 parents 3aba677 + 3a07922 commit c6b2e39
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `ClasswiseWrapper` such that `compute` gave wrong result ([#1225](https://github.com/Lightning-AI/metrics/pull/1225))


- Fixed synchronization of empty list states ([#1219](https://github.com/Lightning-AI/metrics/pull/1219))


## [0.9.3] - 2022-08-22

### Added
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:
for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)

if isinstance(output_dict[attr], list) and len(output_dict[attr]) == 0:
setattr(self, attr, [])
continue

if isinstance(output_dict[attr][0], Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
Expand Down
13 changes: 13 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,16 @@ def _test_sync_on_compute_list_state(rank, worldsize, sync_on_compute):
def test_sync_on_compute(sync_on_compute, test_func):
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(test_func, args=(2, sync_on_compute), nprocs=2)


def _test_sync_with_empty_lists(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = DummyListMetric()
val = dummy.compute()
assert val == []


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_empty_lists():
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(_test_sync_with_empty_lists, args=(2,), nprocs=2)

0 comments on commit c6b2e39

Please sign in to comment.