diff --git a/CHANGELOG.md b/CHANGELOG.md index dcd2e714997..eaab96afb0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608)) +- Fixed restrictive check in `PearsonCorrCoef` ([#1649](https://github.com/Lightning-AI/metrics/pull/1649)) + + - Fixed integration with `jsonargparse` and `LightningCLI` ([#1651](https://github.com/Lightning-AI/metrics/pull/1651)) diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py index a847d01082e..609ddf88f30 100644 --- a/src/torchmetrics/functional/regression/utils.py +++ b/src/torchmetrics/functional/regression/utils.py @@ -21,7 +21,9 @@ def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: f"Expected both predictions and target to be either 1- or 2-dimensional tensors," f" but got {target.ndim} and {preds.ndim}." ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[1]): + cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1) + cond2 = num_outputs > 1 and num_outputs != preds.shape[1] + if cond1 or cond2: raise ValueError( f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" f" and {preds.shape[1]}." diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 8476eb7bbfd..8581f929577 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -129,6 +129,12 @@ def test_error_on_different_shape(): metric(torch.randn(100, 5), torch.randn(100, 5)) +def test_1d_input_allowed(): + """Check that both input of the form [N,] and [N,1] is allowed with default num_outputs argument.""" + assert isinstance(pearson_corrcoef(torch.randn(10, 1), torch.randn(10, 1)), torch.Tensor) + assert isinstance(pearson_corrcoef(torch.randn(10), torch.randn(10)), torch.Tensor) + + @pytest.mark.parametrize("shapes", [(5,), (1, 5), (2, 5)]) def test_final_aggregation_function(shapes): """Test that final aggregation function can take various shapes of input."""