Skip to content

Commit

Permalink
Feature/modified panoptic quality (#1627)
Browse files Browse the repository at this point in the history
Co-authored-by: SkafteNicki <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people authored Mar 21, 2023
1 parent 18d3dd8 commit b785d28
Show file tree
Hide file tree
Showing 12 changed files with 602 additions and 37 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485))


- Added `ModifiedPanopticQuality` metric to detection package ([#1627](https://github.com/Lightning-AI/metrics/pull/1627))


### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand Down
23 changes: 23 additions & 0 deletions docs/source/detection/modified_panoptic_quality.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Modified Panoptic Quality
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Detection

#########################
Modified Panoptic Quality
#########################

.. include:: ../links.rst

Module Interface
________________

.. autoclass:: torchmetrics.ModifiedPanopticQuality
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.modified_panoptic_quality
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@
.. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance
.. _Demographic parity: http://www.fairmlbook.org/
.. _Equal opportunity: https://proceedings.neurips.cc/paper/2016/hash/9d2682367c3935defcb1f9e247a97c0d-Abstract.html
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
3 changes: 2 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.detection import PanopticQuality # noqa: E402
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality # noqa: E402
from torchmetrics.image import ( # noqa: E402
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
Expand Down Expand Up @@ -152,6 +152,7 @@
"MetricTracker",
"MinMaxMetric",
"MinMetric",
"ModifiedPanopticQuality",
"MultioutputWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"PanopticQuality",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
if _TORCHVISION_GREATER_EQUAL_0_8:
from torchmetrics.detection.mean_ap import MeanAveragePrecision # noqa: F401

from torchmetrics.detection.modified_panoptic_quality import ModifiedPanopticQuality # noqa: F401
from torchmetrics.detection.panoptic_quality import PanopticQuality # noqa: F401
210 changes: 210 additions & 0 deletions src/torchmetrics/detection/modified_panoptic_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Collection, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.functional.detection._panoptic_quality_common import (
_get_category_id_to_continuous_id,
_get_void_color,
_panoptic_quality_compute,
_panoptic_quality_update,
_parse_categories,
_prepocess_inputs,
_validate_inputs,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ModifiedPanopticQuality.plot"]


class ModifiedPanopticQuality(Metric):
r"""Compute `Modified Panoptic Quality`_ for panoptic segmentations.
The metric was introduced in `Seamless Scene Segmentation paper`_, and is an adaptation of the original
`Panoptic Quality`_ where the metric for a stuff class is computed as
.. math::
PQ^{\dagger}_c = \frac{IOU_c}{|S_c|}
where IOU_c is the sum of the intersection over union of all matching segments for a given class, and \|S_c| is
the overall number of segments in the ground truth for that class.
.. note:
Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
computation.
Args:
things:
Set of ``category_id`` for countable things.
stuffs:
Set of ``category_id`` for uncountable stuffs.
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
Raises:
ValueError:
If ``things``, ``stuffs`` have at least one common ``category_id``.
TypeError:
If ``things``, ``stuffs`` contain non-integer ``category_id``.
Example:
>>> from torch import tensor
>>> from torchmetrics import ModifiedPanopticQuality
>>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
>>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
>>> pq_modified = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> pq_modified(preds, target)
tensor(0.7667, dtype=torch.float64)
"""
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False

iou_sum: Tensor
true_positives: Tensor
false_positives: Tensor
false_negatives: Tensor

def __init__(
self,
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

things, stuffs = _parse_categories(things, stuffs)
self.things = things
self.stuffs = stuffs
self.void_color = _get_void_color(things, stuffs)
self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
self.allow_unknown_preds_category = allow_unknown_preds_category

# per category intermediate metrics
n_categories = len(things) + len(stuffs)
self.add_state("iou_sum", default=torch.zeros(n_categories, dtype=torch.double), dist_reduce_fx="sum")
self.add_state("true_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_negatives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
r"""Update state with predictions and targets.
Args:
preds: panoptic detection of shape ``[batch, *spatial_dims, 2]`` containing
the pair ``(category_id, instance_id)`` for each point.
If the ``category_id`` refer to a stuff, the instance_id is ignored.
target: ground truth of shape ``[batch, *spatial_dims, 2]`` containing
the pair ``(category_id, instance_id)`` for each pixel of the image.
If the ``category_id`` refer to a stuff, the instance_id is ignored.
Raises:
TypeError:
If ``preds`` or ``target`` is not an ``torch.Tensor``.
ValueError:
If ``preds`` and ``target`` have different shape.
ValueError:
If ``preds`` has less than 3 dimensions.
ValueError:
If the final dimension of ``preds`` has size != 2.
"""
_validate_inputs(preds, target)
flatten_preds = _prepocess_inputs(
self.things, self.stuffs, preds, self.void_color, self.allow_unknown_preds_category
)
flatten_target = _prepocess_inputs(self.things, self.stuffs, target, self.void_color, True)
iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
flatten_preds,
flatten_target,
self.cat_id_to_continuous_id,
self.void_color,
modified_metric_stuffs=self.stuffs,
)
self.iou_sum += iou_sum
self.true_positives += true_positives
self.false_positives += false_positives
self.false_negatives += false_negatives

def compute(self) -> Tensor:
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import tensor
>>> from torchmetrics import ModifiedPanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics import ModifiedPanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> vals = []
>>> for _ in range(20):
... vals.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(vals)
"""
return self._plot(val, ax)
56 changes: 28 additions & 28 deletions src/torchmetrics/detection/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,37 +48,37 @@ class PanopticQuality(Metric):
computation.
Args:
things:
Set of ``category_id`` for countable things.
stuffs:
Set of ``category_id`` for uncountable stuffs.
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
things:
Set of ``category_id`` for countable things.
stuffs:
Set of ``category_id`` for uncountable stuffs.
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
Raises:
ValueError:
If ``things``, ``stuffs`` have at least one common ``category_id``.
TypeError:
If ``things``, ``stuffs`` contain non-integer ``category_id``.
Example:ty
>>> from torch import tensor
>>> from torchmetrics import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)
ValueError:
If ``things``, ``stuffs`` have at least one common ``category_id``.
TypeError:
If ``things``, ``stuffs`` contain non-integer ``category_id``.
Example:
>>> from torch import tensor
>>> from torchmetrics import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)
"""
is_differentiable: bool = False
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torchmetrics.functional.classification.roc import roc
from torchmetrics.functional.classification.specificity import specificity
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality
from torchmetrics.functional.detection.panoptic_quality import panoptic_quality
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
Expand Down Expand Up @@ -139,6 +140,7 @@
"mean_squared_error",
"mean_squared_log_error",
"minkowski_distance",
"modified_panoptic_quality",
"multiscale_structural_similarity_index_measure",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality # noqa: F401
from torchmetrics.functional.detection.panoptic_quality import panoptic_quality # noqa: F401
Loading

0 comments on commit b785d28

Please sign in to comment.