Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix classwise computation in IoU metric #1924

Merged
merged 22 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**

## [UnReleased] - 2023-MM-DD

### Added

-


### Changed

-


### Removed

-


### Fixed

- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924))


## [1.1.0] - 2023-08-22

Expand All @@ -21,7 +42,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983))
- Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001))
- Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931))
Borda marked this conversation as resolved.
Show resolved Hide resolved
- Added new property `metric_state` to all metrics for users to investigate currently stored tensors in memory ([#2006](https://github.com/Lightning-AI/torchmetrics/pull/2006))


## [1.0.3] - 2023-08-08
Expand Down
48 changes: 10 additions & 38 deletions src/torchmetrics/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ class IntersectionOverUnion(Metric):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.
- labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for
the boxes.

- ``target`` (:class:`~List`): A list consisting of dictionaries each containing the key-values
(each dictionary corresponds to a single image). Parameters that should be provided per dict:
Expand Down Expand Up @@ -85,7 +83,6 @@ class IntersectionOverUnion(Metric):
>>> preds = [
... {
... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
... "scores": torch.tensor([0.236, 0.56]),
... "labels": torch.tensor([4, 5]),
... }
... ]
Expand All @@ -108,13 +105,8 @@ class IntersectionOverUnion(Metric):
higher_is_better: Optional[bool] = True
full_state_update: bool = True

detections: List[Tensor]
detection_scores: List[Tensor]
detection_labels: List[Tensor]
groundtruths: List[Tensor]
groundtruth_labels: List[Tensor]
results: List[Tensor]
labels_eq: List[Tensor]
_iou_type: str = "iou"
_invalid_val: float = 0.0

Expand Down Expand Up @@ -149,13 +141,8 @@ def __init__(
raise ValueError("Expected argument `respect_labels` to be a boolean")
self.respect_labels = respect_labels

self.add_state("detections", default=[], dist_reduce_fx=None)
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
self.add_state("detection_labels", default=[], dist_reduce_fx=None)
self.add_state("groundtruths", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None)
self.add_state("results", default=[], dist_reduce_fx=None)
self.add_state("labels_eq", default=[], dist_reduce_fx=None)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
Expand Down Expand Up @@ -192,24 +179,16 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]

for p, t in zip(preds, target):
det_boxes = self._get_safe_item_values(p["boxes"])
self.detections.append(det_boxes)
self.detection_labels.append(p["labels"])
self.detection_scores.append(p["scores"])

gt_boxes = self._get_safe_item_values(t["boxes"])
self.groundtruths.append(gt_boxes)
self.groundtruth_labels.append(t["labels"])

label_eq = torch.equal(p["labels"], t["labels"])
# Workaround to persist state, which only works with tensors
self.labels_eq.append(torch.tensor([label_eq], dtype=torch.int, device=self.device))

ious = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val)
if self.respect_labels and not label_eq:
label_diff = p["labels"].unsqueeze(0).T - t["labels"].unsqueeze(0)
labels_not_eq = label_diff != 0.0
ious[labels_not_eq] = self._invalid_val
self.results.append(ious.to(dtype=torch.float, device=self.device))
self.results.append(ious.diag().to(dtype=torch.float, device=self.device))

def _get_safe_item_values(self, boxes: Tensor) -> Tensor:
boxes = _fix_empty_tensors(boxes)
Expand All @@ -225,22 +204,15 @@ def _get_gt_classes(self) -> List:

def compute(self) -> dict:
"""Computes IoU based on inputs passed in to ``update`` previously."""
aggregated_iou = dim_zero_cat(
[self._iou_compute_fn(iou, bool(lbl_eq)) for iou, lbl_eq in zip(self.results, self.labels_eq)]
)
results: Dict[str, Tensor] = {f"{self._iou_type}": aggregated_iou.mean()}
ious = dim_zero_cat(self.results)
labels = dim_zero_cat(self.groundtruth_labels)
results: Dict[str, Tensor] = {f"{self._iou_type}": self._iou_compute_fn(ious, False)}

if self.class_metrics:
class_results: Dict[int, List[Tensor]] = defaultdict(list)
for iou, label in zip(self.results, self.groundtruth_labels):
for cl in self._get_gt_classes():
masked_iou = iou[:, label == cl]
if masked_iou.numel() > 0:
class_results[cl].append(self._iou_compute_fn(masked_iou, False))

results.update(
{f"{self._iou_type}/cl_{cl}": dim_zero_cat(class_results[cl]).mean() for cl in class_results}
)
for cl in self._get_gt_classes():
masked_iou = ious[labels == cl]
results.update({f"{self._iou_type}/cl_{cl}": self._iou_compute_fn(masked_iou, False)})

return results

def plot(
Expand Down
58 changes: 47 additions & 11 deletions src/torchmetrics/functional/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def _ciou_update(
iou = complete_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
return iou
return iou.diag()


def _ciou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor:
if labels_eq:
return iou.diag().mean()
return iou.mean()
def _ciou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor:
if not aggregate:
return iou
return iou.mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device)


def complete_intersection_over_union(
Expand All @@ -62,15 +62,51 @@ def complete_intersection_over_union(
replacement_val:
Value to replace values under the threshold with.
aggregate:
Return the average value instead of the complete IoU matrix.
Return the average value instead of the per box pair IoU value.

Example::
By default iou is aggregated across all box pairs:

Example:
>>> import torch
>>> from torchmetrics.functional.detection import complete_intersection_over_union
>>> preds = torch.Tensor([[100, 100, 200, 200]])
>>> target = torch.Tensor([[110, 110, 210, 210]])
>>> preds = torch.tensor(
... [
... [296.55, 93.96, 314.97, 152.79],
... [328.94, 97.05, 342.49, 122.98],
... [356.62, 95.47, 372.33, 147.55],
... ]
... )
>>> target = torch.tensor(
... [
... [300.00, 100.00, 315.00, 150.00],
... [330.00, 100.00, 350.00, 125.00],
... [350.00, 100.00, 375.00, 150.00],
... ]
... )
>>> complete_intersection_over_union(preds, target)
tensor(0.6724)
tensor(0.5790)

Example::
By setting `aggregate=False` the IoU score per prediction and target boxes is returned:

>>> import torch
>>> from torchmetrics.functional.detection import complete_intersection_over_union
>>> preds = torch.tensor(
... [
... [296.55, 93.96, 314.97, 152.79],
... [328.94, 97.05, 342.49, 122.98],
... [356.62, 95.47, 372.33, 147.55],
... ]
... )
>>> target = torch.tensor(
... [
... [300.00, 100.00, 315.00, 150.00],
... [330.00, 100.00, 350.00, 125.00],
... [350.00, 100.00, 375.00, 150.00],
... ]
... )
>>> complete_intersection_over_union(preds, target, aggregate=False)
tensor([0.6883, 0.4881, 0.5606])

"""
if not _TORCHVISION_GREATER_EQUAL_0_13:
Expand All @@ -80,4 +116,4 @@ def complete_intersection_over_union(
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
)
iou = _ciou_update(preds, target, iou_threshold, replacement_val)
return _ciou_compute(iou) if aggregate else iou
return _ciou_compute(iou, aggregate)
58 changes: 47 additions & 11 deletions src/torchmetrics/functional/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def _diou_update(
iou = distance_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
return iou
return iou.diag()


def _diou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor:
if labels_eq:
return iou.diag().mean()
return iou.mean()
def _diou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor:
if not aggregate:
return iou
return iou.mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device)


def distance_intersection_over_union(
Expand All @@ -62,15 +62,51 @@ def distance_intersection_over_union(
replacement_val:
Value to replace values under the threshold with.
aggregate:
Return the average value instead of the complete IoU matrix.
Return the average value instead of the per box pair IoU value.

Example::
By default iou is aggregated across all box pairs:

Example:
>>> import torch
>>> from torchmetrics.functional.detection import distance_intersection_over_union
>>> preds = torch.Tensor([[100, 100, 200, 200]])
>>> target = torch.Tensor([[110, 110, 210, 210]])
>>> preds = torch.tensor(
... [
... [296.55, 93.96, 314.97, 152.79],
... [328.94, 97.05, 342.49, 122.98],
... [356.62, 95.47, 372.33, 147.55],
... ]
... )
>>> target = torch.tensor(
... [
... [300.00, 100.00, 315.00, 150.00],
... [330.00, 100.00, 350.00, 125.00],
... [350.00, 100.00, 375.00, 150.00],
... ]
... )
>>> distance_intersection_over_union(preds, target)
tensor(0.6724)
tensor(0.5793)

Example::
By setting `aggregate=False` the IoU score per prediction and target boxes is returned:

>>> import torch
>>> from torchmetrics.functional.detection import distance_intersection_over_union
>>> preds = torch.tensor(
... [
... [296.55, 93.96, 314.97, 152.79],
... [328.94, 97.05, 342.49, 122.98],
... [356.62, 95.47, 372.33, 147.55],
... ]
... )
>>> target = torch.tensor(
... [
... [300.00, 100.00, 315.00, 150.00],
... [330.00, 100.00, 350.00, 125.00],
... [350.00, 100.00, 375.00, 150.00],
... ]
... )
>>> distance_intersection_over_union(preds, target, aggregate=False)
tensor([0.6883, 0.4886, 0.5609])

"""
if not _TORCHVISION_GREATER_EQUAL_0_13:
Expand All @@ -80,4 +116,4 @@ def distance_intersection_over_union(
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
)
iou = _diou_update(preds, target, iou_threshold, replacement_val)
return _diou_compute(iou) if aggregate else iou
return _diou_compute(iou, aggregate)
58 changes: 47 additions & 11 deletions src/torchmetrics/functional/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def _giou_update(
iou = generalized_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
return iou
return iou.diag()


def _giou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor:
if labels_eq:
return iou.diag().mean()
return iou.mean()
def _giou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor:
if not aggregate:
return iou
return iou.mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device)


def generalized_intersection_over_union(
Expand All @@ -62,15 +62,51 @@ def generalized_intersection_over_union(
replacement_val:
Value to replace values under the threshold with.
aggregate:
Return the average value instead of the complete IoU matrix.
Return the average value instead of the per box pair IoU value.

Example::
By default iou is aggregated across all box pairs:

Example:
>>> import torch
>>> from torchmetrics.functional.detection import generalized_intersection_over_union
>>> preds = torch.Tensor([[100, 100, 200, 200]])
>>> target = torch.Tensor([[110, 110, 210, 210]])
>>> preds = torch.tensor(
... [
... [296.55, 93.96, 314.97, 152.79],
... [328.94, 97.05, 342.49, 122.98],
... [356.62, 95.47, 372.33, 147.55],
... ]
... )
>>> target = torch.tensor(
... [
... [300.00, 100.00, 315.00, 150.00],
... [330.00, 100.00, 350.00, 125.00],
... [350.00, 100.00, 375.00, 150.00],
... ]
... )
>>> generalized_intersection_over_union(preds, target)
tensor(0.6641)
tensor(0.5638)

Example::
By setting `aggregate=False` the IoU score per prediction and target boxes is returned:

>>> import torch
>>> from torchmetrics.functional.detection import generalized_intersection_over_union
>>> preds = torch.tensor(
... [
... [296.55, 93.96, 314.97, 152.79],
... [328.94, 97.05, 342.49, 122.98],
... [356.62, 95.47, 372.33, 147.55],
... ]
... )
>>> target = torch.tensor(
... [
... [300.00, 100.00, 315.00, 150.00],
... [330.00, 100.00, 350.00, 125.00],
... [350.00, 100.00, 375.00, 150.00],
... ]
... )
>>> generalized_intersection_over_union(preds, target, aggregate=False)
tensor([0.6895, 0.4673, 0.5345])

"""
if not _TORCHVISION_GREATER_EQUAL_0_8:
Expand All @@ -80,4 +116,4 @@ def generalized_intersection_over_union(
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
)
iou = _giou_update(preds, target, iou_threshold, replacement_val)
return _giou_compute(iou) if aggregate else iou
return _giou_compute(iou, aggregate)
Loading