Skip to content

Commit

Permalink
[Segmentation] Added generalized dice score metric (#1090)
Browse files Browse the repository at this point in the history
* Adding generalized dice score metric
* Apply suggestions from code review

---------

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
7 people authored Apr 23, 2024
1 parent 82ab513 commit 745c471
Show file tree
Hide file tree
Showing 12 changed files with 475 additions and 1 deletion.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `GeneralizedDiceScore` to segmentation package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090))


- Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217))


Expand All @@ -34,7 +37,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

-


### Fixed
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ covers the following domains:
- Multimodal (Image-Text)
- Nominal
- Regression
- Segmentation
- Text

Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`,
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ Or directly from conda

retrieval/*

.. toctree::
:maxdepth: 2
:name: segmentation
:caption: Segmentation
:glob:

segmentation/*

.. toctree::
:maxdepth: 2
:name: text
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
22 changes: 22 additions & 0 deletions docs/source/segmentation/generalized_dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Generalized Dice Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

######################
Generalized Dice Score
######################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.GeneralizedDiceScore
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.generalized_dice_score
:noindex:
1 change: 1 addition & 0 deletions src/torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"confusion_matrix",
"multiclass_confusion_matrix",
"multilabel_confusion_matrix",
"generalized_dice_score",
"dice",
"exact_match",
"multiclass_exact_match",
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score

__all__ = ["generalized_dice_score"]
138 changes: 138 additions & 0 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide


def _generalized_dice_validate_args(
num_classes: int,
include_background: bool,
per_class: bool,
weight_type: Literal["square", "simple", "linear"],
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if not isinstance(per_class, bool):
raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")
if weight_type not in ["square", "simple", "linear"]:
raise ValueError(
f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}."
)


def _generalized_dice_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool,
weight_type: Literal["square", "simple", "linear"] = "square",
) -> Tensor:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if (preds.bool() != preds).any(): # preds is an index tensor
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
preds, target = _ignore_background(preds, target)

reduce_axis = list(range(2, target.ndim))
intersection = torch.sum(preds * target, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
cardinality = target_sum + pred_sum

if weight_type == "simple":
weights = 1.0 / target_sum
elif weight_type == "linear":
weights = torch.ones_like(target_sum)
elif weight_type == "square":
weights = 1.0 / (target_sum**2)
else:
raise ValueError(
f"Expected argument `weight_type` to be one of 'simple', 'linear', 'square', but got {weight_type}."
)

w_shape = weights.shape
weights_flatten = weights.flatten()
infs = torch.isinf(weights_flatten)
weights_flatten[infs] = 0
w_max = torch.max(weights, 0).values.repeat(w_shape[0], 1).T.flatten()
weights_flatten[infs] = w_max[infs]
weights = weights_flatten.reshape(w_shape)

numerator = 2.0 * intersection * weights
denominator = cardinality * weights
return numerator, denominator # type:ignore[return-value]


def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor:
"""Compute the generalized dice score."""
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator)


def generalized_dice_score(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
weight_type: Literal["square", "simple", "linear"] = "square",
) -> Tensor:
"""Compute the Generalized Dice Score for semantic segmentation.
Args:
preds: Predictions from model
target: Ground truth values
num_classes: Number of classes
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"``
Returns:
The Generalized Dice Score
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.functional.segmentation import generalized_dice_score
>>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> generalized_dice_score(preds, target, num_classes=5)
tensor([0.4830, 0.4935, 0.5044, 0.4880])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True)
tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type)
numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type)
return _generalized_dice_compute(numerator, denominator, per_class)
7 changes: 7 additions & 0 deletions src/torchmetrics/functional/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from torchmetrics.utilities.imports import _SCIPY_AVAILABLE


def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Ignore the background class in the computation."""
preds = preds[:, 1:] if preds.shape[1] > 1 else preds
target = target[:, 1:] if target.shape[1] > 1 else target
return preds, target


def check_if_binarized(x: Tensor) -> None:
"""Check if the input is binarized.
Expand Down
17 changes: 17 additions & 0 deletions src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore

__all__ = ["GeneralizedDiceScore"]
Loading

0 comments on commit 745c471

Please sign in to comment.