Skip to content

Commit

Permalink
Raise exception for invalid kwargs in Metric base class (#1427)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Fixes #1426
  • Loading branch information
EPronovost authored Jan 12, 2023
1 parent 7b505ff commit 4f3cab9
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))


- Raise exception for invalid kwargs in Metric base class ([#1427](https://github.com/Lightning-AI/metrics/pull/1427))


### Deprecated

-
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def __init__(
f"Expected keyword argument `sync_on_compute` to be a `bool` but got {self.sync_on_compute}"
)

if kwargs:
kwargs_ = [f"`{a}`" for a in sorted(kwargs)]
raise ValueError(f"Unexpected keyword arguments: {', '.join(kwargs_)}")

# initialize
self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self):
for stage in ["train", "val", "test"]:
acc = BinaryAccuracy()
acc.reset = mock.Mock(side_effect=acc.reset)
ap = BinaryAveragePrecision(num_classes=1, pos_label=1)
ap = BinaryAveragePrecision()
ap.reset = mock.Mock(side_effect=ap.reset)
self.add_module(f"acc_{stage}", acc)
self.add_module(f"ap_{stage}", ap)
Expand Down
6 changes: 6 additions & 0 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def test_error_on_wrong_input():
with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` bu.*"):
DummyMetric(compute_on_cpu=None)

with pytest.raises(ValueError, match="Unexpected keyword arguments: `foo`"):
DummyMetric(foo=True)

with pytest.raises(ValueError, match="Unexpected keyword arguments: `bar`, `foo`"):
DummyMetric(foo=True, bar=42)


def test_inherit():
"""Test that metric that inherits can be instanciated."""
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/wrappers/test_minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_no_base_metric() -> None:

def test_no_scalar_compute() -> None:
"""tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute."""
min_max_nsm = MinMaxMetric(BinaryConfusionMatrix(num_classes=2))
min_max_nsm = MinMaxMetric(BinaryConfusionMatrix())

with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a scalar .*"):
min_max_nsm.compute()
10 changes: 4 additions & 6 deletions tests/unittests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,21 @@ def _multi_target_sk_accuracy(preds, target, num_outputs):


@pytest.mark.parametrize(
"base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs",
"base_metric_class, compare_metric, preds, target, num_outputs",
[
(
R2Score,
_multi_target_sk_r2score,
_multi_target_regression_inputs.preds,
_multi_target_regression_inputs.target,
num_targets,
{},
),
(
MulticlassAccuracy,
partial(MulticlassAccuracy, num_classes=NUM_CLASSES, average="micro"),
partial(_multi_target_sk_accuracy, num_outputs=2),
_multi_target_classification_inputs.preds,
_multi_target_classification_inputs.target,
num_targets,
dict(num_classes=NUM_CLASSES, average="micro"),
),
],
)
Expand All @@ -109,7 +107,7 @@ class TestMultioutputWrapper(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_multioutput_wrapper(
self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step
self, base_metric_class, compare_metric, preds, target, num_outputs, ddp, dist_sync_on_step
):
"""Test that the multioutput wrapper properly slices and computes outputs along the output dimension for
both classification and regression metrics."""
Expand All @@ -120,5 +118,5 @@ def test_multioutput_wrapper(
_MultioutputMetric,
compare_metric,
dist_sync_on_step,
metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs),
metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class),
)

0 comments on commit 4f3cab9

Please sign in to comment.