Skip to content

Commit

Permalink
simple
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Apr 24, 2022
1 parent 6f64e79 commit 98c517a
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,9 @@ def __calculate_recall_precision_scores(
# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
# Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0)
if det_scores.is_cuda and det_scores.dtype is torch.bool:
# Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort
inds = torch.argsort(det_scores.to(torch.uint8), descending=True)
else:
inds = torch.argsort(det_scores, descending=True)
dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype
# Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort
inds = torch.argsort(det_scores.to(dtype), descending=True)
det_scores_sorted = det_scores[inds]

det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
Expand Down

0 comments on commit 98c517a

Please sign in to comment.