-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added add_metrics method to MetricCollection #221
Changes from 6 commits
cc91ef3
fb8e815
115aec5
55d2896
5a3094f
c6037da
1263786
3b2fdf9
cd1db7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,46 +94,8 @@ def __init__( | |
postfix: Optional[str] = None | ||
): | ||
super().__init__() | ||
if isinstance(metrics, Metric): | ||
# set compatible with original type expectations | ||
metrics = [metrics] | ||
if isinstance(metrics, Sequence): | ||
# prepare for optional additions | ||
metrics = list(metrics) | ||
remain = [] | ||
for m in additional_metrics: | ||
(metrics if isinstance(m, Metric) else remain).append(m) | ||
|
||
if remain: | ||
rank_zero_warn( | ||
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." | ||
) | ||
elif additional_metrics: | ||
raise ValueError( | ||
f"You have passes extra arguments {additional_metrics} which are not compatible" | ||
f" with first passed dictionary {metrics} so they will be ignored." | ||
) | ||
|
||
if isinstance(metrics, dict): | ||
# Check all values are metrics | ||
# Make sure that metrics are added in deterministic order | ||
for name in sorted(metrics.keys()): | ||
metric = metrics[name] | ||
if not isinstance(metric, Metric): | ||
raise ValueError( | ||
f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" | ||
) | ||
self[name] = metric | ||
elif isinstance(metrics, Sequence): | ||
for metric in metrics: | ||
if not isinstance(metric, Metric): | ||
raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") | ||
name = metric.__class__.__name__ | ||
if name in self: | ||
raise ValueError(f"Encountered two metrics both named {name}") | ||
self[name] = metric | ||
else: | ||
raise ValueError("Unknown input to MetricCollection.") | ||
self.add_metrics(metrics, *additional_metrics) | ||
|
||
self.prefix = self._check_arg(prefix, 'prefix') | ||
self.postfix = self._check_arg(postfix, 'postfix') | ||
|
@@ -185,6 +147,51 @@ def persistent(self, mode: bool = True) -> None: | |
for _, m in self.items(): | ||
m.persistent(mode) | ||
|
||
def add_metrics(self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be |
||
*additional_metrics: Metric) -> None: | ||
"""Add new metrics to Metric Collection | ||
""" | ||
if isinstance(metrics, Metric): | ||
# set compatible with original type expectations | ||
metrics = [metrics] | ||
if isinstance(metrics, Sequence): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand this is just copying what we had here before, but we should probably have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but iterable does not have len, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are not using len though, we are just directly converting to a list in the next line. |
||
# prepare for optional additions | ||
metrics = list(metrics) | ||
remain = [] | ||
for m in additional_metrics: | ||
(metrics if isinstance(m, Metric) else remain).append(m) | ||
|
||
if remain: | ||
rank_zero_warn( | ||
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." | ||
) | ||
elif additional_metrics: | ||
raise ValueError( | ||
f"You have passes extra arguments {additional_metrics} which are not compatible" | ||
f" with first passed dictionary {metrics} so they will be ignored." | ||
) | ||
|
||
if isinstance(metrics, dict): | ||
# Check all values are metrics | ||
# Make sure that metrics are added in deterministic order | ||
for name in sorted(metrics.keys()): | ||
metric = metrics[name] | ||
if not isinstance(metric, Metric): | ||
raise ValueError( | ||
f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" | ||
) | ||
self[name] = metric | ||
elif isinstance(metrics, Sequence): | ||
for metric in metrics: | ||
if not isinstance(metric, Metric): | ||
raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") | ||
name = metric.__class__.__name__ | ||
if name in self: | ||
raise ValueError(f"Encountered two metrics both named {name}") | ||
self[name] = metric | ||
else: | ||
raise ValueError("Unknown input to MetricCollection.") | ||
|
||
def _set_name(self, base: str) -> str: | ||
name = base if self.prefix is None else self.prefix + base | ||
name = name if self.postfix is None else name + self.postfix | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be a better name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to keep the name different. If we say
append
it might imply that it won't work asextend
while here it does.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add_metrics
seems more appropriate for a Collection.