Skip to content

Commit

Permalink
MAP: change bool to float32 (#1150)
Browse files Browse the repository at this point in the history
* Change bool to float32.

* add test case

* changelog

Co-authored-by: André Aquilina <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
(cherry picked from commit d3c4c82)
  • Loading branch information
dreaquil authored and Borda committed Oct 21, 2022
1 parent d84736b commit e491f47
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed unintentional downloading of `nltk.punkt` when `lsum` not in `rouge_keys` ([#1258](https://github.com/Lightning-AI/metrics/pull/1258))


- Fixed type casting in `MAP` metric between `bool` and `float32` ([#1150](https://github.com/Lightning-AI/metrics/pull/1150))


## [0.10.0] - 2022-10-04

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def __evaluate_image_gt_no_preds(
return {
"dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device),
"gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device),
"dtScores": torch.zeros(nb_det, dtype=torch.bool, device=self.device),
"dtScores": torch.zeros(nb_det, dtype=torch.float32, device=self.device),
"gtIgnore": gt_ignore,
"dtIgnore": det_ignore,
}
Expand Down
16 changes: 15 additions & 1 deletion tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,31 @@

# Test empty preds case, to ensure bool inputs are properly casted to uint8
# From https://github.com/Lightning-AI/metrics/issues/981
# and https://github.com/Lightning-AI/metrics/issues/1147
_inputs3 = Input(
preds=[
[
dict(
boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]),
scores=Tensor([0.536]),
labels=IntTensor([0]),
),
],
[
dict(boxes=Tensor([]), scores=Tensor([]), labels=Tensor([])),
],
],
target=[
[
dict(
boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]),
labels=IntTensor([0]),
)
],
[
dict(
boxes=Tensor([[1.0, 2.0, 3.0, 4.0]]),
scores=Tensor([0.8]),
scores=Tensor([0.8]), # target does not have scores
labels=Tensor([1]),
),
],
Expand Down

0 comments on commit e491f47

Please sign in to comment.