From 3247af7e28a1ec6cac11c04c5d9260c5b3c1a19e Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 3 May 2023 20:47:53 +0000 Subject: [PATCH] add tests --- tests/conf/inria.yaml | 2 +- tests/trainers/test_segmentation.py | 18 ++++++++++++++++++ torchgeo/trainers/segmentation.py | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index 7a47124bee5..df4f4043fc4 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -3,7 +3,7 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: true + weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index c04aa99e83a..7fd7badcd34 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -169,6 +169,24 @@ def test_weight_str( model_kwargs["weights"] = str(mocked_weights) SemanticSegmentationTask(**model_kwargs) + @pytest.mark.slow + def test_weight_enum_download( + self, model_kwargs: dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = weights + SemanticSegmentationTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_str_download( + self, model_kwargs: dict[str, Any], weights: WeightsEnum + ) -> None: + model_kwargs["backbone"] = weights.meta["model"] + model_kwargs["in_channels"] = weights.meta["in_chans"] + model_kwargs["weights"] = str(weights) + SemanticSegmentationTask(**model_kwargs) + def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not valid." diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 575a9b2af6d..883110fb8d1 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -145,7 +145,7 @@ class and used with 'ce' loss and *freeze_decoder* parameters. .. versionchanged:: 0.5 - The *weights* parameter supports WeightEnums and checkpoint paths. + The *weights* parameter now supports WeightEnums and checkpoint paths. """ super().__init__()