Skip to content

Commit

Permalink
Add LoveDADataModule to the trainer tests (microsoft#966)
Browse files Browse the repository at this point in the history
* Add loveda to trainer tests

* Delete direct loveda datamodule test

* Ignoring deprecation warning for lightning

* Remove ignore

* test -> predict

* Fix typo

* Add comment explaining mismatch

* More coverage

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
calebrob6 and adamjstewart authored Dec 30, 2022
1 parent e7569a8 commit 6b17ffe
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 55 deletions.
19 changes: 19 additions & 0 deletions tests/conf/loveda.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
experiment:
task: "loveda"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 8
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/loveda"
download: true
batch_size: 1
num_workers: 0
46 changes: 0 additions & 46 deletions tests/datamodules/test_loveda.py

This file was deleted.

6 changes: 5 additions & 1 deletion tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ETCI2021DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
NAIPChesapeakeDataModule,
SEN12MSDataModule,
SpaceNet1DataModule,
Expand All @@ -42,6 +43,7 @@ class TestSemanticSegmentationTask:
("inria_val", InriaAerialImageLabelingDataModule),
("inria_test", InriaAerialImageLabelingDataModule),
("landcoverai", LandCoverAIDataModule),
("loveda", LoveDADataModule),
("naipchesapeake", NAIPChesapeakeDataModule),
("sen12ms_all", SEN12MSDataModule),
("sen12ms_s1", SEN12MSDataModule),
Expand Down Expand Up @@ -77,7 +79,9 @@ def test_trainer(
# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)

if hasattr(datamodule, "test_dataset") or hasattr(datamodule, "test_sampler"):
trainer.test(model=model, datamodule=datamodule)

if hasattr(datamodule, "predict_dataset"):
trainer.predict(model=model, datamodule=datamodule)
Expand Down
17 changes: 9 additions & 8 deletions torchgeo/datamodules/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,19 @@ def setup(self, stage: Optional[str] = None) -> None:
stage: stage to set up
"""
train_transforms = self.preprocess
val_test_transforms = self.preprocess
val_predict_transforms = self.preprocess

self.train_dataset = LoveDA(
split="train", transforms=train_transforms, **self.kwargs
)

self.val_dataset = LoveDA(
split="val", transforms=val_test_transforms, **self.kwargs
split="val", transforms=val_predict_transforms, **self.kwargs
)

self.test_dataset = LoveDA(
split="test", transforms=val_test_transforms, **self.kwargs
# Test set masks are not public, use for prediction instead
self.predict_dataset = LoveDA(
split="test", transforms=val_predict_transforms, **self.kwargs
)

def train_dataloader(self) -> DataLoader[Any]:
Expand Down Expand Up @@ -107,14 +108,14 @@ def val_dataloader(self) -> DataLoader[Any]:
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
def predict_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for prediction.
Returns:
testing data loader
predict data loader
"""
return DataLoader(
self.test_dataset,
self.predict_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EuroSATDataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
NAIPChesapeakeDataModule,
NASAMarineDebrisDataModule,
RESISC45DataModule,
Expand Down Expand Up @@ -51,6 +52,7 @@
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule),
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
"loveda": (SemanticSegmentationTask, LoveDADataModule),
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule),
"nasa_marine_debris": (ObjectDetectionTask, NASAMarineDebrisDataModule),
"resisc45": (ClassificationTask, RESISC45DataModule),
Expand Down

0 comments on commit 6b17ffe

Please sign in to comment.