diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 7459fa8a953..0595914ebac 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -3,6 +3,7 @@ """Segmentation tasks.""" +import os import warnings from typing import Any, Dict, cast @@ -15,9 +16,11 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex +from torchvision.models._api import WeightsEnum from ..datasets.utils import unbind_samples -from ..models import FCN +from ..models import FCN, get_weight +from . import utils class SemanticSegmentationTask(pl.LightningModule): @@ -31,17 +34,19 @@ class SemanticSegmentationTask(pl.LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" + weights = self.hyperparams["weights"] + if self.hyperparams["model"] == "unet": self.model = smp.Unet( encoder_name=self.hyperparams["backbone"], - encoder_weights=self.hyperparams["weights"], + encoder_weights="imagenet" if weights is True else None, in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) elif self.hyperparams["model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=self.hyperparams["backbone"], - encoder_weights=self.hyperparams["weights"], + encoder_weights="imagenet" if weights is True else None, in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) @@ -74,6 +79,16 @@ def config_task(self) -> None: f"Currently, supports 'ce', 'jaccard' or 'focal' loss." ) + if self.hyperparams["model"] != "fcn": + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model.encoder = utils.load_state_dict(self.model, state_dict) + def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function.