Skip to content

Commit

Permalink
Bugfix: fix compute when called on empty lists (#1219)
Browse files Browse the repository at this point in the history
* fix ##1218
* tests
* changelog

Co-authored-by: Aws user for bootstrap <aws_install@ip-10-102-147-234.ef73-poctrustai.aws.cloud.airbus-v.corp>
Co-authored-by: SkafteNicki <[email protected]>
  • Loading branch information
3 people authored Sep 15, 2022
1 parent 8086097 commit 3fd3dc4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Jaccard index ([#1205](https://github.com/Lightning-AI/metrics/pull/1205))


- Fixed synchronization of empty list states ([#1219](https://github.com/Lightning-AI/metrics/pull/1219))


## [0.9.3] - 2022-08-22

### Added
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:
for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)

if isinstance(output_dict[attr], list) and len(output_dict[attr]) == 0:
setattr(self, attr, [])
continue

if isinstance(output_dict[attr][0], Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
Expand Down
13 changes: 13 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,16 @@ def _test_sync_on_compute_list_state(rank, worldsize, sync_on_compute):
def test_sync_on_compute(sync_on_compute, test_func):
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(test_func, args=(2, sync_on_compute), nprocs=2)


def _test_sync_with_empty_lists(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = DummyListMetric()
val = dummy.compute()
assert val == []


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_empty_lists():
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(_test_sync_with_empty_lists, args=(2,), nprocs=2)

0 comments on commit 3fd3dc4

Please sign in to comment.