From 3c8ee4ce30da191f78eb4f46ee7ea8ab277815b1 Mon Sep 17 00:00:00 2001 From: BAHL Gaetan Date: Tue, 3 May 2022 11:13:46 +0200 Subject: [PATCH 1/2] Fix loading encoder weights trained with BYOL --- torchgeo/trainers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 8e04815c937..7cde2b01d6f 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -42,8 +42,8 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: state_dict = OrderedDict( {k.replace("model.", ""): v for k, v in state_dict.items()} ) - elif "encoder" in checkpoint["hyper_parameters"]: - name = checkpoint["hyper_parameters"]["encoder"] + elif "encoder_name" in checkpoint["hyper_parameters"]: + name = checkpoint["hyper_parameters"]["encoder_name"] state_dict = checkpoint["state_dict"] state_dict = OrderedDict( {k: v for k, v in state_dict.items() if "model.encoder.model" in k} From 310b351df8dc91b5f5fca6ae02e3e912b85c2826 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 13 Jun 2022 09:34:30 -0700 Subject: [PATCH 2/2] Update conftest.py --- tests/trainers/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index 287c03b868a..3da741c3013 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -25,7 +25,7 @@ def state_dict(model: Module) -> Dict[str, Tensor]: return model.state_dict() -@pytest.fixture(params=["classification_model", "encoder"]) +@pytest.fixture(params=["classification_model", "encoder_name"]) def checkpoint( state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path ) -> str: