diff --git a/pyproject.toml b/pyproject.toml index 5107822c1c4..c167af3458b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ dependencies = [ # torch 1.12+ required by torchvision "torch>=1.12,<3", # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics - "torchmetrics>=0.10,<0.12", + "torchmetrics>=0.10,<2", # torchvision 0.13+ required for torchvision.models._api.WeightsEnum "torchvision>=0.13,<0.16", ] diff --git a/requirements/required.txt b/requirements/required.txt index fd8788f1d67..94a1a016552 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -17,5 +17,5 @@ segmentation-models-pytorch==0.3.3 shapely==2.0.1 timm==0.9.2 torch==2.0.1 -torchmetrics==0.11.4 +torchmetrics==1.0.0 torchvision==0.15.2 diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index dfa2bd924a8..ed58bf945a2 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -12,6 +12,7 @@ from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau +from torchmetrics import MetricCollection from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models import resnet as R from torchvision.models.detection.backbone_utils import resnet_fpn_backbone @@ -187,8 +188,9 @@ def __init__(self, **kwargs: Any) -> None: self.config_task() - self.val_metrics = MeanAveragePrecision() - self.test_metrics = MeanAveragePrecision() + metrics = MetricCollection([MeanAveragePrecision()]) + self.val_metrics = metrics.clone(prefix="val_") + self.test_metrics = metrics.clone(prefix="test_") def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass of the model. @@ -273,8 +275,11 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: def on_validation_epoch_end(self) -> None: """Logs epoch level validation metrics.""" metrics = self.val_metrics.compute() - renamed_metrics = {f"val_{i}": metrics[i] for i in metrics.keys()} - self.log_dict(renamed_metrics) + + # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 + metrics.pop("val_classes", None) + + self.log_dict(metrics) self.val_metrics.reset() def test_step(self, *args: Any, **kwargs: Any) -> None: @@ -297,8 +302,11 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: def on_test_epoch_end(self) -> None: """Logs epoch level test metrics.""" metrics = self.test_metrics.compute() - renamed_metrics = {f"test_{i}": metrics[i] for i in metrics.keys()} - self.log_dict(renamed_metrics) + + # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 + metrics.pop("test_classes", None) + + self.log_dict(metrics) self.test_metrics.reset() def predict_step(self, *args: Any, **kwargs: Any) -> list[dict[str, Tensor]]: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index af20fb8687a..2107371cf53 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -175,7 +175,7 @@ class and used with 'ce' loss MulticlassAccuracy( num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, - mdmc_average="global", + multidim_average="global", average="micro", ), MulticlassJaccardIndex(