diff --git a/CHANGELOG.md b/CHANGELOG.md index 862f9a3c4c4..8feedfaaf28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed metric calculation with unequal batch sizes ([#220](https://github.com/PyTorchLightning/metrics/pull/220)) +- Fixed metric concatenation for list states for zero-dim input ([#229](https://github.com/PyTorchLightning/metrics/pull/229)) + + ## [0.3.1] - 2021-04-21 - Cleaning remaining inconsistency and fix PL develop integration ( diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index fdfd56b7e16..ab97e150d96 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -59,10 +59,10 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Te f" Got preds: {preds.dtype} and target: {target.dtype}." ) _check_same_shape(preds, target) - + preds = preds.squeeze() + target = target.squeeze() if preds.ndim > 1 or target.ndim > 1: raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') - return preds, target diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 002d06395cd..ed00ab1478a 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -21,16 +21,17 @@ METRIC_EPS = 1e-6 -def dim_zero_cat(x): +def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: x = x if isinstance(x, (list, tuple)) else [x] + x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] return torch.cat(x, dim=0) -def dim_zero_sum(x): +def dim_zero_sum(x: Tensor) -> Tensor: return torch.sum(x, dim=0) -def dim_zero_mean(x): +def dim_zero_mean(x: Tensor) -> Tensor: return torch.mean(x, dim=0)