Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix classwise computation in IoU metric #1924

Merged
merged 22 commits into from
Aug 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -32,8 +32,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)


- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924))


- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)


## [1.1.0] - 2023-08-22

### Added
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
export FREEZE_REQUIREMENTS=1
# assume you have installed need packages
export SPHINX_MOCK_REQUIREMENTS=1
export SPHINX_FETCH_ASSETS=0

clean:
# clean all temp runs
14 changes: 8 additions & 6 deletions src/torchmetrics/detection/ciou.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,6 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.

@@ -48,14 +46,14 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground
truth boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed ground truth
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``ciou_dict``: A dictionary containing the following key-values:

- ciou: (:class:`~torch.Tensor`)
- ciou: (:class:`~torch.Tensor`) with overall ciou value over all classes and samples.
- ciou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``

Args:
@@ -65,6 +63,9 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
class_metrics:
Option to enable per-class metrics for IoU. Has a performance impact.
respect_labels:
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

@@ -86,7 +87,7 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
... ]
>>> metric = CompleteIntersectionOverUnion()
>>> metric(preds, target)
{'ciou': tensor(-0.5694)}
{'ciou': tensor(0.8611)}

Raises:
ModuleNotFoundError:
@@ -105,14 +106,15 @@ def __init__(
box_format: str = "xyxy",
iou_threshold: Optional[float] = None,
class_metrics: bool = False,
respect_labels: bool = True,
**kwargs: Any,
) -> None:
if not _TORCHVISION_GREATER_EQUAL_0_13:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
12 changes: 7 additions & 5 deletions src/torchmetrics/detection/diou.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,6 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.

@@ -55,7 +53,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):

- ``diou_dict``: A dictionary containing the following key-values:

- diou: (:class:`~torch.Tensor`)
- diou: (:class:`~torch.Tensor`) with overall diou value over all classes and samples.
- diou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``

Args:
@@ -65,6 +63,9 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
class_metrics:
Option to enable per-class metrics for IoU. Has a performance impact.
respect_labels:
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

@@ -86,7 +87,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
... ]
>>> metric = DistanceIntersectionOverUnion()
>>> metric(preds, target)
{'diou': tensor(-0.0694)}
{'diou': tensor(0.8611)}

Raises:
ModuleNotFoundError:
@@ -105,14 +106,15 @@ def __init__(
box_format: str = "xyxy",
iou_threshold: Optional[float] = None,
class_metrics: bool = False,
respect_labels: bool = True,
**kwargs: Any,
) -> None:
if not _TORCHVISION_GREATER_EQUAL_0_13:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
12 changes: 7 additions & 5 deletions src/torchmetrics/detection/giou.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,6 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.

@@ -55,7 +53,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):

- ``giou_dict``: A dictionary containing the following key-values:

- giou: (:class:`~torch.Tensor`)
- giou: (:class:`~torch.Tensor`) with overall giou value over all classes and samples.
- giou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True``

Args:
@@ -65,6 +63,9 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
class_metrics:
Option to enable per-class metrics for IoU. Has a performance impact.
respect_labels:
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

@@ -86,7 +87,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
... ]
>>> metric = GeneralizedIntersectionOverUnion()
>>> metric(preds, target)
{'giou': tensor(-0.0694)}
{'giou': tensor(0.8613)}

Raises:
ModuleNotFoundError:
@@ -105,9 +106,10 @@ def __init__(
box_format: str = "xyxy",
iou_threshold: Optional[float] = None,
class_metrics: bool = False,
respect_labels: bool = True,
**kwargs: Any,
) -> None:
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
7 changes: 5 additions & 2 deletions src/torchmetrics/detection/helpers.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ def _input_validator(
preds: Sequence[Dict[str, Tensor]],
targets: Sequence[Dict[str, Tensor]],
iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox",
ignore_score: bool = False,
) -> None:
"""Ensure the correct input format of `preds` and `targets`."""
if isinstance(iou_type, str):
@@ -39,7 +40,7 @@ def _input_validator(
f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}"
)

for k in [*item_val_name, "scores", "labels"]:
for k in [*item_val_name, "labels"] + (["scores"] if not ignore_score else []):
if any(k not in p for p in preds):
raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key")

@@ -50,7 +51,7 @@ def _input_validator(
for ivn in item_val_name:
if any(type(pred[ivn]) is not Tensor for pred in preds):
raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor")
if any(type(pred["scores"]) is not Tensor for pred in preds):
if not ignore_score and any(type(pred["scores"]) is not Tensor for pred in preds):
raise ValueError("Expected all scores in `preds` to be of type Tensor")
if any(type(pred["labels"]) is not Tensor for pred in preds):
raise ValueError("Expected all labels in `preds` to be of type Tensor")
@@ -67,6 +68,8 @@ def _input_validator(
f"Input '{ivn}' and labels of sample {i} in targets have a"
f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})"
)
if ignore_score:
return
for i, item in enumerate(preds):
for ivn in item_val_name:
if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)):
Loading