Skip to content

Commit

Permalink
Vectorize _find_best_gt_match in MeanAveragePrecision (#1259)
Browse files Browse the repository at this point in the history
* refactor: Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match`
* Put back removed condition
* Cast best_gt_matches to int
* Further vectorize (one missing)
* Partially vectorize gt_matches calc
* Fix corner cases with missing gt and pred
  • Loading branch information
stancld authored Oct 12, 2022
1 parent 173e88e commit eacfd2f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 30 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))


- Changed minimum Pytorch version to be 1.8 ([#1263](https://github.com/Lightning-AI/metrics/pull/1263))


Expand Down
49 changes: 19 additions & 30 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,6 @@ def _evaluate_image(
det = [det[i] for i in det_label_mask]
if len(gt) == 0 and len(det) == 0:
return None
if isinstance(det, dict):
det = [det]
if isinstance(gt, dict):
gt = [gt]

areas = compute_area(gt, iou_type=self.iou_type).to(self.device)

Expand All @@ -604,23 +600,23 @@ def _evaluate_image(
# load computed ious
ious = ious[idx, class_id][:, gtind] if len(ious[idx, class_id]) > 0 else ious[idx, class_id]

nb_iou_thrs = len(self.iou_thresholds)
nb_gt = len(gt)
nb_det = len(det)
gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device)
det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
gt_ignore = ignore_area_sorted
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)

iou_thresholds = torch.tensor(self.iou_thresholds, device=self.device)

if torch.numel(ious) > 0:
for idx_iou, t in enumerate(self.iou_thresholds):
for idx_det, _ in enumerate(det):
m = MeanAveragePrecision._find_best_gt_match(t, gt_matches, idx_iou, gt_ignore, ious, idx_det)
if m == -1:
continue
det_ignore[idx_iou, idx_det] = gt_ignore[m]
det_matches[idx_iou, idx_det] = 1
gt_matches[idx_iou, m] = 1
best_matches = self._find_best_gt_matches(iou_thresholds, gt_matches, gt_ignore, ious)
_zero_tensor = torch.tensor(0, dtype=torch.bool, device=self.device)
_one_tensor = torch.tensor(1, dtype=torch.bool, device=self.device)
det_ignore = torch.where(
best_matches != -1, gt_ignore[best_matches.clamp(max=gt_ignore.shape[0] - 1)], _zero_tensor
)
det_matches = torch.where(best_matches != -1, _one_tensor, _zero_tensor)
for idx in range(nb_iou_thrs):
gt_matches[idx, best_matches[idx].clamp(0, max=gt_matches.shape[1] - 1).unique()] = 1

# set unmatched detections outside of area range to ignore
det_areas = compute_area(det, iou_type=self.iou_type).to(self.device)
Expand All @@ -639,33 +635,26 @@ def _evaluate_image(
}

@staticmethod
def _find_best_gt_match(
thr: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
) -> int:
"""Return id of best ground truth match with current detection.
def _find_best_gt_matches(thr: Tensor, gt_matches: Tensor, gt_ignore: Tensor, ious: Tensor) -> Tensor:
"""Return matrix of indices of best ground truth match with current detection.
Args:
thr:
Current threshold value.
gt_matches:
Tensor showing if a ground truth matches for threshold ``t`` exists.
idx_iou:
Id of threshold ``t``.
gt_ignore:
Tensor showing if ground truth should be ignored.
ious:
IoUs for all combinations of detection and ground truth.
idx_det:
Id of current detection.
"""
previously_matched = gt_matches[idx_iou]
# Remove previously matched or ignored gts
remove_mask = previously_matched | gt_ignore
gt_ious = ious[idx_det] * ~remove_mask
match_idx = gt_ious.argmax().item()
if gt_ious[match_idx] > thr:
return match_idx
return -1
remove_mask = gt_matches | gt_ignore
gt_ious = torch.einsum("cw,dw->cdw", ~remove_mask, ious).max(-1).values
best_gt_matches = gt_ious.where(
gt_ious > thr.unsqueeze(-1), torch.tensor(-1, dtype=gt_ious.dtype, device=gt_ious.device)
)
return best_gt_matches.long()

def _summarize(
self,
Expand Down

0 comments on commit eacfd2f

Please sign in to comment.