Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return also classes for MAP metric #1419

Merged
merged 9 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .azure/gpu-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
ls -lh $(HF_CACHE_DIR) # show what was restored...
displayName: 'Show HF cache'

- bash: python -m pytest torchmetrics --cov=torchmetrics --timeout=120 --durations=50
- bash: python -m pytest torchmetrics --cov=torchmetrics --timeout=150 --durations=50
workingDirectory: src
displayName: 'DocTesting'

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328))

- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))

### Changed

Expand Down
8 changes: 5 additions & 3 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __delattr__(self, key: str) -> None:
class MAPMetricResults(BaseMetricResults):
"""Class to wrap the final mAP results."""

__slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large")
__slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large", "classes")


class MARMetricResults(BaseMetricResults):
Expand Down Expand Up @@ -248,6 +248,7 @@ class MeanAveragePrecision(Metric):
- map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds)
- map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled)
- mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled)
- classes (:class:`~torch.Tensor`)

For an example on how to use this metric check the `torchmetrics examples
<https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py>`_
Expand Down Expand Up @@ -332,7 +333,8 @@ class MeanAveragePrecision(Metric):
>>> metric.update(preds, target)
>>> from pprint import pprint
>>> pprint(metric.compute())
{'map': tensor(0.6000),
{'classes': tensor(0, dtype=torch.int32),
'map': tensor(0.6000),
'map_50': tensor(1.),
'map_75': tensor(1.),
'map_large': tensor(0.6000),
Expand Down Expand Up @@ -923,5 +925,5 @@ def compute(self) -> dict:
metrics.update(mar_val)
metrics.map_per_class = map_per_class_values
metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values

metrics.classes = torch.tensor(classes, dtype=torch.int)
return metrics
2 changes: 2 additions & 0 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _compare_fn(preds, target) -> dict:
"mar_large": Tensor([0.633]),
"map_per_class": Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.556]),
"mar_100_per_class": Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.580]),
"classes": Tensor([0, 1, 2, 3, 4, 49]),
}


Expand Down Expand Up @@ -317,6 +318,7 @@ def _compare_fn_segm(preds, target) -> dict:
"mar_large": Tensor([0.35]),
"map_per_class": Tensor([0.4039604, -1.0, 0.3]),
"mar_100_per_class": Tensor([0.4, -1.0, 0.3]),
"classes": Tensor([2, 3, 4]),
}


Expand Down