Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binned implementation of PrecisionRecallCurve is extremely slow for large batch sizes #1492

Closed
Callidior opened this issue Feb 9, 2023 · 1 comment · Fixed by #1493
Closed
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority Critical task/issue
Milestone

Comments

@Callidior
Copy link
Contributor

🐛 Bug

Since the classification metrics refactor in torchmetrics 0.10, the update step of PrecisionRecallCurve and metrics based on it (such as AveragePrecision) is extremely slow and memory-consuming when using the binned implementation with many samples. This is, for example, the case in semantic segmentation scenarios, where every pixel is a sample.

The binned implementation is intended to be both faster and memory efficient than computing the exact precision-recall curve. This is currently not the case.

To Reproduce

import torch
from torchmetrics.classification import BinaryAveragePrecision

targets = torch.randint(0, 2, (16, 512, 512), device="cuda")
preds = torch.rand((16, 512, 512), device="cuda")

metric_exact = BinaryAveragePrecision().cuda()
metric_binned = BinaryAveragePrecision(thresholds=200).cuda()

with torch.inference_mode():
    start_exact_evt = torch.cuda.Event(enable_timing=True)
    update_exact_evt = torch.cuda.Event(enable_timing=True)
    compute_exact_evt = torch.cuda.Event(enable_timing=True)
    start_binned_evt = torch.cuda.Event(enable_timing=True)
    update_binned_evt = torch.cuda.Event(enable_timing=True)
    compute_binned_evt = torch.cuda.Event(enable_timing=True)
    
    start_exact_evt.record()
    metric_exact.update(preds, targets)
    update_exact_evt.record()
    ap_exact = metric_exact.compute()
    compute_exact_evt.record()
    
    start_binned_evt.record()
    metric_binned.update(preds, targets)
    update_binned_evt.record()
    ap_binned = metric_exact.compute()
    compute_binned_evt.record()

torch.cuda.synchronize()

print("AP (exact):", ap_exact)
print("AP (binned):", ap_binned)
print("Time (exact): {} ms + {} ms".format(
    start_exact_evt.elapsed_time(update_exact_evt),
    update_exact_evt.elapsed_time(compute_exact_evt)
))
print("Time (binned): {} ms + {} ms".format(
    start_binned_evt.elapsed_time(update_binned_evt),
    update_binned_evt.elapsed_time(compute_binned_evt)
))

Output (using an RTX 3090):

AP (exact): tensor(0.5008, device='cuda:0')
AP (binned): tensor(0.5008, device='cuda:0')
Time (exact): 1.8867520093917847 ms + 4.845024108886719 ms
Time (binned): 10816.78515625 ms + 0.002047999994829297 ms

Binned average precision was 1600x slower than exact average precision.
The memory consumption of exact AP on the GPU was 260 MB, while binned AP consumed 20,000 MB.

Expected behavior

The binned implementation of PrecisionRecallCurve, i.e., with thresholds != None, should be faster and more memory-efficient.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.11.1 (pip)
  • Python & PyTorch Version (e.g., 1.0): Python 3.7.9, PyTorch 1.12.1
  • Any other relevant information such as OS (e.g., Linux): Linux

Additional context

The current implementation of _binary_precision_recall_curve_update compares all predictions with all thresholds at once, which is very memory consuming:

len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long()  # num_samples x num_thresholds
unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t)
return bins.reshape(len_t, 2, 2)

Computing preds_t and unique_mapping is quite fast, but both consume large amounts of memory (number of samples * number of thresholds * 8 byte). _bincount, on the other hand, does not consume much memory, but a lot of time.

Iterating over the thresholds one by one is significantly faster (0.6 sec instead of 10 sec) and only consumes 150 MB memory:

confmat = thresholds.new_empty((len(thresholds), 2, 2), dtype=torch.int64)
for i in range(len(thresholds)):
    preds_t = (preds >= thresholds[i]).long()
    unique_mapping = preds_t + 2 * targets
    bins = torch.bincount(unique_mapping, minlength=4)
    confmat[i] = bins.reshape(2, 2)

Using the implementation of BinnedPrecisionRecallCurve from torchmetrics 0.9 is even faster (0.1 sec) and uses only 60 MB:

confmat = thresholds.new_empty((len(thresholds), 2, 2), dtype=torch.int64)
targets_t = targets == 1
for i in range(len(thresholds)):
    preds_t = preds >= thresholds[i]
    confmat[i, 1, 1] = (targets_t & preds_t).sum()
    confmat[i, 0, 1] = ((~targets_t) & preds_t).sum()
    confmat[i, 1, 0] = (targets_t & (~preds_t)).sum()
    confmat[i, 0, 0] = len(preds_t) - confmat[i, 0, 1] - confmat[i, 1, 0] - confmat[i, 1, 1]
@Callidior Callidior added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 9, 2023
@Callidior
Copy link
Contributor Author

I conducted some further experiments with respect to the number of samples, comparing the implementation of the BinaryAveragePrecision update step from torchmetrics 0.11 with that of torchmetrics 0.9:

#Samples 0.11 Speed 0.9 Speed 0.11 Memory 0.9 Memory
100 0.2 ms 23.8 ms < 1 MB < 1 MB
1,000 1.1 ms 24.4 ms 20 MB < 1 MB
25,000 24.0 ms 25.8 ms 200 MB < 1 MB
100,000 115.3 ms 25.5 ms 722 MB 2 MB
1,000,000 1,012.1 ms 26.7 ms 7,156 MB 2 MB

For small numbers of samples, the new implementation is in fact faster, but never more memory-efficient.
However, it's time consumption scales linearly with the number of samples, while the old implementation is almost constant with respect to the number of samples. It will scale with the number of thresholds though (I used 200 here). This makes it difficult to switch between implementations depending on the number of samples.

I personally would be in favor of switching back to the old implementation. It is not the fastest in all scenarios, but the most memory efficient and does not get terribly slow for large numbers of samples.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority Critical task/issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants