Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pre and postfix for Classwise Wrapper #1867

Closed
relativityhd opened this issue Jun 29, 2023 · 1 comment · Fixed by #1866
Closed

Implement pre and postfix for Classwise Wrapper #1867

relativityhd opened this issue Jun 29, 2023 · 1 comment · Fixed by #1866
Labels
enhancement New feature or request

Comments

@relativityhd
Copy link
Contributor

🚀 Feature

A prefix and postfix parameter to change the metric names of ClasswiseWrapper class.

Motivation

I came up with this feature while implementing multiple classification metrics for my Lightning Module which I then want to log via log_dict to Weight and Biases:

In __init__ method of my lightning module:

def __init__(self, ...):
    ...
    macro_metrics = MetricCollection(
            {
                "accuracy": MulticlassAccuracy(num_classes=M, average="macro"),
                "precision": MulticlassPrecision(num_classes=M, average="macro"),
                "recall": MulticlassRecall(num_classes=M, average="macro"),
                "f1": MulticlassF1Score(num_classes=M, average="macro"),
                "jaccard": MulticlassJaccardIndex(num_classes=M, average="macro"),
            },
            prefix="macro-",
        )

        micro_metrics = MetricCollection(
            {
                "accuracy": MulticlassAccuracy(num_classes=M, average="micro"),
                "precision": MulticlassPrecision(num_classes=M, average="micro"),
                "recall": MulticlassRecall(num_classes=M, average="micro"),
                "f1": MulticlassF1Score(num_classes=M, average="micro"),
                "jaccard": MulticlassJaccardIndex(num_classes=M, average="micro"),
            },
            prefix="micro-",
        )

        classwise_metrics = MetricCollection(
            {
                "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=M, average=None), labels=self.labels),
                "precision": ClasswiseWrapper(MulticlassPrecision(num_classes=M, average=None), labels=self.labels),
                "recall": ClasswiseWrapper(MulticlassRecall(num_classes=M, average=None), labels=self.labels),
                "f1": ClasswiseWrapper(MulticlassF1Score(num_classes=M, average=None), labels=self.labels),
                "jaccard": ClasswiseWrapper(MulticlassJaccardIndex(num_classes=M, average=None), labels=self.labels),
            }
        )
        self.metrics = MetricCollection([macro_metrics, micro_metrics, classwise_metrics], prefix="val/")

In validation step:

def validation_step(self, ...):
    ...
    self.metrics.update(preds, target)
    self.log_dict(self.metrics, on_epoch=True)

As of right now this would result in Metrics like: "val/multiclassaccuracy_0". However, for aesthetics and sorting/grouping reasons I would like to have metrics named: "val/accuracy/class_0" or similar. I know I could manually log each item but this would result in such big boilerplate that the reason to use torchmetrics would be gone.

Pitch

Currently the name of the wrapper is determined from the class name of the metric:

>>> import torch
>>> from torchmetrics import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
...    MulticlassAccuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.3333),
'multiclassaccuracy_fish': tensor(0.6667),
'multiclassaccuracy_dog': tensor(0.)}

With a prefix (and/or a postfix) it would be possible to customize the output name of the metric, which would be relevant for e.g. logging via log_dict of a Lightning Module:

>>> import torch
>>> from torchmetrics import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
...    MulticlassAccuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"],
...    prefix="accuracy-"
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'accuracy-horse': tensor(0.3333),
'accuracy-fish': tensor(0.6667),
'accuracy-dog': tensor(0.)}

Alternatives

  1. A name parameter which would replace this.metric.__class__.__name__ if set. This would be an easier implementation but wouldn't give the user the amount of customization options he may need.
  2. Use the dict key as a name when passed to a MetricCollection. This could be an additional feature. Needs to be discussed what should have prio: the prefix/postfix parameter or the keyname in the MetricCollection dict.

Additional context

I already opened a PR for this: #1866


Important Note

I think that Lightning can't handle ClasswiseWrapper as of right now when logging via self.log_dict(self.metrics, on_epoch=True) since it raises an error in the argument validation process: ValueError: The ".compute()" return of the metric logged as 'val/accuracy' must be a tensor. Found {'multiclassaccuracy_0': tensor(0.6000), 'multiclassaccuracy_1': tensor(0.6000)} The reason for this is probably that the ClasswiseWrapper counts as a single item when the MetricCollection.items() function is called:

│ /home/tobias/Repositories/eSPA-Python/.venv/lib/python3.10/site-packages/lightning/pytorch/train │
│ er/connectors/logger_connector/result.py:474 in metrics                                          │
│                                                                                                  │
│   471 │   │                                                                                      │
│   472 │   │   for _, result_metric in self.valid_items():                                        │
│   473 │   │   │   # extract forward_cache or computed from the _ResultMetric                     │
│ ❱ 474 │   │   │   value = self._get_cache(result_metric, on_step)                                │
│   475 │   │   │   if not isinstance(value, Tensor):                                              │
│   476 │   │   │   │   continue                                                                   │
│   477                                                                                            │

│ /home/tobias/Repositories/eSPA-Python/.venv/lib/python3.10/site-packages/lightning/pytorch/train │
│ er/connectors/logger_connector/result.py:445 in _get_cache                                       │
│                                                                                                  │
│   442 │   │                                                                                      │
│   443 │   │   if cache is not None:                                                              │
│   444 │   │   │   if not isinstance(cache, Tensor):                                              │
│ ❱ 445 │   │   │   │   raise ValueError(                                                          │
│   446 │   │   │   │   │   f"The `.compute()` return of the metric logged as {result_metric.met   │
│   447 │   │   │   │   │   f" Found {cache}"                                                      │
│   448 │   │   │   │   )                                                                          │
@relativityhd relativityhd added the enhancement New feature or request label Jun 29, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant