Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Pretrained Weights #1046

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_w
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights="imagenet",
weights=True,
in_channels=3,
num_classes=2,
loss="ce",
Expand Down
2 changes: 1 addition & 1 deletion conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 6
Expand Down
2 changes: 1 addition & 1 deletion conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "deeplabv3+"
backbone: "resnet34"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 2
in_channels: 4
Expand Down
2 changes: 1 addition & 1 deletion conf/spacenet1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
weights: true
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
55 changes: 55 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
from torchgeo.datasets import LandCoverAI
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import SemanticSegmentationTask


Expand All @@ -34,6 +40,11 @@ def create_model(**kwargs: Any) -> Module:
return SegmentationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
raise ValueError

Expand Down Expand Up @@ -111,6 +122,50 @@ def model_kwargs(self) -> dict[Any, Any]:
"ignore_index": 0,
}

@pytest.fixture(
params=[
weights
for model in list_models()
for weights in get_model_weights(model)
if "resnet" in weights.meta["model"]
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint

def test_weight_enum(
self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = mocked_weights

def test_weight_str(
self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = str(mocked_weights)

def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not valid."
Expand Down
30 changes: 25 additions & 5 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Segmentation tasks."""

import os
import warnings
from typing import Any, cast

Expand All @@ -15,9 +16,11 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum

from ..datasets.utils import unbind_samples
from ..models import FCN
from ..models import FCN, get_weight
from . import utils


class SemanticSegmentationTask(LightningModule): # type: ignore[misc]
Expand All @@ -31,17 +34,19 @@ class SemanticSegmentationTask(LightningModule): # type: ignore[misc]

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
weights = self.hyperparams["weights"]

if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
Expand Down Expand Up @@ -80,6 +85,20 @@ def config_task(self) -> None:
f"Currently, supports 'ce', 'jaccard' or 'focal' loss."
)

if self.hyperparams["model"] != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
self.model.encoder = utils.load_state_dict(
self.model.encoder, state_dict
)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)

# Freeze backbone
if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
"model"
Expand All @@ -100,8 +119,9 @@ def __init__(self, **kwargs: Any) -> None:
Keyword Args:
model: Name of the segmentation model type to use
backbone: Name of the timm backbone to use
weights: None or "imagenet" to use imagenet pretrained weights in
the backbone
weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
in_channels: Number of channels in input image
num_classes: Number of semantic classes to predict
loss: Name of the loss function, currently supports
Expand Down