Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_state_dict does not return the model #1503

Merged
merged 11 commits into from
Feb 6, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
```

These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html).
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/pretrained_weights.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n",
"model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None:

def test_load_state_dict(checkpoint: str, model: Module) -> None:
_, state_dict = extract_backbone(checkpoint)
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)


def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None:
Expand All @@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module)
f" model {expected_in_channels}. Overriding with new input channels"
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)


def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None:
Expand All @@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None
f" {expected_num_classes}. Overriding with new num classes"
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)


def test_reinit_initial_conv_layer() -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
backbone = utils.load_state_dict(backbone, state_dict)
utils.load_state_dict(backbone, state_dict)

self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224))

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
utils.load_state_dict(self.model, state_dict)

# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
utils.load_state_dict(self.backbone, state_dict)

# Create projection (and prediction) head
batch_norm = version == 3
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
utils.load_state_dict(self.model, state_dict)

# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
utils.load_state_dict(self.backbone, state_dict)

# Create projection head
input_dim = self.backbone.num_features
Expand Down
12 changes: 8 additions & 4 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]:
return key, module


def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module:
def load_state_dict(
model: Module, state_dict: "OrderedDict[str, Tensor]"
) -> tuple[list[str], list[str]]:
"""Load pretrained resnet weights to a model.

Args:
model: model to load the pretrained weights to
state_dict: dict containing tensor parameters

Returns:
the model with pretrained weights
The missing and unexpected keys

Warns:
If input channels in model != pretrained model input channels
Expand Down Expand Up @@ -115,8 +117,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
state_dict[output_module_key + ".bias"],
)

model.load_state_dict(state_dict, strict=False)
return model
missing_keys: list[str]
unexpected_keys: list[str]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
return missing_keys, unexpected_keys


def reinit_initial_conv_layer(
Expand Down
Loading