Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load_state_dict for all timm models #1084

Merged
merged 10 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
Expand All @@ -18,7 +19,7 @@

from torchgeo.datamodules import ChesapeakeCVPRDataModule, MisconfigurationException
from torchgeo.datasets import ChesapeakeCVPR
from torchgeo.models import ResNet18_Weights
from torchgeo.models import get_model_weights, list_models
from torchgeo.samplers import GridGeoSampler
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
Expand Down Expand Up @@ -97,34 +98,74 @@ def test_trainer(

@pytest.fixture
def model_kwargs(self) -> Dict[str, Any]:
return {"backbone": "resnet18", "weights": None, "in_channels": 3}
return {
"backbone": "resnet18",
"in_channels": 13,
"loss": "ce",
"num_classes": 10,
"weights": None,
}

@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
BYOLTask(**model_kwargs)
with pytest.warns(UserWarning):
BYOLTask(**model_kwargs)

def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = mocked_weights
BYOLTask(**model_kwargs)

def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = str(mocked_weights)
BYOLTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(
self, model_kwargs: Dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = weights
BYOLTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(
self, model_kwargs: Dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = str(weights)
BYOLTask(**model_kwargs)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictBYOLDataModule(
root="tests/data/chesapeake/cvpr",
Expand Down
42 changes: 38 additions & 4 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
Expand All @@ -24,7 +25,7 @@
UCMercedDataModule,
)
from torchgeo.datasets import BigEarthNet, EuroSAT
from torchgeo.models import ResNet18_Weights
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask

from .test_utils import ClassificationTestModel
Expand Down Expand Up @@ -110,11 +111,22 @@ def model_kwargs(self) -> Dict[str, Any]:
"weights": None,
}

@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
def mocked_weights(
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
Expand All @@ -128,17 +140,39 @@ def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> Non
def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["model"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = mocked_weights
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)

def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["model"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = str(mocked_weights)
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(
self, model_kwargs: Dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["model"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = weights
ClassificationTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(
self, model_kwargs: Dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["model"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = str(weights)
ClassificationTask(**model_kwargs)

def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
Expand Down
42 changes: 38 additions & 4 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
Expand All @@ -20,7 +21,7 @@
TropicalCycloneDataModule,
)
from torchgeo.datasets import TropicalCyclone
from torchgeo.models import ResNet18_Weights
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import RegressionTask

from .test_utils import RegressionTestModel
Expand Down Expand Up @@ -86,11 +87,22 @@ def model_kwargs(self) -> Dict[str, Any]:
"in_channels": 3,
}

@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
Expand All @@ -104,17 +116,39 @@ def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> Non
def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["model"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = mocked_weights
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)

def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["model"] = mocked_weights.meta["model"]
model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
model_kwargs["weights"] = str(mocked_weights)
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(
self, model_kwargs: Dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["model"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = weights
RegressionTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(
self, model_kwargs: Dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["model"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = str(weights)
RegressionTask(**model_kwargs)

def test_no_rgb(
self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool
) -> None:
Expand Down
9 changes: 9 additions & 0 deletions tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import Any, cast

import pytest
import timm
import torch
import torch.nn as nn
from torch.nn.modules import Module

from torchgeo.trainers.utils import (
_get_input_layer_name_and_module,
extract_backbone,
load_state_dict,
reinit_initial_conv_layer,
Expand Down Expand Up @@ -115,3 +117,10 @@ def test_reinit_initial_conv_layer() -> None:
assert in_channels == 4
assert k1 == 3 and k2 == 3
assert new_conv_layer.stride[0] == 2


def test_get_input_layer_name_and_module() -> None:
key, module = _get_input_layer_name_and_module(timm.create_model("resnet18"))
assert key == "conv1"
assert isinstance(module, nn.Conv2d)
assert module.in_channels == 3
2 changes: 1 addition & 1 deletion torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def config_task(self) -> None:
state_dict = get_weight(weights).get_state_dict(progress=True)
backbone = utils.load_state_dict(backbone, state_dict)

self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256))
self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224))

def __init__(self, **kwargs: Any) -> None:
"""Initialize a LightningModule for pre-training a model with BYOL.
Expand Down
36 changes: 29 additions & 7 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ def extract_backbone(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]:
return name, state_dict


def _get_input_layer_name_and_module(model: Module) -> Tuple[str, Module]:
"""Retrieve the input layer name and module from a timm model.

Args:
model: timm model
"""
keys = []
children = list(model.named_children())
while children != []:
name, module = children[0]
keys.append(name)
children = list(module.named_children())

key = ".".join(keys)
return key, module


def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module:
"""Load pretrained resnet weights to a model.

Expand All @@ -68,27 +85,32 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
If input channels in model != pretrained model input channels
If num output classes in model != pretrained model num classes
"""
in_channels = cast(nn.Module, model.conv1).in_channels
expected_in_channels = state_dict["conv1.weight"].shape[1]
input_module_key, input_module = _get_input_layer_name_and_module(model)
in_channels = input_module.in_channels
expected_in_channels = state_dict[input_module_key + ".weight"].shape[1]

num_classes = cast(nn.Module, model.fc).out_features
output_module_key, output_module = list(model.named_children())[-1]
num_classes = output_module.out_features
expected_num_classes = None
if "fc.weight" in state_dict:
expected_num_classes = state_dict["fc.weight"].shape[0]
if output_module_key + ".weight" in state_dict:
expected_num_classes = state_dict[output_module_key + ".weight"].shape[0]

if in_channels != expected_in_channels:
warnings.warn(
f"input channels {in_channels} != input channels in pretrained"
f" model {expected_in_channels}. Overriding with new input channels"
)
del state_dict["conv1.weight"]
del state_dict[input_module_key + ".weight"]

if expected_num_classes and num_classes != expected_num_classes:
warnings.warn(
f"num classes {num_classes} != num classes in pretrained model"
f" {expected_num_classes}. Overriding with new num classes"
)
del state_dict["fc.weight"], state_dict["fc.bias"]
del (
state_dict[output_module_key + ".weight"],
state_dict[output_module_key + ".bias"],
)

model.load_state_dict(state_dict, strict=False)
return model
Expand Down