Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added add_metrics method to MetricCollection #221

Merged
merged 9 commits into from
May 4, 2021
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `is_differentiable` property to `AUC`, `AUROC`, `CohenKappa` and `AveragePrecision` ([#178](https://github.com/PyTorchLightning/metrics/pull/178))


- Added `add_metrics` method to `MetricCollection` for adding additional metrics after initialization ([#221](https://github.com/PyTorchLightning/metrics/pull/221))


### Changed

- Calling `compute` before `update` will now give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164))
Expand Down
14 changes: 14 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ def test_metric_collection_same_order():
assert k1 == k2


def test_collection_add_metrics():
m1 = DummyMetricSum()
m2 = DummyMetricDiff()

collection = MetricCollection([m1])
collection.add_metrics({'m1_': DummyMetricSum()})
collection.add_metrics(m2)

collection.update(5)
results = collection.compute()
assert results['DummyMetricSum'] == results['m1_'] and results['m1_'] == 5
assert results['DummyMetricDiff'] == -5


def test_collection_check_arg():
assert MetricCollection._check_arg(None, 'prefix') is None
assert MetricCollection._check_arg('sample', 'prefix') == 'sample'
Expand Down
85 changes: 46 additions & 39 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,46 +94,8 @@ def __init__(
postfix: Optional[str] = None
):
super().__init__()
if isinstance(metrics, Metric):
# set compatible with original type expectations
metrics = [metrics]
if isinstance(metrics, Sequence):
# prepare for optional additions
metrics = list(metrics)
remain = []
for m in additional_metrics:
(metrics if isinstance(m, Metric) else remain).append(m)

if remain:
rank_zero_warn(
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
)
elif additional_metrics:
raise ValueError(
f"You have passes extra arguments {additional_metrics} which are not compatible"
f" with first passed dictionary {metrics} so they will be ignored."
)

if isinstance(metrics, dict):
# Check all values are metrics
# Make sure that metrics are added in deterministic order
for name in sorted(metrics.keys()):
metric = metrics[name]
if not isinstance(metric, Metric):
raise ValueError(
f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`"
)
self[name] = metric
elif isinstance(metrics, Sequence):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`")
name = metric.__class__.__name__
if name in self:
raise ValueError(f"Encountered two metrics both named {name}")
self[name] = metric
else:
raise ValueError("Unknown input to MetricCollection.")
self.add_metrics(metrics, *additional_metrics)

self.prefix = self._check_arg(prefix, 'prefix')
self.postfix = self._check_arg(postfix, 'postfix')
Expand Down Expand Up @@ -185,6 +147,51 @@ def persistent(self, mode: bool = True) -> None:
for _, m in self.items():
m.persistent(mode)

def add_metrics(self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_metrics(self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
def append(self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],

would it be a better name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to keep the name different. If we say append it might imply that it won't work as extend while here it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_metrics seems more appropriate for a Collection.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be Mapping[str, Metrics], not Dict + similar a check below.

*additional_metrics: Metric) -> None:
"""Add new metrics to Metric Collection
"""
if isinstance(metrics, Metric):
# set compatible with original type expectations
metrics = [metrics]
if isinstance(metrics, Sequence):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is just copying what we had here before, but we should probably have Iterable here instead of Sequence as we doin't really care about indexing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but iterable does not have len, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not using len though, we are just directly converting to a list in the next line.

# prepare for optional additions
metrics = list(metrics)
remain = []
for m in additional_metrics:
(metrics if isinstance(m, Metric) else remain).append(m)

if remain:
rank_zero_warn(
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
)
elif additional_metrics:
raise ValueError(
f"You have passes extra arguments {additional_metrics} which are not compatible"
f" with first passed dictionary {metrics} so they will be ignored."
)

if isinstance(metrics, dict):
# Check all values are metrics
# Make sure that metrics are added in deterministic order
for name in sorted(metrics.keys()):
metric = metrics[name]
if not isinstance(metric, Metric):
raise ValueError(
f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`"
)
self[name] = metric
elif isinstance(metrics, Sequence):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`")
name = metric.__class__.__name__
if name in self:
raise ValueError(f"Encountered two metrics both named {name}")
self[name] = metric
else:
raise ValueError("Unknown input to MetricCollection.")

def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
name = name if self.postfix is None else name + self.postfix
Expand Down