Skip to content

Commit

Permalink
clean API interface; working on tests for iou_type SEGM
Browse files Browse the repository at this point in the history
  • Loading branch information
gianscarpe committed Mar 17, 2022
1 parent 35fee46 commit c470675
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 38 deletions.
20 changes: 20 additions & 0 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,26 @@ def test_empty_metric():
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_segm_iou_empty_mask():
"""Test empty ground truths."""
metric = MeanAveragePrecision(iou_type="segm")

metric.update(
[
dict(
masks=torch.randint(0, 1, (1, 10, 10)),
scores=torch.Tensor([0.5]),
labels=torch.IntTensor([4]),
),
],
[
dict(masks=torch.Tensor([]), labels=torch.IntTensor([])),
],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
Expand Down
60 changes: 22 additions & 38 deletions torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,11 @@ class MeanAveragePrecision(Metric):
If ``class_metrics`` is not a boolean
"""

detection_boxes: List[Tensor]
detections: List[Tensor]
detection_scores: List[Tensor]
detection_labels: List[Tensor]
groundtruth_boxes: List[Tensor]
groundtruths: List[Tensor]
groundtruth_labels: List[Tensor]
groundtruth_masks: List[Tensor]
detection_masks: List[Tensor]

def __init__(
self,
Expand Down Expand Up @@ -303,13 +301,11 @@ def __init__(
raise ValueError("Expected argument `class_metrics` to be a boolean")

self.class_metrics = class_metrics
self.add_state("detection_boxes", default=[], dist_reduce_fx=None)
self.add_state("detections", default=[], dist_reduce_fx=None)
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
self.add_state("detection_labels", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None)
self.add_state("groundtruths", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None)
self.add_state("detection_masks", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_masks", default=[], dist_reduce_fx=None)

def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore
"""Add detections and ground truth to the metric.
Expand Down Expand Up @@ -354,32 +350,29 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
ValueError:
If any score is not type float and of length 1
"""
_input_validator(preds, target)
_input_validator(preds, target, iou_type=self.iou_type)

for item in preds:
boxes, masks = self._get_safe_item_values(item)
self.detection_boxes.append(boxes)
detections = self._get_safe_item_values(item)
self.detections.append(detections)
self.detection_labels.append(item["labels"])
self.detection_scores.append(item["scores"])
self.detection_masks.append(masks)

for item in target:
boxes, masks = self._get_safe_item_values(item)
self.groundtruth_boxes.append(boxes)
groundtruths = self._get_safe_item_values(item)
self.groundtruths.append(groundtruths)
self.groundtruth_labels.append(item["labels"])
self.groundtruth_masks.append(masks)

def _get_safe_item_values(self, item):
if self.iou_type == "bbox":
boxes = _fix_empty_tensors(item["boxes"])
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
masks = _fix_empty_tensors(torch.Tensor())
elif self.iou_type == "masks":
return boxes
elif self.iou_type == "segm":
masks = _fix_empty_tensors(item["masks"])
boxes = _fix_empty_tensors(torch.Tensor())
return masks
else:
raise Exception(f"IOU type {self.iou_type} is not supported")
return boxes, masks

def _get_classes(self) -> List:
"""Returns a list of unique classes found in ground truth and detection data."""
Expand All @@ -388,18 +381,11 @@ def _get_classes(self) -> List:
return []

def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
return self._compute_iou_impl(id, self.groundtruth_boxes, self.detection_boxes, class_id, max_det, box_iou)
iou_func = box_iou if self.iou_type == "bbox" else segm_iou

# if self.iou_type == "segm":
# return self._compute_iou_impl(id, self.groundtruth_masks, self.detection_masks, class_id, max_det, segm_iou)
# elif self.iou_type == "bbox":
return self._compute_iou_impl(id, class_id, max_det, iou_func)

# else:
# raise Exception(f"IOU type {self.iou_type} is not supported")

def _compute_iou_impl(
self, id: int, ground_truths, detections, class_id: int, max_det: int, compute_iou: Callable
) -> Tensor:
def _compute_iou_impl(self, id: int, class_id: int, max_det: int, compute_iou: Callable) -> Tensor:
"""Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given
image and class.
Expand All @@ -412,8 +398,8 @@ def _compute_iou_impl(
Maximum number of evaluated detection bounding boxes
"""
# if self.iou_type == "bbox":
gt = self.groundtruth_boxes[id]
det = self.detection_boxes[id]
gt = self.groundtruths[id]
det = self.detections[id]

gt_label_mask = self.groundtruth_labels[id] == class_id
det_label_mask = self.detection_labels[id] == class_id
Expand Down Expand Up @@ -452,8 +438,8 @@ def _evaluate_image(
ious:
IoU results for image and class.
"""
gt = self.groundtruth_boxes[id]
det = self.detection_boxes[id]
gt = self.groundtruths[id]
det = self.detections[id]
gt_label_mask = self.groundtruth_labels[id] == class_id
det_label_mask = self.detection_labels[id] == class_id
if len(gt_label_mask) == 0 or len(det_label_mask) == 0:
Expand Down Expand Up @@ -595,7 +581,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
class_ids:
List of label class Ids.
"""
img_ids = range(len(self.groundtruth_boxes))
img_ids = range(len(self.groundtruths))
max_detections = self.max_detection_thresholds[-1]
area_ranges = self.bbox_area_ranges.values()

Expand Down Expand Up @@ -766,13 +752,11 @@ def compute(self) -> dict:
"""

# move everything to CPU, as we are faster here
self.detection_boxes = [box.cpu() for box in self.detection_boxes]
self.detections = [box.cpu() for box in self.detections]
self.detection_labels = [label.cpu() for label in self.detection_labels]
self.detection_scores = [score.cpu() for score in self.detection_scores]
self.groundtruth_boxes = [box.cpu() for box in self.groundtruth_boxes]
self.groundtruths = [box.cpu() for box in self.groundtruths]
self.groundtruth_labels = [label.cpu() for label in self.groundtruth_labels]
self.groundtruth_masks = [box.cpu() for box in self.groundtruth_masks]
self.detection_masks = [label.cpu() for label in self.detection_masks]

classes = self._get_classes()
precisions, recalls = self._calculate(classes)
Expand Down

0 comments on commit c470675

Please sign in to comment.