You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A prefix and postfix parameter to change the metric names of ClasswiseWrapper class.
Motivation
I came up with this feature while implementing multiple classification metrics for my Lightning Module which I then want to log via log_dict to Weight and Biases:
As of right now this would result in Metrics like: "val/multiclassaccuracy_0". However, for aesthetics and sorting/grouping reasons I would like to have metrics named: "val/accuracy/class_0" or similar. I know I could manually log each item but this would result in such big boilerplate that the reason to use torchmetrics would be gone.
Pitch
Currently the name of the wrapper is determined from the class name of the metric:
With a prefix (and/or a postfix) it would be possible to customize the output name of the metric, which would be relevant for e.g. logging via log_dict of a Lightning Module:
A name parameter which would replace this.metric.__class__.__name__ if set. This would be an easier implementation but wouldn't give the user the amount of customization options he may need.
Use the dict key as a name when passed to a MetricCollection. This could be an additional feature. Needs to be discussed what should have prio: the prefix/postfix parameter or the keyname in the MetricCollection dict.
I think that Lightning can't handle ClasswiseWrapper as of right now when logging via self.log_dict(self.metrics, on_epoch=True) since it raises an error in the argument validation process: ValueError: The ".compute()" return of the metric logged as 'val/accuracy' must be a tensor. Found {'multiclassaccuracy_0': tensor(0.6000), 'multiclassaccuracy_1': tensor(0.6000)} The reason for this is probably that the ClasswiseWrapper counts as a single item when the MetricCollection.items() function is called:
│ /home/tobias/Repositories/eSPA-Python/.venv/lib/python3.10/site-packages/lightning/pytorch/train │
│ er/connectors/logger_connector/result.py:474 in metrics │
│ │
│ 471 │ │ │
│ 472 │ │ for _, result_metric in self.valid_items(): │
│ 473 │ │ │ # extract forward_cache or computed from the _ResultMetric │
│ ❱ 474 │ │ │ value = self._get_cache(result_metric, on_step) │
│ 475 │ │ │ if not isinstance(value, Tensor): │
│ 476 │ │ │ │ continue │
│ 477 │
│ /home/tobias/Repositories/eSPA-Python/.venv/lib/python3.10/site-packages/lightning/pytorch/train │
│ er/connectors/logger_connector/result.py:445 in _get_cache │
│ │
│ 442 │ │ │
│ 443 │ │ if cache is not None: │
│ 444 │ │ │ if not isinstance(cache, Tensor): │
│ ❱ 445 │ │ │ │ raise ValueError( │
│ 446 │ │ │ │ │ f"The `.compute()` return of the metric logged as {result_metric.met │
│ 447 │ │ │ │ │ f" Found {cache}" │
│ 448 │ │ │ │ ) │
The text was updated successfully, but these errors were encountered:
🚀 Feature
A
prefix
andpostfix
parameter to change the metric names ofClasswiseWrapper
class.Motivation
I came up with this feature while implementing multiple classification metrics for my Lightning Module which I then want to log via
log_dict
to Weight and Biases:In
__init__
method of my lightning module:In validation step:
As of right now this would result in Metrics like: "val/multiclassaccuracy_0". However, for aesthetics and sorting/grouping reasons I would like to have metrics named: "val/accuracy/class_0" or similar. I know I could manually log each item but this would result in such big boilerplate that the reason to use torchmetrics would be gone.
Pitch
Currently the name of the wrapper is determined from the class name of the metric:
With a prefix (and/or a postfix) it would be possible to customize the output name of the metric, which would be relevant for e.g. logging via
log_dict
of a Lightning Module:Alternatives
name
parameter which would replacethis.metric.__class__.__name__
if set. This would be an easier implementation but wouldn't give the user the amount of customization options he may need.MetricCollection
. This could be an additional feature. Needs to be discussed what should have prio: theprefix
/postfix
parameter or the keyname in theMetricCollection
dict.Additional context
I already opened a PR for this: #1866
Important Note
I think that Lightning can't handle
ClasswiseWrapper
as of right now when logging viaself.log_dict(self.metrics, on_epoch=True)
since it raises an error in the argument validation process:ValueError: The ".compute()" return of the metric logged as 'val/accuracy' must be a tensor. Found {'multiclassaccuracy_0': tensor(0.6000), 'multiclassaccuracy_1': tensor(0.6000)}
The reason for this is probably that theClasswiseWrapper
counts as a single item when theMetricCollection.items()
function is called:The text was updated successfully, but these errors were encountered: