Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch-based mAP #632

Merged
merged 55 commits into from
Dec 5, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
5732063
First draft
twsl Nov 22, 2021
c22220c
Update torchmetrics/detection/map_new.py
twsl Nov 22, 2021
7b1350a
Update torchmetrics/detection/map_new.py
twsl Nov 22, 2021
f6cb8f5
Remove double score
twsl Nov 22, 2021
35f61e8
Rename box_format arg
twsl Nov 22, 2021
51571b2
Calculate num_class only
twsl Nov 22, 2021
6078a14
Support empty predictions
twsl Nov 23, 2021
45c346d
Merge branch 'master' into map
twsl Nov 23, 2021
e5f3a46
Remove pycocotools from tests
twsl Nov 23, 2021
1d8e346
Fix annotation id evals to false if zero
twsl Nov 24, 2021
5fdec29
Somehow required
twsl Nov 24, 2021
0f2aa57
Simplified loops
twsl Nov 24, 2021
1193cd0
Simplify loops
twsl Nov 25, 2021
9f308a5
Fixed ordering
twsl Nov 26, 2021
013ee20
More refactoring
twsl Nov 26, 2021
67459ef
Improve method descriptions
twsl Nov 26, 2021
335c102
All tests pass
twsl Nov 26, 2021
75a5cfb
Implement _fix_empty_tensor
twsl Nov 26, 2021
08e76f2
Merge branch 'master' into map
twsl Nov 26, 2021
9f16bfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2021
041c1db
Update CHANGELOG.md
twsl Nov 26, 2021
0c8b131
Update torchmetrics/detection/map.py
twsl Nov 26, 2021
76ee752
Update torchmetrics/detection/map.py
twsl Nov 26, 2021
190688a
Update torchmetrics/detection/map.py
twsl Nov 26, 2021
290e4c0
Added more method descriptions
twsl Nov 27, 2021
41486ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2021
9347025
Fix var renaming
twsl Nov 27, 2021
79f8e49
Add description to _calculate
twsl Nov 27, 2021
7b785cd
Fix input validation
twsl Nov 27, 2021
f44a34e
Fix returning metrics if parameters are changed
twsl Nov 29, 2021
9bfb352
Merge branch 'master' into map
justusschock Nov 29, 2021
4aef9e6
Merge branch 'master' into map
Borda Nov 30, 2021
99736d5
Update torchmetrics/detection/map.py
twsl Nov 30, 2021
0d847ce
Update torchmetrics/detection/map.py
twsl Nov 30, 2021
3a9fb1d
Update torchmetrics/detection/map.py
twsl Nov 30, 2021
4c76379
Merge branch 'master' into map
Borda Dec 4, 2021
25b46a8
rename
Borda Dec 4, 2021
6f614ba
rename
Borda Dec 4, 2021
2389495
cls
Borda Dec 4, 2021
ffdf43f
inter func
Borda Dec 4, 2021
a0a380b
Apply suggestions from code review
Borda Dec 4, 2021
e976080
Apply suggestions from code review
Borda Dec 4, 2021
fae4eb9
Fix renamed error messages
twsl Dec 4, 2021
79384ed
Fix ious empty check
twsl Dec 4, 2021
5ea2909
Merge branch 'master' into map
mergify[bot] Dec 4, 2021
30ea32f
Merge branch 'map' of https://github.com/twsl/metrics into map
twsl Dec 4, 2021
f05549d
Remove invalid test case
twsl Dec 4, 2021
ceb82be
Eval as boolean
twsl Dec 4, 2021
d874945
Moved gt match function
twsl Dec 4, 2021
384a10e
mypy
Borda Dec 4, 2021
604527b
mypy
Borda Dec 4, 2021
cd2d683
tuple
Borda Dec 4, 2021
642ce43
mypy
Borda Dec 4, 2021
8fe1808
.
Borda Dec 4, 2021
f20542d
det
Borda Dec 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))


- Migrate MAP metrics from pycocotools to PyTorch ([#632](https://github.com/PyTorchLightning/metrics/pull/632))
Borda marked this conversation as resolved.
Show resolved Hide resolved
- Use `torch.topk` instead of `torch.argsort` in retrieval precision for speedup ([#627](https://github.com/PyTorchLightning/metrics/pull/627))


Expand Down
18 changes: 7 additions & 11 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@

from tests.helpers.testers import MetricTester
from torchmetrics.detection.map import MAP
from torchmetrics.utilities.imports import (
_PYCOCOTOOLS_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCHVISION_GREATER_EQUAL_0_8,
)
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8

Input = namedtuple("Input", ["preds", "target"])

Expand Down Expand Up @@ -59,7 +55,7 @@
), # coco image id 74
dict(
boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=torch.Tensor([0.699, 0.423]),
scores=torch.Tensor([0.699]),
twsl marked this conversation as resolved.
Show resolved Hide resolved
labels=torch.IntTensor([5]),
), # coco image id 133
],
Expand Down Expand Up @@ -164,10 +160,10 @@ def _compare_fn(preds, target) -> dict:
}


_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)
_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
class TestMAP(MetricTester):
"""Test the MAP metric for object detection predictions.

Expand All @@ -194,7 +190,7 @@ def test_map(self, ddp):


# noinspection PyTypeChecker
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
MAP() # no error
Expand All @@ -203,7 +199,7 @@ def test_error_on_wrong_init():
MAP(class_metrics=0)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""
metric = MAP()
Expand Down Expand Up @@ -235,7 +231,7 @@ def test_empty_metric():
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
metric = MAP()
Expand Down
Loading