diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 2f5dbd8d2a4..51c5cfc6fc6 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn import torchvision +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer @@ -18,7 +19,7 @@ from torchgeo.datamodules import ChesapeakeCVPRDataModule, MisconfigurationException from torchgeo.datasets import ChesapeakeCVPR -from torchgeo.models import ResNet18_Weights +from torchgeo.models import get_model_weights, list_models from torchgeo.samplers import GridGeoSampler from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation @@ -97,13 +98,30 @@ def test_trainer( @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: - return {"backbone": "resnet18", "weights": None, "in_channels": 3} + return { + "backbone": "resnet18", + "in_channels": 13, + "loss": "ce", + "num_classes": 10, + "weights": None, + } + + @pytest.fixture( + params=[ + weights for model in list_models() for weights in get_model_weights(model) + ] + ) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param @pytest.fixture - def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: - weights = ResNet18_Weights.SENTINEL2_RGB_MOCO + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: path = tmp_path / f"{weights}.pth" - model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) torch.save(model.state_dict(), path) monkeypatch.setattr(weights, "url", str(path)) monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) @@ -111,20 +129,43 @@ def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnu def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint - BYOLTask(**model_kwargs) + with pytest.warns(UserWarning): + BYOLTask(**model_kwargs) def test_weight_enum( self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: + model_kwargs["backbone"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = mocked_weights BYOLTask(**model_kwargs) def test_weight_str( self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: + model_kwargs["backbone"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = str(mocked_weights) BYOLTask(**model_kwargs) + @pytest.mark.slow + def test_weight_enum_download( + self, model_kwargs: Dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = weights + BYOLTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_str_download( + self, model_kwargs: Dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = str(weights) + BYOLTask(**model_kwargs) + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictBYOLDataModule( root="tests/data/chesapeake/cvpr", diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 6f8596da35b..15c281d42e4 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -9,6 +9,7 @@ import timm import torch import torchvision +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer @@ -24,7 +25,7 @@ UCMercedDataModule, ) from torchgeo.datasets import BigEarthNet, EuroSAT -from torchgeo.models import ResNet18_Weights +from torchgeo.models import get_model_weights, list_models from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import ClassificationTestModel @@ -110,11 +111,22 @@ def model_kwargs(self) -> Dict[str, Any]: "weights": None, } + @pytest.fixture( + params=[ + weights for model in list_models() for weights in get_model_weights(model) + ] + ) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + @pytest.fixture - def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: - weights = ResNet18_Weights.SENTINEL2_ALL_MOCO + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: path = tmp_path / f"{weights}.pth" - model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) torch.save(model.state_dict(), path) monkeypatch.setattr(weights, "url", str(path)) monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) @@ -128,6 +140,8 @@ def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> Non def test_weight_enum( self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: + model_kwargs["model"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = mocked_weights with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) @@ -135,10 +149,30 @@ def test_weight_enum( def test_weight_str( self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: + model_kwargs["model"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = str(mocked_weights) with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) + @pytest.mark.slow + def test_weight_enum_download( + self, model_kwargs: Dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["model"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = weights + ClassificationTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_str_download( + self, model_kwargs: Dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["model"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = str(weights) + ClassificationTask(**model_kwargs) + def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 65fbeabfeca..3262af2accc 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -9,6 +9,7 @@ import timm import torch import torchvision +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer @@ -20,7 +21,7 @@ TropicalCycloneDataModule, ) from torchgeo.datasets import TropicalCyclone -from torchgeo.models import ResNet18_Weights +from torchgeo.models import get_model_weights, list_models from torchgeo.trainers import RegressionTask from .test_utils import RegressionTestModel @@ -86,11 +87,22 @@ def model_kwargs(self) -> Dict[str, Any]: "in_channels": 3, } + @pytest.fixture( + params=[ + weights for model in list_models() for weights in get_model_weights(model) + ] + ) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + @pytest.fixture - def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: - weights = ResNet18_Weights.SENTINEL2_RGB_MOCO + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: path = tmp_path / f"{weights}.pth" - model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) torch.save(model.state_dict(), path) monkeypatch.setattr(weights, "url", str(path)) monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) @@ -104,6 +116,8 @@ def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> Non def test_weight_enum( self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: + model_kwargs["model"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = mocked_weights with pytest.warns(UserWarning): RegressionTask(**model_kwargs) @@ -111,10 +125,30 @@ def test_weight_enum( def test_weight_str( self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: + model_kwargs["model"] = mocked_weights.meta["model"] + model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] model_kwargs["weights"] = str(mocked_weights) with pytest.warns(UserWarning): RegressionTask(**model_kwargs) + @pytest.mark.slow + def test_weight_enum_download( + self, model_kwargs: Dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["model"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = weights + RegressionTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_str_download( + self, model_kwargs: Dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["model"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = str(weights) + RegressionTask(**model_kwargs) + def test_no_rgb( self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 953a7c0fb7e..05c17d7e4ac 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -6,11 +6,13 @@ from typing import Any, cast import pytest +import timm import torch import torch.nn as nn from torch.nn.modules import Module from torchgeo.trainers.utils import ( + _get_input_layer_name_and_module, extract_backbone, load_state_dict, reinit_initial_conv_layer, @@ -115,3 +117,10 @@ def test_reinit_initial_conv_layer() -> None: assert in_channels == 4 assert k1 == 3 and k2 == 3 assert new_conv_layer.stride[0] == 2 + + +def test_get_input_layer_name_and_module() -> None: + key, module = _get_input_layer_name_and_module(timm.create_model("resnet18")) + assert key == "conv1" + assert isinstance(module, nn.Conv2d) + assert module.in_channels == 3 diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 3b9e8d0db3e..d33c332a22b 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -313,7 +313,7 @@ def config_task(self) -> None: state_dict = get_weight(weights).get_state_dict(progress=True) backbone = utils.load_state_dict(backbone, state_dict) - self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256)) + self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224)) def __init__(self, **kwargs: Any) -> None: """Initialize a LightningModule for pre-training a model with BYOL. diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index cc38b4b6f9d..cf7a8f89d0d 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -54,6 +54,23 @@ def extract_backbone(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: return name, state_dict +def _get_input_layer_name_and_module(model: Module) -> Tuple[str, Module]: + """Retrieve the input layer name and module from a timm model. + + Args: + model: timm model + """ + keys = [] + children = list(model.named_children()) + while children != []: + name, module = children[0] + keys.append(name) + children = list(module.named_children()) + + key = ".".join(keys) + return key, module + + def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module: """Load pretrained resnet weights to a model. @@ -68,27 +85,32 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo If input channels in model != pretrained model input channels If num output classes in model != pretrained model num classes """ - in_channels = cast(nn.Module, model.conv1).in_channels - expected_in_channels = state_dict["conv1.weight"].shape[1] + input_module_key, input_module = _get_input_layer_name_and_module(model) + in_channels = input_module.in_channels + expected_in_channels = state_dict[input_module_key + ".weight"].shape[1] - num_classes = cast(nn.Module, model.fc).out_features + output_module_key, output_module = list(model.named_children())[-1] + num_classes = output_module.out_features expected_num_classes = None - if "fc.weight" in state_dict: - expected_num_classes = state_dict["fc.weight"].shape[0] + if output_module_key + ".weight" in state_dict: + expected_num_classes = state_dict[output_module_key + ".weight"].shape[0] if in_channels != expected_in_channels: warnings.warn( f"input channels {in_channels} != input channels in pretrained" f" model {expected_in_channels}. Overriding with new input channels" ) - del state_dict["conv1.weight"] + del state_dict[input_module_key + ".weight"] if expected_num_classes and num_classes != expected_num_classes: warnings.warn( f"num classes {num_classes} != num classes in pretrained model" f" {expected_num_classes}. Overriding with new num classes" ) - del state_dict["fc.weight"], state_dict["fc.bias"] + del ( + state_dict[output_module_key + ".weight"], + state_dict[output_module_key + ".bias"], + ) model.load_state_dict(state_dict, strict=False) return model