Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed May 3, 2023
1 parent dd4f56b commit 3247af7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: true
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
Expand Down
18 changes: 18 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def test_weight_str(
model_kwargs["weights"] = str(mocked_weights)
SemanticSegmentationTask(**model_kwargs)

@pytest.mark.slow
def test_weight_enum_download(
self, model_kwargs: dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = weights
SemanticSegmentationTask(**model_kwargs)

@pytest.mark.slow
def test_weight_str_download(
self, model_kwargs: dict[str, Any], weights: WeightsEnum
) -> None:
model_kwargs["backbone"] = weights.meta["model"]
model_kwargs["in_channels"] = weights.meta["in_chans"]
model_kwargs["weights"] = str(weights)
SemanticSegmentationTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not valid."
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class and used with 'ce' loss
and *freeze_decoder* parameters.
.. versionchanged:: 0.5
The *weights* parameter supports WeightEnums and checkpoint paths.
The *weights* parameter now supports WeightEnums and checkpoint paths.
"""
super().__init__()
Expand Down

0 comments on commit 3247af7

Please sign in to comment.