From 9652899df989c40243e7d985b2d2e6ce87b09bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20K=C3=B6nig?= Date: Sat, 25 Nov 2023 18:36:42 +0100 Subject: [PATCH] Fix `dim_zero_cat` reduction (#2226) Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 3 +++ src/torchmetrics/metric.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e640fe410e1..ff083f29174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234)) +- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226)) + + ## [1.2.0] - 2023-09-22 ### Added diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index f574320dc44..8e8b4dbe337 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -409,7 +409,10 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: elif reduce_fn == dim_zero_min: reduced = torch.min(global_state, local_state) elif reduce_fn == dim_zero_cat: - reduced = global_state + local_state + if isinstance(global_state, Tensor): + reduced = torch.cat([global_state, local_state]) + else: + reduced = global_state + local_state elif reduce_fn is None and isinstance(global_state, Tensor): reduced = torch.stack([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, list):