From 7cd3f617ae75635e7d1a4cdbc65f01e38edfda08 Mon Sep 17 00:00:00 2001 From: Konstantin Klemmer Date: Mon, 31 Jul 2023 18:51:15 -0400 Subject: [PATCH 1/9] Update pretrained_weights.ipynb Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models. --- docs/tutorials/pretrained_weights.ipynb | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index 26d97fcbc6c..a64e26da193 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -227,8 +227,26 @@ "outputs": [], "source": [ "in_chans = weights.meta[\"in_chans\"]\n", - "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n", - "model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" + "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=0)\n", + "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setting `num_classes=0` will prevent the creation of a prediction head (fully-connected layer). However, you can create a timm model with a prediction head and still match the keys of all but the last fully-connected layer, the parameters of which will be randomly initialized, when using the `strict=False` flag." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "in_chans = weights.meta[\"in_chans\"]\n", + "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=256)\n", + "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" ] }, { From 1f976cf96294a2e9c5a648eb676c47e56adb86cb Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 17 Dec 2023 21:31:22 -0800 Subject: [PATCH 2/9] Update docs/tutorials/pretrained_weights.ipynb --- docs/tutorials/pretrained_weights.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index a64e26da193..46083d4ae67 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -245,7 +245,7 @@ "outputs": [], "source": [ "in_chans = weights.meta[\"in_chans\"]\n", - "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=256)\n", + "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n", "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" ] }, From 392b30557688ce1c46ee66eaa5bd887078fc2939 Mon Sep 17 00:00:00 2001 From: Konstantin Klemmer Date: Tue, 30 Jan 2024 11:05:35 -0500 Subject: [PATCH 3/9] Update utils.py * Import Tuple from typing * Change return of `load_state_dict` from `model` to `Tuple[List[str], List[str]]`, matching the return of the standard PyTorch builtin function. --- torchgeo/trainers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index e1b6678ed73..f8527efc153 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -5,7 +5,7 @@ import warnings from collections import OrderedDict -from typing import Optional, Union, cast +from typing import Optional, Union, Tuple, cast import torch import torch.nn as nn @@ -116,7 +116,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo ) model.load_state_dict(state_dict, strict=False) - return model + return Tuple[List[str], List[str]] def reinit_initial_conv_layer( From 46afcd7a81ede7f60bab447e7b028cf649bad00e Mon Sep 17 00:00:00 2001 From: Konstantin Klemmer Date: Tue, 30 Jan 2024 11:16:01 -0500 Subject: [PATCH 4/9] Update pretrained_weights.ipynb Remove example of loading pretrained model without prediction head (`num_classes=0`). --- docs/tutorials/pretrained_weights.ipynb | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index 46083d4ae67..e15dd1ebc16 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -225,24 +225,6 @@ "id": "ZaZQ07jorMOO" }, "outputs": [], - "source": [ - "in_chans = weights.meta[\"in_chans\"]\n", - "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=0)\n", - "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Setting `num_classes=0` will prevent the creation of a prediction head (fully-connected layer). However, you can create a timm model with a prediction head and still match the keys of all but the last fully-connected layer, the parameters of which will be randomly initialized, when using the `strict=False` flag." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "in_chans = weights.meta[\"in_chans\"]\n", "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n", From 2324488dc98a0e21e1b7cd55ec801a63e7ea7764 Mon Sep 17 00:00:00 2001 From: Konstantin Klemmer Date: Tue, 30 Jan 2024 11:18:01 -0500 Subject: [PATCH 5/9] Update README.md Adapt new `load_state_dict` function. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 36ae3aa37eb..ec31fcb7753 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights weights = ResNet18_Weights.SENTINEL2_ALL_MOCO model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10) -model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False) +model.load_state_dict(weights.get_state_dict(progress=True), strict=False) ``` These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html). From 53885dcbefa5092cf8ae38c63367d7e65eca5672 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 31 Jan 2024 15:17:28 +0100 Subject: [PATCH 6/9] Mimic return type of builtin load_state_dict --- tests/trainers/test_utils.py | 6 +++--- torchgeo/trainers/byol.py | 2 +- torchgeo/trainers/classification.py | 2 +- torchgeo/trainers/moco.py | 2 +- torchgeo/trainers/regression.py | 2 +- torchgeo/trainers/simclr.py | 2 +- torchgeo/trainers/utils.py | 7 +++---- 7 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 52d7a9be25d..06da0a359eb 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None: def test_load_state_dict(checkpoint: str, model: Module) -> None: _, state_dict = extract_backbone(checkpoint) - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None: @@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) f" model {expected_in_channels}. Overriding with new input channels" ) with pytest.warns(UserWarning, match=warning): - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None: @@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None f" {expected_num_classes}. Overriding with new num classes" ) with pytest.warns(UserWarning, match=warning): - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_reinit_initial_conv_layer() -> None: diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 68bdb6c9c43..d6c0b62765e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -343,7 +343,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - backbone = utils.load_state_dict(backbone, state_dict) + utils.load_state_dict(backbone, state_dict) self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224)) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 9ac312051c7..5d8d10c9dce 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -137,7 +137,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) + utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams["freeze_backbone"]: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index d2621a8da74..4dbc1e453c9 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -261,7 +261,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.backbone = utils.load_state_dict(self.backbone, state_dict) + utils.load_state_dict(self.backbone, state_dict) # Create projection (and prediction) head batch_norm = version == 3 diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index c58f033b5bc..9cc2ea56441 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -128,7 +128,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) + utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams["freeze_backbone"]: diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index a889be1c96f..27719eda224 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -172,7 +172,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.backbone = utils.load_state_dict(self.backbone, state_dict) + utils.load_state_dict(self.backbone, state_dict) # Create projection head input_dim = self.backbone.num_features diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index f8527efc153..c6a196f596c 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -71,7 +71,7 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: return key, module -def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module: +def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Tuple[List[str], List[str]]: """Load pretrained resnet weights to a model. Args: @@ -79,7 +79,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo state_dict: dict containing tensor parameters Returns: - the model with pretrained weights + The missing and unexpected keys Warns: If input channels in model != pretrained model input channels @@ -115,8 +115,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo state_dict[output_module_key + ".bias"], ) - model.load_state_dict(state_dict, strict=False) - return Tuple[List[str], List[str]] + return model.load_state_dict(state_dict, strict=False) def reinit_initial_conv_layer( From fd19ab18c3a1a31d643cd4b15862d2d10530c521 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 31 Jan 2024 15:18:52 +0100 Subject: [PATCH 7/9] Modern type hints --- torchgeo/trainers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index c6a196f596c..0b6d3cd1fb9 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -5,7 +5,7 @@ import warnings from collections import OrderedDict -from typing import Optional, Union, Tuple, cast +from typing import Optional, Union, cast import torch import torch.nn as nn @@ -71,7 +71,7 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: return key, module -def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Tuple[List[str], List[str]]: +def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> tuple[list[str], list[str]]: """Load pretrained resnet weights to a model. Args: From 35b914c91589bd8a414a2fde8bab413ee61a58a7 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 31 Jan 2024 15:20:24 +0100 Subject: [PATCH 8/9] Blacken --- torchgeo/trainers/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 0b6d3cd1fb9..7a54f936cb4 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -71,7 +71,9 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: return key, module -def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> tuple[list[str], list[str]]: +def load_state_dict( + model: Module, state_dict: "OrderedDict[str, Tensor]" +) -> tuple[list[str], list[str]]: """Load pretrained resnet weights to a model. Args: From b5a9007c70d5c4e91a9da49ac11a862e48898fca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 31 Jan 2024 15:31:34 +0100 Subject: [PATCH 9/9] Try being explicit --- torchgeo/trainers/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 7a54f936cb4..b5cd8f1e923 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -117,7 +117,10 @@ def load_state_dict( state_dict[output_module_key + ".bias"], ) - return model.load_state_dict(state_dict, strict=False) + missing_keys: list[str] + unexpected_keys: list[str] + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + return missing_keys, unexpected_keys def reinit_initial_conv_layer(