Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate multiple Retrieval metrics together for efficiency when group number is large #889

Closed
ZeguanXiao opened this issue Mar 17, 2022 · 2 comments
Labels
enhancement New feature or request
Milestone

Comments

@ZeguanXiao
Copy link

ZeguanXiao commented Mar 17, 2022

🚀 Feature

A way to wrap multiple Retrieval metrics together in order to speed metric computing.

Motivation

When I want to compute k metrics with same (indexes, preds, target) data, TorchMetrics will group indexes k times. When the number of groups is very large, it takes a very long time to do this.

Pitch

I function or class take a list of metrics name to achieve this.

Alternatives

Additional context

@ZeguanXiao ZeguanXiao added the enhancement New feature or request label Mar 17, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @ZeguanXiao,
I am happy to report that this issue should already been solved on master. In this PR #709 we introduced the concept of compute groups that will automatically group together computations from metrics that share the same underlying metric state (like all retrieval metrics). In the run the following example, that uses 3 different retrieval metrics:

from torchmetrics import MetricCollection
from torchmetrics import RetrievalFallOut
from torchmetrics import RetrievalMAP
from torchmetrics import RetrievalNormalizedDCG
import time
import torch

indexes = torch.randint(10, (100,))
preds = torch.rand(100,)
target = torch.randint(10, (100,)).bool()

N = 10000
for cg in [False, True]:
    m = MetricCollection(
        RetrievalFallOut(), RetrievalMAP(), RetrievalNormalizedDCG(), compute_groups=cg
    )

    start = time.time()
    for _ in range(N):
        m.update(preds, target, indexes=indexes)
    print((time.time() - start))

I get
Old metric collection: 1.651695966720581
New metric collection: 0.5793600082397461
meaning that with compute groups enabled (will be by default) it is 3 times faster which corresponds to only one metric actually being updated.
Closing this issue.

@Borda Borda added this to the v0.8 milestone May 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants