From 6b17ffef08898f8bab355a41a637da5f5699a4b8 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 29 Dec 2022 19:46:06 -0800 Subject: [PATCH] Add LoveDADataModule to the trainer tests (#966) * Add loveda to trainer tests * Delete direct loveda datamodule test * Ignoring deprecation warning for lightning * Remove ignore * test -> predict * Fix typo * Add comment explaining mismatch * More coverage Co-authored-by: Adam J. Stewart --- tests/conf/loveda.yaml | 19 ++++++++++++ tests/datamodules/test_loveda.py | 46 ----------------------------- tests/trainers/test_segmentation.py | 6 +++- torchgeo/datamodules/loveda.py | 17 ++++++----- train.py | 2 ++ 5 files changed, 35 insertions(+), 55 deletions(-) create mode 100644 tests/conf/loveda.yaml delete mode 100644 tests/datamodules/test_loveda.py diff --git a/tests/conf/loveda.yaml b/tests/conf/loveda.yaml new file mode 100644 index 00000000000..df062a0e600 --- /dev/null +++ b/tests/conf/loveda.yaml @@ -0,0 +1,19 @@ +experiment: + task: "loveda" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 8 + num_filters: 1 + ignore_index: null + datamodule: + root: "tests/data/loveda" + download: true + batch_size: 1 + num_workers: 0 diff --git a/tests/datamodules/test_loveda.py b/tests/datamodules/test_loveda.py deleted file mode 100644 index 92457ebff64..00000000000 --- a/tests/datamodules/test_loveda.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import matplotlib.pyplot as plt -import pytest - -from torchgeo.datamodules import LoveDADataModule -from torchgeo.datasets import unbind_samples - - -class TestLoveDADataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> LoveDADataModule: - root = os.path.join("tests", "data", "loveda") - batch_size = 2 - num_workers = 0 - scene = ["rural", "urban"] - - dm = LoveDADataModule( - root=root, - scene=scene, - batch_size=batch_size, - num_workers=num_workers, - download=True, - ) - - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.test_dataloader())) - - def test_plot(self, datamodule: LoveDADataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) - sample = unbind_samples(batch)[0] - datamodule.plot(sample) - plt.close() diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 73245f6c717..6c9b1f2ff69 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -17,6 +17,7 @@ ETCI2021DataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, + LoveDADataModule, NAIPChesapeakeDataModule, SEN12MSDataModule, SpaceNet1DataModule, @@ -42,6 +43,7 @@ class TestSemanticSegmentationTask: ("inria_val", InriaAerialImageLabelingDataModule), ("inria_test", InriaAerialImageLabelingDataModule), ("landcoverai", LandCoverAIDataModule), + ("loveda", LoveDADataModule), ("naipchesapeake", NAIPChesapeakeDataModule), ("sen12ms_all", SEN12MSDataModule), ("sen12ms_s1", SEN12MSDataModule), @@ -77,7 +79,9 @@ def test_trainer( # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) + + if hasattr(datamodule, "test_dataset") or hasattr(datamodule, "test_sampler"): + trainer.test(model=model, datamodule=datamodule) if hasattr(datamodule, "predict_dataset"): trainer.predict(model=model, datamodule=datamodule) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index d61757f31ea..a82233f9844 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -67,18 +67,19 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ train_transforms = self.preprocess - val_test_transforms = self.preprocess + val_predict_transforms = self.preprocess self.train_dataset = LoveDA( split="train", transforms=train_transforms, **self.kwargs ) self.val_dataset = LoveDA( - split="val", transforms=val_test_transforms, **self.kwargs + split="val", transforms=val_predict_transforms, **self.kwargs ) - self.test_dataset = LoveDA( - split="test", transforms=val_test_transforms, **self.kwargs + # Test set masks are not public, use for prediction instead + self.predict_dataset = LoveDA( + split="test", transforms=val_predict_transforms, **self.kwargs ) def train_dataloader(self) -> DataLoader[Any]: @@ -107,14 +108,14 @@ def val_dataloader(self) -> DataLoader[Any]: shuffle=False, ) - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. + def predict_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for prediction. Returns: - testing data loader + predict data loader """ return DataLoader( - self.test_dataset, + self.predict_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, diff --git a/train.py b/train.py index 0c9fd72a3e0..cb001437686 100755 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ EuroSATDataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, + LoveDADataModule, NAIPChesapeakeDataModule, NASAMarineDebrisDataModule, RESISC45DataModule, @@ -51,6 +52,7 @@ "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), "landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), + "loveda": (SemanticSegmentationTask, LoveDADataModule), "naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), "nasa_marine_debris": (ObjectDetectionTask, NASAMarineDebrisDataModule), "resisc45": (ClassificationTask, RESISC45DataModule),