Skip to content

Commit

Permalink
add support to pretrained encoder weights (#1306)
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley authored May 4, 2023
1 parent 113f385 commit d192207
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
65 changes: 65 additions & 0 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,71 @@ def model_kwargs(self) -> dict[str, Any]:
"learning_rate_schedule_patience": 6,
}

@pytest.fixture(
params=[
weights
for model in list_models()
for weights in get_model_weights(model)
if "resnet" in weights.meta["model"]
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
PixelwiseRegressionTask(**model_kwargs)

def test_weight_enum(
self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = mocked_weights
PixelwiseRegressionTask(**model_kwargs)

def test_weight_str(
self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = str(mocked_weights)
PixelwiseRegressionTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(
self, model_kwargs: dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = weights
PixelwiseRegressionTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(
self, model_kwargs: dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = str(weights)
PixelwiseRegressionTask(**model_kwargs)

@pytest.mark.parametrize(
"backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"]
)
Expand Down
16 changes: 14 additions & 2 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,19 @@ class PixelwiseRegressionTask(RegressionTask):

def config_model(self) -> None:
"""Configures the model based on kwargs parameters."""
weights = self.hyperparams["weights"]

if self.hyperparams["model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=1,
)
elif self.hyperparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["backbone"],
encoder_weights=self.hyperparams["weights"],
encoder_weights="imagenet" if weights is True else None,
in_channels=self.hyperparams["in_channels"],
classes=1,
)
Expand All @@ -306,6 +308,16 @@ def config_model(self) -> None:
f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)

if self.hyperparams["model"] != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)

# Freeze backbone
if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
"model"
Expand Down

0 comments on commit d192207

Please sign in to comment.