Skip to content

Commit

Permalink
Fix restrictive input check in pearson corr coef (#1649)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 28, 2023
1 parent 4d5147a commit be0a9e6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608))


- Fixed restrictive check in `PearsonCorrCoef` ([#1649](https://github.com/Lightning-AI/metrics/pull/1649))


- Fixed integration with `jsonargparse` and `LightningCLI` ([#1651](https://github.com/Lightning-AI/metrics/pull/1651))


Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/functional/regression/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs:
f"Expected both predictions and target to be either 1- or 2-dimensional tensors,"
f" but got {target.ndim} and {preds.ndim}."
)
if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[1]):
cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1)
cond2 = num_outputs > 1 and num_outputs != preds.shape[1]
if cond1 or cond2:
raise ValueError(
f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}"
f" and {preds.shape[1]}."
Expand Down
6 changes: 6 additions & 0 deletions tests/unittests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def test_error_on_different_shape():
metric(torch.randn(100, 5), torch.randn(100, 5))


def test_1d_input_allowed():
"""Check that both input of the form [N,] and [N,1] is allowed with default num_outputs argument."""
assert isinstance(pearson_corrcoef(torch.randn(10, 1), torch.randn(10, 1)), torch.Tensor)
assert isinstance(pearson_corrcoef(torch.randn(10), torch.randn(10)), torch.Tensor)


@pytest.mark.parametrize("shapes", [(5,), (1, 5), (2, 5)])
def test_final_aggregation_function(shapes):
"""Test that final aggregation function can take various shapes of input."""
Expand Down

0 comments on commit be0a9e6

Please sign in to comment.