diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ec812e74ef1..36076119a07 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -22,70 +22,76 @@ //"--shm-size=4gb", ], // Set *default* container specific settings.json values on container create. - "settings": { - "editor.formatOnSave": true, - "editor.rulers": [ - 120 - ], - "python.pythonPath": "/usr/local/bin/python", - "python.defaultInterpreterPath": "/usr/local/bin/python", - "python.languageServer": "Pylance", - "python.autoComplete.addBrackets": true, - "python.analysis.autoImportCompletions": true, - "python.analysis.completeFunctionParens": true, - "python.analysis.autoSearchPaths": true, - "python.analysis.useImportHeuristic": true, - "python.sortImports": true, - "python.sortImports.args": [ - "--settings-path=${workspaceFolder}/pyproject.toml", - ], - "python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8", - "python.formatting.blackPath": "/usr/local/py-utils/bin/black", - "python.formatting.provider": "black", - "python.formatting.blackArgs": [ - "--config=${workspaceFolder}/pyproject.toml" - ], - "python.linting.banditPath": "/usr/local/py-utils/bin/bandit", - "python.linting.flake8Path": "/usr/local/py-utils/bin/flake8", - "python.linting.mypyPath": "/usr/local/py-utils/bin/mypy", - "python.linting.pycodestylePath": "/usr/local/py-utils/bin/pycodestyle", - "python.linting.pydocstylePath": "/usr/local/py-utils/bin/pydocstyle", - "python.linting.pylintPath": "/usr/local/py-utils/bin/pylint", - "python.linting.enabled": true, - "python.linting.pylintEnabled": false, - "python.linting.flake8Enabled": true, - "python.linting.flake8Args": [ - "--config=${workspaceFolder}/setup.cfg", - "--verbose" - ], - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "restructuredtext.confPath": "${workspaceFolder}/docs/source", - "restructuredtext.builtDocumentationPath": "${workspaceFolder}/docs/build", - "restructuredtext.languageServer.disabled": false, - "[python]": { - "editor.codeActionsOnSave": { - "source.organizeImports": true, - } + "customizations": { + "vscode": { + "settings": { + "editor.formatOnSave": true, + "editor.rulers": [ + 120 + ], + "files.exclude": { + "**/__pycache__": true + }, + "python.pythonPath": "/usr/local/bin/python", + "python.defaultInterpreterPath": "/usr/local/bin/python", + "python.languageServer": "Pylance", + "python.autoComplete.addBrackets": true, + "python.analysis.autoImportCompletions": true, + "python.analysis.completeFunctionParens": true, + "python.analysis.autoSearchPaths": true, + "python.analysis.useImportHeuristic": true, + "python.sortImports": true, + "isort.args": [ + "--settings-path=${workspaceFolder}/pyproject.toml" + ], + "python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8", + "python.formatting.blackPath": "/usr/local/py-utils/bin/black", + "python.formatting.provider": "black", + "python.formatting.blackArgs": [ + "--config=${workspaceFolder}/pyproject.toml" + ], + "python.linting.banditPath": "/usr/local/py-utils/bin/bandit", + "python.linting.flake8Path": "/usr/local/py-utils/bin/flake8", + "python.linting.mypyPath": "/usr/local/py-utils/bin/mypy", + "python.linting.pycodestylePath": "/usr/local/py-utils/bin/pycodestyle", + "python.linting.pydocstylePath": "/usr/local/py-utils/bin/pydocstyle", + "python.linting.pylintPath": "/usr/local/py-utils/bin/pylint", + "python.linting.enabled": true, + "python.linting.pylintEnabled": false, + "python.linting.flake8Enabled": true, + "python.linting.flake8Args": [ + "--config=${workspaceFolder}/setup.cfg", + "--verbose" + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "esbonio.sphinx.confDir": "${workspaceFolder}/docs/source", + "esbonio.sphinx.buildDir": "${workspaceFolder}/docs/build", + "[python]": { + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + } + }, + // Add the IDs of extensions you want installed when the container is created. + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance", + "visualstudioexptteam.vscodeintellicode", + "kevinrose.vsc-python-indent", + "littlefoxteam.vscode-python-test-adapter", + "hbenl.vscode-test-explorer", + "medo64.render-crlf", + "shardulm94.trailing-spaces", + "njqdev.vscode-python-typehint", + "lextudio.restructuredtext", + "trond-snekvik.simple-rst" + ] } }, - // Add the IDs of extensions you want installed when the container is created. - "extensions": [ - "ms-python.python", - "ms-python.vscode-pylance", - "visualstudioexptteam.vscodeintellicode", - "kevinrose.vsc-python-indent", - "littlefoxteam.vscode-python-test-adapter", - "hbenl.vscode-test-explorer", - "medo64.render-crlf", - "shardulm94.trailing-spaces", - "njqdev.vscode-python-typehint", - "lextudio.restructuredtext", - "trond-snekvik.simple-rst", - ], // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. diff --git a/CHANGELOG.md b/CHANGELOG.md index 381bbcbff85..ac501b08826 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -92,6 +92,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `PrecisionAtFixedRecall` metric to classification package ([#1683](https://github.com/Lightning-AI/torchmetrics/pull/1683)) +- Added multiple metrics to detection package ([#1284](https://github.com/Lightning-AI/metrics/pull/1284)) + * `IntersectionOverUnion` + * `GeneralizedIntersectionOverUnion` + * `CompleteIntersectionOverUnion` + * `DistanceIntersectionOverUnion` + + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) diff --git a/docs/source/detection/complete_intersection_over_union.rst b/docs/source/detection/complete_intersection_over_union.rst new file mode 100644 index 00000000000..f8cadd18088 --- /dev/null +++ b/docs/source/detection/complete_intersection_over_union.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Complete Intersection Over Union (cIoU) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg + :tags: Detection + +####################################### +Complete Intersection Over Union (cIoU) +####################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.detection.ciou.CompleteIntersectionOverUnion + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.detection.ciou.complete_intersection_over_union + :noindex: diff --git a/docs/source/detection/distance_intersection_over_union.rst b/docs/source/detection/distance_intersection_over_union.rst new file mode 100644 index 00000000000..f75c8cb6f29 --- /dev/null +++ b/docs/source/detection/distance_intersection_over_union.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Distance Intersection Over Union (dIoU) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg + :tags: Detection + +####################################### +Distance Intersection Over Union (dIoU) +####################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.detection.diou.DistanceIntersectionOverUnion + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.detection.diou.distance_intersection_over_union + :noindex: diff --git a/docs/source/detection/generalized_intersection_over_union.rst b/docs/source/detection/generalized_intersection_over_union.rst new file mode 100644 index 00000000000..3a0a22b5221 --- /dev/null +++ b/docs/source/detection/generalized_intersection_over_union.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Generalized Intersection Over Union (gIoU) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg + :tags: Detection + +########################################## +Generalized Intersection Over Union (gIoU) +########################################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.detection.giou.GeneralizedIntersectionOverUnion + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.detection.giou.generalized_intersection_over_union + :noindex: diff --git a/docs/source/detection/intersection_over_union.rst b/docs/source/detection/intersection_over_union.rst new file mode 100644 index 00000000000..bb15d9d4270 --- /dev/null +++ b/docs/source/detection/intersection_over_union.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Intersection Over Union (IoU) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg + :tags: Detection + +############################# +Intersection Over Union (IoU) +############################# + +Module Interface +________________ + +.. autoclass:: torchmetrics.detection.iou.IntersectionOverUnion + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.detection.iou.intersection_over_union + :noindex: diff --git a/src/torchmetrics/detection/__init__.py b/src/torchmetrics/detection/__init__.py index 790aa9ed3ed..e589bca3cd6 100644 --- a/src/torchmetrics/detection/__init__.py +++ b/src/torchmetrics/detection/__init__.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.detection.panoptic_qualities import ModifiedPanopticQuality, PanopticQuality -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import ( + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_8, + _TORCHVISION_GREATER_EQUAL_0_13, +) __all__ = ["ModifiedPanopticQuality", "PanopticQuality"] if _TORCHVISION_GREATER_EQUAL_0_8: - from torchmetrics.detection.mean_ap import MeanAveragePrecision # noqa: F401 + from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion + from torchmetrics.detection.iou import IntersectionOverUnion + from torchmetrics.detection.mean_ap import MeanAveragePrecision - __all__.append("MeanAveragePrecision") + __all__ += ["MeanAveragePrecision", "GeneralizedIntersectionOverUnion", "IntersectionOverUnion"] + +if _TORCHVISION_GREATER_EQUAL_0_13: + from torchmetrics.detection.ciou import CompleteIntersectionOverUnion + from torchmetrics.detection.diou import DistanceIntersectionOverUnion + + __all__ += ["CompleteIntersectionOverUnion", "DistanceIntersectionOverUnion"] diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py new file mode 100644 index 00000000000..f967a449b8c --- /dev/null +++ b/src/torchmetrics/detection/ciou.py @@ -0,0 +1,185 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.detection.iou import IntersectionOverUnion +from torchmetrics.functional.detection.ciou import _ciou_compute, _ciou_update +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _TORCHVISION_GREATER_EQUAL_0_13: + __doctest_skip__ = ["CompleteIntersectionOverUnion", "CompleteIntersectionOverUnion.plot"] +elif not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CompleteIntersectionOverUnion.plot"] + + +class CompleteIntersectionOverUnion(IntersectionOverUnion): + r"""Computes Complete Intersection Over Union (CIoU) `_. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict: + + - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for + the boxes. + + - ``target`` (:class:`~List`) A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict: + + - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth + classes for the boxes. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``ciou_dict``: A dictionary containing the following key-values: + + - ciou: (:class:`~torch.Tensor`) + - ciou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True`` + + Args: + box_format: + Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``. + iou_thresholds: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + class_metrics: + Option to enable per-class metrics for IoU. Has a performance impact. + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.detection import CompleteIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = CompleteIntersectionOverUnion() + >>> metric(preds, target) + {'ciou': tensor(-0.5694)} + + Raises: + ModuleNotFoundError: + If torchvision is not installed with version 0.13.0 or newer. + + """ + _iou_type: str = "ciou" + _invalid_val: float = -2.0 # unsure, min val could be just -1.5 as well + + def __init__( + self, + box_format: str = "xyxy", + iou_threshold: Optional[float] = None, + class_metrics: bool = False, + **kwargs: Any, + ) -> None: + if not _TORCHVISION_GREATER_EQUAL_0_13: + raise ModuleNotFoundError( + f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed." + " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + ) + super().__init__(box_format, iou_threshold, class_metrics, **kwargs) + + @staticmethod + def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: + return _ciou_update(*args, **kwargs) + + @staticmethod + def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: + return _ciou_compute(*args, **kwargs) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting single value + >>> import torch + >>> from torchmetrics.detection import CompleteIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = CompleteIntersectionOverUnion() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.detection import CompleteIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = lambda : [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = CompleteIntersectionOverUnion() + >>> vals = [] + >>> for _ in range(20): + ... vals.append(metric(preds, target())) + >>> fig_, ax_ = metric.plot(vals) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py new file mode 100644 index 00000000000..f30ef869ec2 --- /dev/null +++ b/src/torchmetrics/detection/diou.py @@ -0,0 +1,185 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.detection.iou import IntersectionOverUnion +from torchmetrics.functional.detection.diou import _diou_compute, _diou_update +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _TORCHVISION_GREATER_EQUAL_0_13: + __doctest_skip__ = ["DistanceIntersectionOverUnion", "DistanceIntersectionOverUnion.plot"] +elif not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["DistanceIntersectionOverUnion.plot"] + + +class DistanceIntersectionOverUnion(IntersectionOverUnion): + r"""Computes Distance Intersection Over Union (DIoU) `_. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict + + - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for + the boxes. + + - ``target`` (:class:`~List`) A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict: + + - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth + classes for the boxes. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``diou_dict``: A dictionary containing the following key-values: + + - diou: (:class:`~torch.Tensor`) + - diou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True`` + + Args: + box_format: + Input format of given boxes. Supported formats are ``['xyxy', 'xywh', 'cxcywh']``. + iou_thresholds: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + class_metrics: + Option to enable per-class metrics for IoU. Has a performance impact. + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.detection import DistanceIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = DistanceIntersectionOverUnion() + >>> metric(preds, target) + {'diou': tensor(-0.0694)} + + Raises: + ModuleNotFoundError: + If torchvision is not installed with version 0.13.0 or newer. + + """ + _iou_type: str = "diou" + _invalid_val: float = -1.0 + + def __init__( + self, + box_format: str = "xyxy", + iou_threshold: Optional[float] = None, + class_metrics: bool = False, + **kwargs: Any, + ) -> None: + if not _TORCHVISION_GREATER_EQUAL_0_13: + raise ModuleNotFoundError( + f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed." + " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + ) + super().__init__(box_format, iou_threshold, class_metrics, **kwargs) + + @staticmethod + def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: + return _diou_update(*args, **kwargs) + + @staticmethod + def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: + return _diou_compute(*args, **kwargs) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting single value + >>> import torch + >>> from torchmetrics.detection import DistanceIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = DistanceIntersectionOverUnion() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.detection import DistanceIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = lambda : [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = DistanceIntersectionOverUnion() + >>> vals = [] + >>> for _ in range(20): + ... vals.append(metric(preds, target())) + >>> fig_, ax_ = metric.plot(vals) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py new file mode 100644 index 00000000000..e06660ea751 --- /dev/null +++ b/src/torchmetrics/detection/giou.py @@ -0,0 +1,179 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.detection.iou import IntersectionOverUnion +from torchmetrics.functional.detection.giou import _giou_compute, _giou_update +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _TORCHVISION_GREATER_EQUAL_0_8: + __doctest_skip__ = ["GeneralizedIntersectionOverUnion", "GeneralizedIntersectionOverUnion.plot"] +elif not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["GeneralizedIntersectionOverUnion.plot"] + + +class GeneralizedIntersectionOverUnion(IntersectionOverUnion): + r"""Compute Generalized Intersection Over Union (GIoU) `_. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict + + - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for + the boxes. + + - ``target`` (:class:`~List`) A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict: + + - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth + classes for the boxes. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``giou_dict``: A dictionary containing the following key-values: + + - giou: (:class:`~torch.Tensor`) + - giou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True`` + + Args: + box_format: + Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``. + iou_thresholds: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + class_metrics: + Option to enable per-class metrics for IoU. Has a performance impact. + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.detection import GeneralizedIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = GeneralizedIntersectionOverUnion() + >>> metric(preds, target) + {'giou': tensor(-0.0694)} + + Raises: + ModuleNotFoundError: + If torchvision is not installed with version 0.8.0 or newer. + """ + _iou_type: str = "giou" + _invalid_val: float = -1.0 + + def __init__( + self, + box_format: str = "xyxy", + iou_threshold: Optional[float] = None, + class_metrics: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(box_format, iou_threshold, class_metrics, **kwargs) + + @staticmethod + def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: + return _giou_update(*args, **kwargs) + + @staticmethod + def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: + return _giou_compute(*args, **kwargs) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting single value + >>> import torch + >>> from torchmetrics.detection import GeneralizedIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = GeneralizedIntersectionOverUnion() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.detection import GeneralizedIntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = lambda : [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = GeneralizedIntersectionOverUnion() + >>> vals = [] + >>> for _ in range(20): + ... vals.append(metric(preds, target())) + >>> fig_, ax_ = metric.plot(vals) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py new file mode 100644 index 00000000000..c86787992f3 --- /dev/null +++ b/src/torchmetrics/detection/helpers.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Sequence + +from torch import Tensor + + +def _input_validator( + preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: str = "bbox" +) -> None: + """Ensure the correct input format of `preds` and `targets`.""" + if iou_type == "bbox": + item_val_name = "boxes" + elif iou_type == "segm": + item_val_name = "masks" + else: + raise Exception(f"IOU type {iou_type} is not supported") + + if not isinstance(preds, Sequence): + raise ValueError(f"Expected argument `preds` to be of type Sequence, but got {preds}") + if not isinstance(targets, Sequence): + raise ValueError(f"Expected argument `target` to be of type Sequence, but got {targets}") + if len(preds) != len(targets): + raise ValueError( + f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}" + ) + + for k in [item_val_name, "scores", "labels"]: + if any(k not in p for p in preds): + raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") + + for k in [item_val_name, "labels"]: + if any(k not in p for p in targets): + raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") + + if any(type(pred[item_val_name]) is not Tensor for pred in preds): + raise ValueError(f"Expected all {item_val_name} in `preds` to be of type Tensor") + if any(type(pred["scores"]) is not Tensor for pred in preds): + raise ValueError("Expected all scores in `preds` to be of type Tensor") + if any(type(pred["labels"]) is not Tensor for pred in preds): + raise ValueError("Expected all labels in `preds` to be of type Tensor") + if any(type(target[item_val_name]) is not Tensor for target in targets): + raise ValueError(f"Expected all {item_val_name} in `target` to be of type Tensor") + if any(type(target["labels"]) is not Tensor for target in targets): + raise ValueError("Expected all labels in `target` to be of type Tensor") + + for i, item in enumerate(targets): + if item[item_val_name].size(0) != item["labels"].size(0): + raise ValueError( + f"Input {item_val_name} and labels of sample {i} in targets have a" + f" different length (expected {item[item_val_name].size(0)} labels, got {item['labels'].size(0)})" + ) + for i, item in enumerate(preds): + if not (item[item_val_name].size(0) == item["labels"].size(0) == item["scores"].size(0)): + raise ValueError( + f"Input {item_val_name}, labels and scores of sample {i} in predictions have a" + f" different length (expected {item[item_val_name].size(0)} labels and scores," + f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" + ) + + +def _fix_empty_tensors(boxes: Tensor) -> Tensor: + """Empty tensors can cause problems in DDP mode, this methods corrects them.""" + if boxes.numel() == 0 and boxes.ndim == 1: + return boxes.unsqueeze(0) + return boxes diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py new file mode 100644 index 00000000000..7114b735459 --- /dev/null +++ b/src/torchmetrics/detection/iou.py @@ -0,0 +1,308 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator +from torchmetrics.functional.detection.iou import _iou_compute, _iou_update +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if _TORCHVISION_GREATER_EQUAL_0_8: + from torchvision.ops import box_convert +else: + box_convert = None + +if not _TORCHVISION_GREATER_EQUAL_0_8: + __doctest_skip__ = ["IntersectionOverUnion", "IntersectionOverUnion.plot"] +elif not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["IntersectionOverUnion.plot"] + + +class IntersectionOverUnion(Metric): + r"""Computes Intersection Over Union (IoU). + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict + + - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for + the boxes. + + - ``target`` (:class:`~List`) A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict: + + - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth + classes for the boxes. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``iou_dict``: A dictionary containing the following key-values: + + - iou: (:class:`~torch.Tensor`) + - iou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True`` + + Args: + box_format: + Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``. + iou_thresholds: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + class_metrics: + Option to enable per-class metrics for IoU. Has a performance impact. + respect_labels: + Replace IoU values with the `invalid_val` if the labels do not match. + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.detection import IntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = IntersectionOverUnion() + >>> metric(preds, target) + {'iou': tensor(0.4307)} + + Raises: + ModuleNotFoundError: + If torchvision is not installed with version 0.8.0 or newer. + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + + detections: List[Tensor] + detection_scores: List[Tensor] + detection_labels: List[Tensor] + groundtruths: List[Tensor] + groundtruth_labels: List[Tensor] + results: List[Tensor] + labels_eq: List[Tensor] + _iou_type: str = "iou" + _invalid_val: float = 0.0 + + def __init__( + self, + box_format: str = "xyxy", + iou_threshold: Optional[float] = None, + class_metrics: bool = False, + respect_labels: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if not _TORCHVISION_GREATER_EQUAL_0_8: + raise ModuleNotFoundError( + f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.8.0 or newer is installed." + " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + ) + + allowed_box_formats = ("xyxy", "xywh", "cxcywh") + if box_format not in allowed_box_formats: + raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") + + self.box_format = box_format + self.iou_threshold = iou_threshold + + if not isinstance(class_metrics, bool): + raise ValueError("Expected argument `class_metrics` to be a boolean") + self.class_metrics = class_metrics + + if not isinstance(respect_labels, bool): + raise ValueError("Expected argument `respect_labels` to be a boolean") + self.respect_labels = respect_labels + + self.add_state("detections", default=[], dist_reduce_fx=None) + self.add_state("detection_scores", default=[], dist_reduce_fx=None) + self.add_state("detection_labels", default=[], dist_reduce_fx=None) + self.add_state("groundtruths", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) + self.add_state("results", default=[], dist_reduce_fx=None) + self.add_state("labels_eq", default=[], dist_reduce_fx=None) + + @staticmethod + def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: + return _iou_update(*args, **kwargs) + + @staticmethod + def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: + return _iou_compute(*args, **kwargs) + + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + """Update state with predictions and targets. + + Raises: + ValueError: + If ``preds`` is not of type List[Dict[str, Tensor]] + ValueError: + If ``target`` is not of type List[Dict[str, Tensor]] + ValueError: + If ``preds`` and ``target`` are not of the same length + ValueError: + If any of ``preds.boxes``, ``preds.scores`` + and ``preds.labels`` are not of the same length + ValueError: + If any of ``target.boxes`` and ``target.labels`` are not of the same length + ValueError: + If any box is not type float and of length 4 + ValueError: + If any class is not type int and of length 1 + ValueError: + If any score is not type float and of length 1 + """ + _input_validator(preds, target) + + for p, t in zip(preds, target): + det_boxes = self._get_safe_item_values(p["boxes"]) + self.detections.append(det_boxes) + self.detection_labels.append(p["labels"]) + self.detection_scores.append(p["scores"]) + + gt_boxes = self._get_safe_item_values(t["boxes"]) + self.groundtruths.append(gt_boxes) + self.groundtruth_labels.append(t["labels"]) + + label_eq = torch.equal(p["labels"], t["labels"]) + # Workaround to persist state, which only works with tensors + self.labels_eq.append(torch.tensor([label_eq], dtype=torch.int, device=self.device)) + + ious = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) + if self.respect_labels and not label_eq: + label_diff = p["labels"].unsqueeze(0).T - t["labels"].unsqueeze(0) + labels_not_eq = label_diff != 0.0 + ious[labels_not_eq] = self._invalid_val + self.results.append(ious.to(dtype=torch.float, device=self.device)) + + def _get_safe_item_values(self, boxes: Tensor) -> Tensor: + boxes = _fix_empty_tensors(boxes) + if boxes.numel() > 0: + boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") + return boxes + + def _get_gt_classes(self) -> List: + """Returns a list of unique classes found in ground truth and detection data.""" + if len(self.groundtruth_labels) > 0: + return torch.cat(self.groundtruth_labels).unique().tolist() + return [] + + def compute(self) -> dict: + """Computes IoU based on inputs passed in to ``update`` previously.""" + aggregated_iou = dim_zero_cat( + [self._iou_compute_fn(iou, bool(lbl_eq)) for iou, lbl_eq in zip(self.results, self.labels_eq)] + ) + results: Dict[str, Tensor] = {f"{self._iou_type}": aggregated_iou.mean()} + + if self.class_metrics: + class_results: Dict[int, List[Tensor]] = defaultdict(list) + for iou, label in zip(self.results, self.groundtruth_labels): + for cl in self._get_gt_classes(): + masked_iou = iou[:, label == cl] + if masked_iou.numel() > 0: + class_results[cl].append(self._iou_compute_fn(masked_iou, False)) + + results.update( + {f"{self._iou_type}/cl_{cl}": dim_zero_cat(class_results[cl]).mean() for cl in class_results} + ) + return results + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> import torch + >>> from torchmetrics.detection import IntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = IntersectionOverUnion() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.detection import IntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + ... "scores": torch.tensor([0.236, 0.56]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = lambda : [ + ... { + ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), + ... "labels": torch.tensor([5]), + ... } + ... ] + >>> metric = IntersectionOverUnion() + >>> vals = [] + >>> for _ in range(20): + ... vals.append(metric(preds, target())) + >>> fig_, ax_ = metric.plot(vals) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index d1d2cedfd1b..7cfbc5f08bc 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -18,6 +18,7 @@ import torch from torch import IntTensor, Tensor +from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator from torchmetrics.metric import Metric from torchmetrics.utilities.data import _cumsum from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 @@ -146,59 +147,6 @@ def _segm_iou(det: List[Tuple[np.ndarray, np.ndarray]], gt: List[Tuple[np.ndarra return torch.tensor(mask_utils.iou(det_coco_format, gt_coco_format, [False for _ in gt])) -def _input_validator( - preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: str = "bbox" -) -> None: - """Ensure the correct input format of `preds` and `targets`.""" - if not isinstance(preds, Sequence): - raise ValueError("Expected argument `preds` to be of type Sequence") - if not isinstance(targets, Sequence): - raise ValueError("Expected argument `target` to be of type Sequence") - if len(preds) != len(targets): - raise ValueError("Expected argument `preds` and `target` to have the same length") - iou_attribute = "boxes" if iou_type == "bbox" else "masks" - - for k in [iou_attribute, "scores", "labels"]: - if any(k not in p for p in preds): - raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") - - for k in [iou_attribute, "labels"]: - if any(k not in p for p in targets): - raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") - - if any(type(pred[iou_attribute]) is not Tensor for pred in preds): - raise ValueError(f"Expected all {iou_attribute} in `preds` to be of type Tensor") - if any(type(pred["scores"]) is not Tensor for pred in preds): - raise ValueError("Expected all scores in `preds` to be of type Tensor") - if any(type(pred["labels"]) is not Tensor for pred in preds): - raise ValueError("Expected all labels in `preds` to be of type Tensor") - if any(type(target[iou_attribute]) is not Tensor for target in targets): - raise ValueError(f"Expected all {iou_attribute} in `target` to be of type Tensor") - if any(type(target["labels"]) is not Tensor for target in targets): - raise ValueError("Expected all labels in `target` to be of type Tensor") - - for i, item in enumerate(targets): - if item[iou_attribute].size(0) != item["labels"].size(0): - raise ValueError( - f"Input {iou_attribute} and labels of sample {i} in targets have a" - f" different length (expected {item[iou_attribute].size(0)} labels, got {item['labels'].size(0)})" - ) - for i, item in enumerate(preds): - if not (item[iou_attribute].size(0) == item["labels"].size(0) == item["scores"].size(0)): - raise ValueError( - f"Input {iou_attribute}, labels and scores of sample {i} in predictions have a" - f" different length (expected {item[iou_attribute].size(0)} labels and scores," - f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" - ) - - -def _fix_empty_tensors(boxes: Tensor) -> Tensor: - """Empty tensors can cause problems in DDP mode, this methods corrects them.""" - if boxes.numel() == 0 and boxes.ndim == 1: - return boxes.unsqueeze(0) - return boxes - - class MeanAveragePrecision(Metric): r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions. @@ -350,7 +298,7 @@ class MeanAveragePrecision(Metric): 'mar_small': tensor(-1.)} """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: Optional[bool] = True full_state_update: bool = True plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/functional/detection/__init__.py b/src/torchmetrics/functional/detection/__init__.py index 7ec83290faa..85a2d12e39c 100644 --- a/src/torchmetrics/functional/detection/__init__.py +++ b/src/torchmetrics/functional/detection/__init__.py @@ -11,6 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality +from torchmetrics.utilities.imports import ( + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_8, + _TORCHVISION_GREATER_EQUAL_0_13, +) __all__ = ["modified_panoptic_quality", "panoptic_quality"] + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8: + from torchmetrics.functional.detection.giou import generalized_intersection_over_union # noqa: F401 + from torchmetrics.functional.detection.iou import intersection_over_union # noqa: F401 + + __all__.append("generalized_intersection_over_union") + __all__.append("intersection_over_union") + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13: + from torchmetrics.functional.detection.ciou import complete_intersection_over_union # noqa: F401 + from torchmetrics.functional.detection.diou import distance_intersection_over_union # noqa: F401 + + __all__.append("complete_intersection_over_union") + __all__.append("distance_intersection_over_union") diff --git a/src/torchmetrics/functional/detection/ciou.py b/src/torchmetrics/functional/detection/ciou.py new file mode 100644 index 00000000000..f1f340f6726 --- /dev/null +++ b/src/torchmetrics/functional/detection/ciou.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch + +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13: + from torchvision.ops import complete_box_iou +else: + complete_box_iou = None + __doctest_skip__ = ["complete_intersection_over_union"] + +__doctest_requires__ = {("complete_intersection_over_union",): ["torchvision"]} + + +def _ciou_update( + preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 +) -> torch.Tensor: + iou = complete_box_iou(preds, target) + if iou_threshold is not None: + iou[iou < iou_threshold] = replacement_val + return iou + + +def _ciou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: + if labels_eq: + return iou.diag().mean() + return iou.mean() + + +def complete_intersection_over_union( + preds: torch.Tensor, + target: torch.Tensor, + iou_threshold: Optional[float] = None, + replacement_val: float = 0, + aggregate: bool = True, +) -> torch.Tensor: + r"""Compute `Complete Intersection over Union `_ between two sets of boxes. + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2. + + Args: + preds: + The input tensor containing the predicted bounding boxes. + target: + The tensor containing the ground truth. + iou_threshold: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + replacement_val: + Value to replace values under the threshold with. + aggregate: + Return the average value instead of the complete IoU matrix. + + Example: + >>> import torch + >>> from torchmetrics.functional.detection import complete_intersection_over_union + >>> preds = torch.Tensor([[100, 100, 200, 200]]) + >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> complete_intersection_over_union(preds, target) + tensor(0.6724) + """ + if not _TORCHVISION_GREATER_EQUAL_0_13: + raise ModuleNotFoundError( + f"`{complete_intersection_over_union.__name__}` requires that `torchvision` version 0.13.0 or newer" + " is installed." + " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + ) + iou = _ciou_update(preds, target, iou_threshold, replacement_val) + return _ciou_compute(iou) if aggregate else iou diff --git a/src/torchmetrics/functional/detection/diou.py b/src/torchmetrics/functional/detection/diou.py new file mode 100644 index 00000000000..8cb641a00ec --- /dev/null +++ b/src/torchmetrics/functional/detection/diou.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch + +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13: + from torchvision.ops import distance_box_iou +else: + distance_box_iou = None + __doctest_skip__ = ["distance_intersection_over_union"] + +__doctest_requires__ = {("distance_intersection_over_union",): ["torchvision"]} + + +def _diou_update( + preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 +) -> torch.Tensor: + iou = distance_box_iou(preds, target) + if iou_threshold is not None: + iou[iou < iou_threshold] = replacement_val + return iou + + +def _diou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: + if labels_eq: + return iou.diag().mean() + return iou.mean() + + +def distance_intersection_over_union( + preds: torch.Tensor, + target: torch.Tensor, + iou_threshold: Optional[float] = None, + replacement_val: float = 0, + aggregate: bool = True, +) -> torch.Tensor: + r"""Compute `Distance Intersection over Union `_ between two sets of boxes. + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2. + + Args: + preds: + The input tensor containing the predicted bounding boxes. + target: + The tensor containing the ground truth. + iou_threshold: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + replacement_val: + Value to replace values under the threshold with. + aggregate: + Return the average value instead of the complete IoU matrix. + + Example: + >>> import torch + >>> from torchmetrics.functional.detection import distance_intersection_over_union + >>> preds = torch.Tensor([[100, 100, 200, 200]]) + >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> distance_intersection_over_union(preds, target) + tensor(0.6724) + """ + if not _TORCHVISION_GREATER_EQUAL_0_13: + raise ModuleNotFoundError( + f"`{distance_intersection_over_union.__name__}` requires that `torchvision` version 0.13.0 or newer" + " is installed." + " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." + ) + iou = _diou_update(preds, target, iou_threshold, replacement_val) + return _diou_compute(iou) if aggregate else iou diff --git a/src/torchmetrics/functional/detection/giou.py b/src/torchmetrics/functional/detection/giou.py new file mode 100644 index 00000000000..a46ab8ef99b --- /dev/null +++ b/src/torchmetrics/functional/detection/giou.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch + +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8: + from torchvision.ops import generalized_box_iou +else: + generalized_box_iou = None + __doctest_skip__ = ["generalized_intersection_over_union"] + +__doctest_requires__ = {("generalized_intersection_over_union",): ["torchvision"]} + + +def _giou_update( + preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 +) -> torch.Tensor: + iou = generalized_box_iou(preds, target) + if iou_threshold is not None: + iou[iou < iou_threshold] = replacement_val + return iou + + +def _giou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: + if labels_eq: + return iou.diag().mean() + return iou.mean() + + +def generalized_intersection_over_union( + preds: torch.Tensor, + target: torch.Tensor, + iou_threshold: Optional[float] = None, + replacement_val: float = 0, + aggregate: bool = True, +) -> torch.Tensor: + r"""Compute `Generalized Intersection over Union `_ between two sets of boxes. + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2. + + Args: + preds: + The input tensor containing the predicted bounding boxes. + target: + The tensor containing the ground truth. + iou_threshold: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + replacement_val: + Value to replace values under the threshold with. + aggregate: + Return the average value instead of the complete IoU matrix. + + Example: + >>> import torch + >>> from torchmetrics.functional.detection import generalized_intersection_over_union + >>> preds = torch.Tensor([[100, 100, 200, 200]]) + >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> generalized_intersection_over_union(preds, target) + tensor(0.6641) + """ + if not _TORCHVISION_GREATER_EQUAL_0_8: + raise ModuleNotFoundError( + f"`{generalized_intersection_over_union.__name__}` requires that `torchvision` version 0.8.0 or newer" + " is installed." + " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + ) + iou = _giou_update(preds, target, iou_threshold, replacement_val) + return _giou_compute(iou) if aggregate else iou diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py new file mode 100644 index 00000000000..0c474e63a32 --- /dev/null +++ b/src/torchmetrics/functional/detection/iou.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch + +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 + +if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8: + from torchvision.ops import box_iou +else: + box_iou = None + __doctest_skip__ = ["intersection_over_union"] + +__doctest_requires__ = {("intersection_over_union",): ["torchvision"]} + + +def _iou_update( + preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 +) -> torch.Tensor: + iou = box_iou(preds, target) + if iou_threshold is not None: + iou[iou < iou_threshold] = replacement_val + return iou + + +def _iou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: + if labels_eq: + return iou.diag().mean() + return iou.mean() + + +def intersection_over_union( + preds: torch.Tensor, + target: torch.Tensor, + iou_threshold: Optional[float] = None, + replacement_val: float = 0, + aggregate: bool = True, +) -> torch.Tensor: + r"""Compute Intersection over Union between two sets of boxes. + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2. + + Args: + preds: + The input tensor containing the predicted bounding boxes. + target: + The tensor containing the ground truth. + iou_threshold: + Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. + replacement_val: + Value to replace values under the threshold with. + aggregate: + Return the average value instead of the complete IoU matrix. + + Example: + >>> import torch + >>> from torchmetrics.functional.detection import intersection_over_union + >>> preds = torch.Tensor([[100, 100, 200, 200]]) + >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> intersection_over_union(preds, target) + tensor(0.6807) + """ + if not _TORCHVISION_GREATER_EQUAL_0_8: + raise ModuleNotFoundError( + f"`{intersection_over_union.__name__}` requires that `torchvision` version 0.8.0 or newer is installed." + " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + ) + iou = _iou_update(preds, target, iou_threshold, replacement_val) + return _iou_compute(iou) if aggregate else iou diff --git a/tests/unittests/detection/base_iou_test.py b/tests/unittests/detection/base_iou_test.py new file mode 100644 index 00000000000..660ee4a6b96 --- /dev/null +++ b/tests/unittests/detection/base_iou_test.py @@ -0,0 +1,266 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from collections import namedtuple +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict + +import pytest +import torch +from torch import IntTensor, Tensor + +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 + +Input = namedtuple("Input", ["preds", "target"]) + + +@dataclass +class TestCaseData: + """Test data sample.""" + + data: Input + result: Any + + +_preds = torch.Tensor( + [ + [296.55, 93.96, 314.97, 152.79], + [328.94, 97.05, 342.49, 122.98], + [356.62, 95.47, 372.33, 147.55], + ] +) +_target = torch.Tensor( + [ + [300.00, 100.00, 315.00, 150.00], + [330.00, 100.00, 350.00, 125.00], + [350.00, 100.00, 375.00, 150.00], + ] +) + +_inputs = Input( + preds=[ + [ + { + "boxes": Tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + "scores": Tensor([0.236, 0.56]), + "labels": IntTensor([4, 5]), + } + ], + [ + { + "boxes": Tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + "scores": Tensor([0.236, 0.56]), + "labels": IntTensor([4, 5]), + } + ], + [ + { + "boxes": Tensor([[328.94, 97.05, 342.49, 122.98]]), + "scores": Tensor([0.456]), + "labels": IntTensor([4]), + }, + { + "boxes": Tensor([[356.62, 95.47, 372.33, 147.55]]), + "scores": Tensor([0.791]), + "labels": IntTensor([4]), + }, + ], + [ + { + "boxes": Tensor([[328.94, 97.05, 342.49, 122.98]]), + "scores": Tensor([0.456]), + "labels": IntTensor([5]), + }, + { + "boxes": Tensor([[356.62, 95.47, 372.33, 147.55]]), + "scores": Tensor([0.791]), + "labels": IntTensor([5]), + }, + ], + ], + target=[ + [ + { + "boxes": Tensor([[300.00, 100.00, 315.00, 150.00]]), + "labels": IntTensor([5]), + } + ], + [ + { + "boxes": Tensor([[300.00, 100.00, 315.00, 150.00]]), + "labels": IntTensor([5]), + } + ], + [ + { + "boxes": Tensor([[330.00, 100.00, 350.00, 125.00]]), + "labels": IntTensor([4]), + }, + { + "boxes": Tensor([[350.00, 100.00, 375.00, 150.00]]), + "labels": IntTensor([4]), + }, + ], + [ + { + "boxes": Tensor([[330.00, 100.00, 350.00, 125.00]]), + "labels": IntTensor([5]), + }, + { + "boxes": Tensor([[350.00, 100.00, 375.00, 150.00]]), + "labels": IntTensor([4]), + }, + ], + ], +) + +_box_inputs = Input(preds=_preds, target=_target) + +_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) + + +def compare_fn(preds: Any, target: Any, result: Any): + """Mock compare function by returning additional parameter results directly.""" + return result + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +@pytest.mark.parametrize("compute_on_cpu", [True, False]) +@pytest.mark.parametrize("ddp", [False, True]) +class BaseTestIntersectionOverUnion(ABC): + """Base Test the Intersection over Union metric for object detection predictions.""" + + data: Dict[str, TestCaseData] = { + "iou_variant": TestCaseData(data=_inputs, result={"iou": torch.Tensor([0])}), + "fn_iou_variant": TestCaseData(data=_box_inputs, result=None), + } + metric_class: Metric + metric_fn: Callable[[Tensor, Tensor, bool, float], Tensor] + + def test_iou_variant(self, compute_on_cpu: bool, ddp: bool): + """Test modular implementation for correctness.""" + key = "iou_variant" + + self.run_class_metric_test( # type: ignore + ddp=ddp, + preds=self.data[key].data.preds, + target=self.data[key].data.target, + metric_class=self.metric_class, + reference_metric=partial(compare_fn, result=self.data[key].result), + dist_sync_on_step=False, + check_batch=False, + metric_args={"compute_on_cpu": compute_on_cpu}, + ) + + def test_iou_variant_dont_respect_labels(self, compute_on_cpu: bool, ddp: bool): + """Test modular implementation for correctness while ignoring labels.""" + key = "iou_variant_respect" + + self.run_class_metric_test( # type: ignore + ddp=ddp, + preds=self.data[key].data.preds, + target=self.data[key].data.target, + metric_class=self.metric_class, + reference_metric=partial(compare_fn, result=self.data[key].result), + dist_sync_on_step=False, + check_batch=False, + metric_args={"compute_on_cpu": compute_on_cpu, "respect_labels": False}, + ) + + def test_fn(self, compute_on_cpu: bool, ddp: bool): + """Test functional implementation for correctness.""" + key = "fn_iou_variant" + self.run_functional_metric_test( + self.data[key].data.preds[0].unsqueeze(0), # pass as batch, otherwise it attempts to pass element wise + self.data[key].data.target[0].unsqueeze(0), + self.metric_fn.__func__, + partial(compare_fn, result=self.data[key].result), + ) + + def test_error_on_wrong_input(self, compute_on_cpu: bool, ddp: bool): + """Test class input validation.""" + metric = self.metric_class() + + metric.update([], []) # no error + + with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"): + metric.update(Tensor(), []) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"): + metric.update([], Tensor()) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): + metric.update([{}], [{}, {}]) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): + metric.update( + [{"scores": Tensor(), "labels": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `scores` key"): + metric.update( + [{"boxes": Tensor(), "labels": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], + [{"labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], + [{"boxes": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"): + metric.update( + [{"boxes": [], "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": [], "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": []}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": [], "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": []}], + ) diff --git a/tests/unittests/detection/test_ciou.py b/tests/unittests/detection/test_ciou.py new file mode 100644 index 00000000000..9e94bb7e894 --- /dev/null +++ b/tests/unittests/detection/test_ciou.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import pytest +import torch +from torch import Tensor + +from torchmetrics import Metric +from torchmetrics.detection.ciou import CompleteIntersectionOverUnion +from torchmetrics.functional.detection.ciou import complete_intersection_over_union +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs +from unittests.helpers.testers import MetricTester + +ciou = torch.Tensor( + [ + [-0.2669985], + ] +) +ciou_dontrespect = torch.Tensor( + [ + [0.6078202], + ] +) +box_ciou = torch.Tensor( + [ + [0.6883, -0.2072, -0.3352], + [-0.2217, 0.4881, -0.1913], + [-0.3971, -0.1543, 0.5606], + ] +) + + +_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.13.0 is installed") +class TestCompleteIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): + """Test the Complete Intersection over Union metric for object detection predictions.""" + + data: Dict[str, TestCaseData] = { + "iou_variant": TestCaseData(data=_inputs, result={CompleteIntersectionOverUnion._iou_type: ciou}), + "iou_variant_respect": TestCaseData( + data=_inputs, result={CompleteIntersectionOverUnion._iou_type: ciou_dontrespect} + ), + "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_ciou), + } + metric_class: Metric = CompleteIntersectionOverUnion + metric_fn: Callable[[Tensor, Tensor, bool, float], Tensor] = complete_intersection_over_union diff --git a/tests/unittests/detection/test_diou.py b/tests/unittests/detection/test_diou.py new file mode 100644 index 00000000000..e6f75b42f82 --- /dev/null +++ b/tests/unittests/detection/test_diou.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import pytest +import torch +from torch import Tensor + +from torchmetrics.detection.diou import DistanceIntersectionOverUnion +from torchmetrics.functional.detection.diou import distance_intersection_over_union +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 +from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs +from unittests.helpers.testers import MetricTester + +diou = torch.Tensor( + [ + [0.06653749], + ] +) +diou_dontrespect = torch.Tensor( + [ + [0.6080749], + ] +) +box_diou = torch.Tensor( + [ + [0.6883, -0.2043, -0.3351], + [-0.2214, 0.4886, -0.1913], + [-0.3971, -0.1510, 0.5609], + ] +) + +_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.13.0 is installed") +class TestDistanceIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): + """Test the Distance Intersection over Union metric for object detection predictions.""" + + data: Dict[str, TestCaseData] = { + "iou_variant": TestCaseData(data=_inputs, result={DistanceIntersectionOverUnion._iou_type: diou}), + "iou_variant_respect": TestCaseData( + data=_inputs, result={DistanceIntersectionOverUnion._iou_type: diou_dontrespect} + ), + "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_diou), + } + metric_class = DistanceIntersectionOverUnion + metric_fn: Callable[[Tensor, Tensor, bool, float], Tensor] = distance_intersection_over_union diff --git a/tests/unittests/detection/test_giou.py b/tests/unittests/detection/test_giou.py new file mode 100644 index 00000000000..e118548f5b9 --- /dev/null +++ b/tests/unittests/detection/test_giou.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import pytest +import torch +from torch import Tensor + +from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion +from torchmetrics.functional.detection.giou import generalized_intersection_over_union +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs +from unittests.helpers.testers import MetricTester + +giou = torch.Tensor( + [ + [0.05507809], + ] +) +giou_dontrespect = torch.Tensor( + [ + [0.59242314], + ] +) +box_giou = torch.Tensor( + [ + [0.6895, -0.4964, -0.4944], + [-0.5105, 0.4673, -0.3434], + [-0.6024, -0.4021, 0.5345], + ] +) + +_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +class TestGeneralizedIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): + """Test the Generalized Intersection over Union metric for object detection predictions.""" + + data: Dict[str, TestCaseData] = { + "iou_variant": TestCaseData(data=_inputs, result={GeneralizedIntersectionOverUnion._iou_type: giou}), + "iou_variant_respect": TestCaseData( + data=_inputs, result={GeneralizedIntersectionOverUnion._iou_type: giou_dontrespect} + ), + "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_giou), + } + metric_class = GeneralizedIntersectionOverUnion + metric_fn: Callable[[Tensor, Tensor, bool, float], Tensor] = generalized_intersection_over_union diff --git a/tests/unittests/detection/test_iou.py b/tests/unittests/detection/test_iou.py new file mode 100644 index 00000000000..444aaebe3c0 --- /dev/null +++ b/tests/unittests/detection/test_iou.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import pytest +import torch +from torch import Tensor + +from torchmetrics.detection.iou import IntersectionOverUnion +from torchmetrics.functional.detection.iou import intersection_over_union +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs +from unittests.helpers.testers import MetricTester + +iou = torch.Tensor( + [ + [0.40733114], + ] +) +iou_dontrespect = torch.Tensor( + [ + [0.6165285], + ] +) +box_iou = torch.Tensor( + [ + [0.6898, 0.0000, 0.0000], + [0.0000, 0.5086, 0.0000], + [0.0000, 0.0000, 0.5654], + ] +) + +_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +class TestIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): + """Test the Intersection over Union metric for object detection predictions.""" + + data: Dict[str, TestCaseData] = { + "iou_variant": TestCaseData(data=_inputs, result={IntersectionOverUnion._iou_type: iou}), + "iou_variant_respect": TestCaseData(data=_inputs, result={IntersectionOverUnion._iou_type: iou_dontrespect}), + "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_iou), + } + metric_class = IntersectionOverUnion + metric_fn: Callable[[Tensor, Tensor, bool, float], Tensor] = intersection_over_union diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 9a071e23768..54298ca6037 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -28,37 +28,40 @@ Input = namedtuple("Input", ["preds", "target"]) -with open(_SAMPLE_DETECTION_SEGMENTATION) as fp: - inputs_json = json.load(fp) -_mask_unsqueeze_bool = lambda m: Tensor(mask.decode(m)).unsqueeze(0).bool() -_masks_stack_bool = lambda ms: Tensor(np.stack([mask.decode(m) for m in ms])).bool() - -_inputs_masks = Input( - preds=[ - [ - { - "masks": _mask_unsqueeze_bool(inputs_json["preds"][0]), - "scores": Tensor([0.236]), - "labels": IntTensor([4]), - }, - { - "masks": _masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]), - "scores": Tensor([0.318, 0.726]), - "labels": IntTensor([3, 2]), - }, # 73 +def _create_inputs_masks() -> Input: + with open(_SAMPLE_DETECTION_SEGMENTATION) as fp: + inputs_json = json.load(fp) + + _mask_unsqueeze_bool = lambda m: Tensor(mask.decode(m)).unsqueeze(0).bool() + _masks_stack_bool = lambda ms: Tensor(np.stack([mask.decode(m) for m in ms])).bool() + + return Input( + preds=[ + [ + { + "masks": _mask_unsqueeze_bool(inputs_json["preds"][0]), + "scores": Tensor([0.236]), + "labels": IntTensor([4]), + }, + { + "masks": _masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]), + "scores": Tensor([0.318, 0.726]), + "labels": IntTensor([3, 2]), + }, # 73 + ], ], - ], - target=[ - [ - {"masks": _mask_unsqueeze_bool(inputs_json["targets"][0]), "labels": IntTensor([4])}, # 42 - { - "masks": _masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]), - "labels": IntTensor([2, 2]), - }, # 73 + target=[ + [ + {"masks": _mask_unsqueeze_bool(inputs_json["targets"][0]), "labels": IntTensor([4])}, # 42 + { + "masks": _masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]), + "labels": IntTensor([2, 2]), + }, # 73 + ], ], - ], -) + ) + _inputs = Input( preds=[ @@ -357,6 +360,7 @@ def test_map_bbox(self, compute_on_cpu, ddp): @pytest.mark.parametrize("ddp", [False]) def test_map_segm(self, compute_on_cpu, ddp): """Test modular implementation for correctness.""" + _inputs_masks = _create_inputs_masks() self.run_class_metric_test( ddp=ddp, preds=_inputs_masks.preds, diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index ece1d97e95c..cdc039a1a2e 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -118,6 +118,7 @@ def _class_test( """ assert len(preds) == len(target) num_batches = len(preds) + assert num_batches % world_size == 0, "Number of batches must be divisible by world_size" if not metric_args: metric_args = {}