Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Panoptic Quality implementation #1527

Merged
merged 18 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions src/torchmetrics/detection/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Set
from typing import Any, Collection

import torch
from torch import Tensor
Expand All @@ -22,8 +22,8 @@
_get_void_color,
_panoptic_quality_compute,
_panoptic_quality_update,
_prepocess_image,
_validate_categories,
_parse_categories,
_prepocess_inputs,
_validate_inputs,
)
from torchmetrics.metric import Metric
Expand All @@ -37,38 +37,47 @@ class PanopticQuality(Metric):

where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives,
the number of true postitives, false positives and false negatives. This metric is inspired by the PQ
implementati on of panopticapi, a standard implementation for the PQ metric for object detection.
implementation of panopticapi, a standard implementation for the PQ metric for panoptic segmentation.

.. note:
Metric is currently experimental
Metric is currently experimental.
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. note:
Borda marked this conversation as resolved.
Show resolved Hide resolved
Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
computation.

Args:
things:
Set of ``category_id`` for countable things.
stuffs:
Set of ``category_id`` for uncountable stuffs.
allow_unknown_preds_category:
Bool indication if unknown categories in preds is allowed
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.


Raises:
ValueError:
If ``things``, ``stuffs`` share the same ``category_id``.
If ``things``, ``stuffs`` have at least one common ``category_id``.
TypeError:
If ``things``, ``stuffs`` contain non-integer ``category_id``.

Example:
>>> from torch import tensor
>>> preds = tensor([[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]])
>>> target = tensor([[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]])
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)

"""
is_differentiable: bool = False
higher_is_better: bool = True
Expand All @@ -81,8 +90,8 @@ class PanopticQuality(Metric):

def __init__(
self,
things: Set[int],
stuffs: Set[int],
things: Collection[int],
marcocaccin marked this conversation as resolved.
Show resolved Hide resolved
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
):
Expand All @@ -91,7 +100,7 @@ def __init__(
# todo: better testing for correctness of metric
warnings.warn("This is experimental version and are actively working on its stability.")
Borda marked this conversation as resolved.
Show resolved Hide resolved

_validate_categories(things, stuffs)
things, stuffs = _parse_categories(things, stuffs)
self.things = things
self.stuffs = stuffs
self.void_color = _get_void_color(things, stuffs)
Expand All @@ -109,27 +118,29 @@ def update(self, preds: Tensor, target: Tensor) -> None:
r"""Update state with predictions and targets.

Args:
preds: panoptic detection of shape ``[height, width, 2]`` containing
the pair ``(category_id, instance_id)`` for each pixel of the image.
preds: panoptic detection of shape ``[batch, *spatial_dims, 2]`` containing
the pair ``(category_id, instance_id)`` for each point.
If the ``category_id`` refer to a stuff, the instance_id is ignored.

target: ground truth of shape ``[height, width, 2]`` containing
target: ground truth of shape ``[batch, *spatial_dims, 2]`` containing
the pair ``(category_id, instance_id)`` for each pixel of the image.
If the ``category_id`` refer to a stuff, the instance_id is ignored.

Raises:
TypeError:
If ``preds`` or ``target`` is not an ``torch.Tensor``
If ``preds`` or ``target`` is not an ``torch.Tensor``.
ValueError:
If ``preds`` and ``target`` have different shape.
ValueError:
If ``preds`` or ``target`` has different shape.
If ``preds`` has less than 3 dimensions.
ValueError:
If ``preds`` is not a 3D tensor where the final dimension have size 2
If the final dimension of ``preds`` has size != 2.
"""
_validate_inputs(preds, target)
flatten_preds = _prepocess_image(
flatten_preds = _prepocess_inputs(
self.things, self.stuffs, preds, self.void_color, self.allow_unknown_preds_category
)
flatten_target = _prepocess_image(self.things, self.stuffs, target, self.void_color, True)
flatten_target = _prepocess_inputs(self.things, self.stuffs, target, self.void_color, True)
iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
flatten_preds, flatten_target, self.cat_id_to_continuous_id, self.void_color
)
Expand Down
Loading