Skip to content

Commit

Permalink
add pretrained weights loading for the segmentation encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley authored and calebrob6 committed Feb 23, 2023
1 parent e81af42 commit 86a0716
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Segmentation tasks."""

import os
import warnings
from typing import Any, Dict, cast

Expand All @@ -15,9 +16,11 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum

from ..datasets.utils import unbind_samples
from ..models import FCN
from ..models import FCN, get_weight
from . import utils


class SemanticSegmentationTask(pl.LightningModule):
Expand All @@ -31,17 +34,19 @@ class SemanticSegmentationTask(pl.LightningModule):

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
weights = self.hyperparams["weights"]

if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
Expand Down Expand Up @@ -74,6 +79,16 @@ def config_task(self) -> None:
f"Currently, supports 'ce', 'jaccard' or 'focal' loss."
)

if self.hyperparams["model"] != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder = utils.load_state_dict(self.model, state_dict)

def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.
Expand Down

0 comments on commit 86a0716

Please sign in to comment.