Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

True update() for MetricCollection #203

Closed
IgorHoholko opened this issue Apr 27, 2021 · 8 comments · Fixed by #221
Closed

True update() for MetricCollection #203

IgorHoholko opened this issue Apr 27, 2021 · 8 comments · Fixed by #221
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@IgorHoholko
Copy link
Contributor

🚀 Feature

MetricCollection class overrides update method of parent nn.ModuleDict class. Is it possible to add update from parent's class to MetricCollection?

 def update_(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
        # ... another work for List[Metric], Tuple[Metric] ...
        else:
            return super().update(metrics)

Example of using

metrics = MetricCollection({metric_name: getattr(metrics_obj, metric_name)
                                    for metric_name in self.metrics})

self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
self.test_metrics.update({metric_name: getattr(metrics_obj, metric_name)
                            for metric_name in  ('auroc', 'auroc_macro', 'auroc_weighted')})
@IgorHoholko IgorHoholko added enhancement New feature or request help wanted Extra attention is needed labels Apr 27, 2021
@SkafteNicki
Copy link
Member

You can use super to call the parent method. Basic example

import torch
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, ConfusionMatrix
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metrics = MetricCollection([
    Accuracy(),
    Precision(num_classes=3, average='macro'),
    Recall(num_classes=3, average='macro')
])
print(metrics)
>>>MetricCollection(
>>>  (Accuracy): Accuracy()
>>>  (Precision): Precision()
>>>  (Recall): Recall()
>>>)
# call super
super(type(metrics), metrics).update({'confmat': ConfusionMatrix(num_classes=3)})
print(metrics)
>>>MetricCollection(
>>>  (Accuracy): Accuracy()
>>>  (Precision): Precision()
>>>  (Recall): Recall()
>>>  (confmat): ConfusionMatrix()
>>>)

@IgorHoholko
Copy link
Contributor Author

@SkafteNicki Makes sence. Thanks!

@SkafteNicki
Copy link
Member

@IgorHoholko you are welcome. Closing the issue. Feel free to re-open if necessary :]

@maximsch2
Copy link
Contributor

Should we do something like

def add_metrics(self, metrics):
   super().update(metrics)

?

@IgorHoholko
Copy link
Contributor Author

@maximsch2 It would definitely be helpful.

@IgorHoholko
Copy link
Contributor Author

IgorHoholko commented Apr 27, 2021

metrics.add_metrics() looks better and more intuitive than super(type(metrics), metrics).update() :)

@SkafteNicki
Copy link
Member

fine by me :]
@IgorHoholko want to send a PR?

@SkafteNicki SkafteNicki reopened this Apr 27, 2021
@IgorHoholko
Copy link
Contributor Author

@SkafteNicki Yes, a bit later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants