Skip to content

Commit

Permalink
MAP metric - fix empty predictions (#594)
Browse files Browse the repository at this point in the history
* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Oct 29, 2021
1 parent e44b51b commit 1c42f66
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594))


## [0.6.0] - 2021-10-28

### Added

- Added audio metrics:
- Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))
- Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))
- Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/pull/353))
- Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/pull/353))
- Added Information retrieval metrics:
- `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/))
- `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577))
- `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576))
- Added NLP metrics:
- `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546))
Expand Down
17 changes: 17 additions & 0 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,23 @@ def test_error_on_wrong_init():
MAP(class_metrics=0)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""

metric = MAP()

metric.update(
[
dict(boxes=torch.Tensor([[]]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _get_coco_format(
annotations = []
annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong

boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") for box in boxes]
boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if boxes[0].size(1) == 4 else box for box in boxes]
for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)):
image_boxes = image_boxes.cpu().tolist()
image_labels = image_labels.cpu().tolist()
Expand Down

0 comments on commit 1c42f66

Please sign in to comment.