Skip to content

Commit

Permalink
Fix MAP metric for empty cases (#624)
Browse files Browse the repository at this point in the history
* add detection map code example

* update setup

* simplify named tuple

* Update tm_examples/detection_map.py

* Update tm_examples/detection_map.py

* Update tm_examples/detection_map.py

* Update tm_examples/detection_map.py

* Update tm_examples/detection_map.py

* Update tm_examples/detection_map.py

* add some more comments

* add some more comments

* add example hint in metric docstring

* fix evaluation for empty metric

* Update torchmetrics/detection/map.py

* fix deepsource stuff

* update changelog

* fix ddp issue in multi GPU setup for empty boxes

* update doc

* simplify empty tensors fix

* fix mypy

* fix failing unittests

Co-authored-by: Jirka <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit 3aef7be)
  • Loading branch information
tkupek authored and Borda committed Dec 5, 2021
1 parent 224a6d7 commit c66083d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594))
- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594), [#624](https://github.com/PyTorchLightning/metrics/pull/624))


- Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606))
Expand Down
20 changes: 16 additions & 4 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class TestMAP(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
def test_map(self, ddp):
"""Test modular implementation for correctness."""

self.run_class_metric_test(
ddp=ddp,
preds=_inputs.preds,
Expand All @@ -198,7 +197,6 @@ def test_map(self, ddp):
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""

MAP() # no error

with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"):
Expand All @@ -208,7 +206,6 @@ def test_error_on_wrong_init():
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""

metric = MAP()

metric.update(
Expand All @@ -219,13 +216,28 @@ def test_empty_preds():
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)

metric.update(
[
dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_metric():
"""Test empty metric."""
metric = MAP()
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""

metric = MAP()

metric.update([], []) # no error
Expand Down
31 changes: 21 additions & 10 deletions torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore

def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> None:
"""Ensure the correct input format of `preds` and `targets`"""

if not isinstance(preds, Sequence):
raise ValueError("Expected argument `preds` to be of type List")
if not isinstance(targets, Sequence):
Expand Down Expand Up @@ -139,6 +138,13 @@ def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[st
)


def _fix_empty_tensors(boxes: torch.Tensor) -> torch.Tensor:
"""Empty tensors can cause problems in DDP mode, this methods corrects them."""
if boxes.numel() == 0 and boxes.ndim == 1:
return boxes.unsqueeze(0)
return boxes


class MAP(Metric):
r"""
Computes the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)\
Expand Down Expand Up @@ -273,12 +279,12 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
_input_validator(preds, target)

for item in preds:
self.detection_boxes.append(item["boxes"])
self.detection_boxes.append(_fix_empty_tensors(item["boxes"]))
self.detection_scores.append(item["scores"])
self.detection_labels.append(item["labels"])

for item in target:
self.groundtruth_boxes.append(item["boxes"])
self.groundtruth_boxes.append(_fix_empty_tensors(item["boxes"]))
self.groundtruth_labels.append(item["labels"])

def compute(self) -> dict:
Expand Down Expand Up @@ -325,7 +331,7 @@ def compute(self) -> dict:
if self.class_metrics:
map_per_class_list = []
mar_100_per_class_list = []
for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist():
for class_id in self._get_classes():
coco_eval.params.catIds = [class_id]
with _hide_prints():
coco_eval.evaluate()
Expand Down Expand Up @@ -363,12 +369,14 @@ def _get_coco_format(
Format is defined at https://cocodataset.org/#format-data
"""

images = []
annotations = []
annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong

boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if box.size(1) == 4 else box for box in boxes]
boxes = [
box_convert(box, in_fmt="xyxy", out_fmt="xywh") if box.ndim > 1 and box.size(1) == 4 else box
for box in boxes
]
for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)):
image_boxes = image_boxes.cpu().tolist()
image_labels = image_labels.cpu().tolist()
Expand Down Expand Up @@ -405,8 +413,11 @@ def _get_coco_format(
annotations.append(annotation)
annotation_id += 1

classes = [
{"id": i, "name": str(i)}
for i in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist()
]
classes = [{"id": i, "name": str(i)} for i in self._get_classes()]
return {"images": images, "annotations": annotations, "categories": classes}

def _get_classes(self) -> list:
"""Get list of unique classes depending on groundtruth_labels and detection_labels."""
if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0:
return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist()
return []

0 comments on commit c66083d

Please sign in to comment.