From 7efed586b08b8ec0e8eee0bc022306e8bd129188 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 11:51:32 -0500 Subject: [PATCH 01/17] Add plot method to all DataModules --- torchgeo/datamodules/cyclone.py | 7 +++++++ torchgeo/datamodules/deepglobelandcover.py | 7 +++++++ torchgeo/datamodules/fair1m.py | 11 +++++++++-- torchgeo/datamodules/loveda.py | 7 +++++++ torchgeo/datamodules/naip.py | 20 ++++++++++++++++---- torchgeo/datamodules/oscd.py | 7 +++++++ torchgeo/datamodules/potsdam.py | 7 +++++++ torchgeo/datamodules/sen12ms.py | 7 +++++++ torchgeo/datamodules/so2sat.py | 7 +++++++ torchgeo/datamodules/usavars.py | 7 +++++++ torchgeo/datamodules/vaihingen.py | 7 +++++++ torchgeo/datamodules/xview.py | 7 +++++++ 12 files changed, 95 insertions(+), 6 deletions(-) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index d5cbd02450c..f04bac3dd3a 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -148,3 +148,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.TropicalCycloneWindEstimation.plot`. + + .. versionadded:: 0.4 + """ + return self.all_train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 1a5de89aee4..acce2b6ebdb 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -122,3 +122,10 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index a3927537bb7..57d62c2b211 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -84,9 +84,9 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - dataset = FAIR1M(transforms=self.preprocess, **self.kwargs) + self.dataset = FAIR1M(transforms=self.preprocess, **self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) def train_dataloader(self) -> DataLoader[Any]: @@ -130,3 +130,10 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, collate_fn=collate_fn, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.FAIR1M.plot`. + + .. versionadded:: 0.4 + """ + return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index d4ec7bd16fc..578c2bf5d50 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -121,3 +121,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.LoveDA.plot`. + + .. versionadded:: 0.4 + """ + return self.train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index c15debc8577..ac0565330ef 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -3,7 +3,7 @@ """National Agriculture Imagery Program (NAIP) datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import pytorch_lightning as pl from torch.utils.data import DataLoader @@ -117,17 +117,17 @@ def setup(self, stage: Optional[str] = None) -> None: naip_transforms = Compose([self.preprocess, self.remove_bbox]) chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox]) - chesapeake = Chesapeake13( + self.chesapeake = Chesapeake13( self.chesapeake_root, transforms=chesapeak_transforms, **self.kwargs ) - naip = NAIP( + self.naip = NAIP( self.naip_root, chesapeake.crs, chesapeake.res, transforms=naip_transforms, **self.kwargs, ) - self.dataset = chesapeake & naip + self.dataset = self.chesapeake & self.naip # TODO: figure out better train/val/test split roi = self.dataset.bounds @@ -183,3 +183,15 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, collate_fn=stack_samples, ) + + def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: + """Run NAIP and Chesapeake plot methods. + + See :meth:`torchgeo.datasets.NAIP.plot` and + :meth:`torchgeo.datasets.Chesapeake.plot`. + + .. versionadded:: 0.4 + """ + image = self.naip.plot(*args, **kwargs) + label = self.chesapeake.plot(*args, **kwargs) + return image, label diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index e604eed767c..9499aa28547 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -196,3 +196,10 @@ def test_dataloader(self) -> DataLoader[Any]: return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.OSCD.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 215f6f03f19..56351ea67be 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -121,3 +121,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.Potsdam2D.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 61529cc79b1..63452909d3a 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -191,3 +191,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.SEN12MS.plot`. + + .. versionadded:: 0.4 + """ + return self.all_train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 5cee144d777..447ecf5a4e1 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -196,3 +196,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.So2Sat.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index d7c6e119d3c..2787526607f 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -96,3 +96,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.USAVars.plot`. + + .. versionadded:: 0.4 + """ + return self.train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 5f534c027e3..8e8082a953b 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -121,3 +121,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.Vaihingen2D.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index c23fe80826b..40b8f0809fe 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -119,3 +119,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.XView2.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) From 8da46649c9457196bff637de154916c47b7c6bef Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 11:58:18 -0500 Subject: [PATCH 02/17] Add missing imports --- torchgeo/datamodules/cyclone.py | 1 + torchgeo/datamodules/deepglobelandcover.py | 1 + torchgeo/datamodules/fair1m.py | 1 + torchgeo/datamodules/loveda.py | 1 + torchgeo/datamodules/naip.py | 11 ++++++----- torchgeo/datamodules/oscd.py | 1 + torchgeo/datamodules/potsdam.py | 1 + torchgeo/datamodules/sen12ms.py | 1 + torchgeo/datamodules/so2sat.py | 1 + torchgeo/datamodules/usavars.py | 1 + torchgeo/datamodules/vaihingen.py | 1 + torchgeo/datamodules/xview.py | 1 + 12 files changed, 17 insertions(+), 5 deletions(-) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index f04bac3dd3a..909853a90d0 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index acce2b6ebdb..955467048e3 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 57d62c2b211..88cc9063326 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch import Tensor diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 578c2bf5d50..f4f1dc39c28 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index ac0565330ef..5d155492102 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader from torchvision.transforms import Compose @@ -122,8 +123,8 @@ def setup(self, stage: Optional[str] = None) -> None: ) self.naip = NAIP( self.naip_root, - chesapeake.crs, - chesapeake.res, + self.chesapeake.crs, + self.chesapeake.res, transforms=naip_transforms, **self.kwargs, ) @@ -138,10 +139,10 @@ def setup(self, stage: Optional[str] = None) -> None: test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) self.train_sampler = RandomBatchGeoSampler( - naip, self.patch_size, self.batch_size, self.length, train_roi + self.naip, self.patch_size, self.batch_size, self.length, train_roi ) - self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi) - self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi) + self.val_sampler = GridGeoSampler(self.naip, self.patch_size, self.stride, val_roi) + self.test_sampler = GridGeoSampler(self.naip, self.patch_size, self.stride, test_roi) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 9499aa28547..fbe4ff09d23 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from einops import repeat diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 56351ea67be..7721fd15a70 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 63452909d3a..681c908779d 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 447ecf5a4e1..eb1a91a4b7c 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, cast +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 2787526607f..2beeb055afd 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 8e8082a953b..6017b1cce77 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 40b8f0809fe..8f3fce7a8c4 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose From 7b6962e9cdebd4b756f16c4ccc6002cca87d73e7 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 12:08:01 -0500 Subject: [PATCH 03/17] Test missing attributes --- tests/trainers/test_classification.py | 22 ++++++++++++++++++++++ tests/trainers/test_regression.py | 12 ++++++++++++ tests/trainers/test_segmentation.py | 11 +++++++++++ torchgeo/datamodules/naip.py | 8 ++++++-- 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index d24b723aaaf..c26fac4b9e2 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -124,6 +124,17 @@ def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) + def test_missing_attributes( + self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + ) -> None: + monkeypatch.delattr(EuroSATDataModule, "plot") + datamodule = EuroSATDataModule( + root="tests/data/eurosat", batch_size=1, num_workers=0 + ) + model = ClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) + class TestMultiLabelClassificationTask: @pytest.mark.parametrize( @@ -190,3 +201,14 @@ def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) + + def test_missing_attributes( + self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + ) -> None: + monkeypatch.delattr(BigEarthNetDataModule, "plot") + datamodule = BigEarthNetDataModule( + root="tests/data/bigearthnet", batch_size=1, num_workers=0 + ) + model = MultiLabelClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index a867448b9c5..985cb7f43f0 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -7,6 +7,7 @@ import pytest from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer +from _pytest.monkeypatch import MonkeyPatch from torchgeo.datamodules import COWCCountingDataModule, CycloneDataModule from torchgeo.trainers import RegressionTask @@ -63,3 +64,14 @@ def test_invalid_model(self) -> None: match = "module 'torchvision.models' has no attribute 'invalid_model'" with pytest.raises(AttributeError, match=match): RegressionTask(model="invalid_model", pretrained=False) + + def test_missing_attributes( + self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + ) -> None: + monkeypatch.delattr(COWCCountingDataModule, "plot") + datamodule = COWCCountingDataModule( + root="tests/data/cowc_counting", batch_size=1, num_workers=0 + ) + model = RegressionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 7742a6cb5f6..03b9a22e3dc 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -134,3 +134,14 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: match = "ignore_index has no effect on training when loss='jaccard'" with pytest.warns(UserWarning, match=match): SemanticSegmentationTask(**model_kwargs) + + def test_missing_attributes( + self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + ) -> None: + monkeypatch.delattr(DeepGlobeLandCoverDataModule, "plot") + datamodule = DeepGlobeLandCoverDataModule( + root="tests/data/deepglobelandcover", batch_size=1, num_workers=0 + ) + model = SemanticSegmentationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 5d155492102..cd640cadf0b 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -141,8 +141,12 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_sampler = RandomBatchGeoSampler( self.naip, self.patch_size, self.batch_size, self.length, train_roi ) - self.val_sampler = GridGeoSampler(self.naip, self.patch_size, self.stride, val_roi) - self.test_sampler = GridGeoSampler(self.naip, self.patch_size, self.stride, test_roi) + self.val_sampler = GridGeoSampler( + self.naip, self.patch_size, self.stride, val_roi + ) + self.test_sampler = GridGeoSampler( + self.naip, self.patch_size, self.stride, test_roi + ) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. From 7aa1a539b8b15df7f9086a7d26378fc2b0893eca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 12:30:53 -0500 Subject: [PATCH 04/17] Add default seed --- torchgeo/datamodules/cowc.py | 2 +- torchgeo/datamodules/cyclone.py | 2 +- torchgeo/datamodules/sen12ms.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index ad421dae295..1607b6d1e1e 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -21,7 +21,7 @@ class COWCCountingDataModule(pl.LightningDataModule): """LightningDataModule implementation for the COWC Counting dataset.""" def __init__( - self, seed: int, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: """Initialize a LightningDataModule for COWC Counting based DataLoaders. diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 909853a90d0..bed564f633b 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -26,7 +26,7 @@ class CycloneDataModule(pl.LightningDataModule): """ def __init__( - self, seed: int, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, seed = 0: int, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 681c908779d..3d3360f0008 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -55,7 +55,7 @@ class SEN12MSDataModule(pl.LightningDataModule): def __init__( self, - seed: int, + seed: int = 0, band_set: str = "all", batch_size: int = 64, num_workers: int = 0, From da986314bc2b38722433c2628d95ec42b0c3e8e9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 12:34:25 -0500 Subject: [PATCH 05/17] Fix syntax error --- torchgeo/datamodules/cyclone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index bed564f633b..b4fe1993310 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -26,7 +26,7 @@ class CycloneDataModule(pl.LightningDataModule): """ def __init__( - self, seed = 0: int, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. From 81e6c930f2649c4e6de99a1a7d6630378f3929de Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 12:48:43 -0500 Subject: [PATCH 06/17] Fix model_kwargs --- tests/trainers/test_classification.py | 2 +- tests/trainers/test_regression.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index c26fac4b9e2..347148de704 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -191,7 +191,7 @@ def model_kwargs(self) -> Dict[Any, Any]: return { "classification_model": "resnet18", "in_channels": 1, - "loss": "ce", + "loss": "bce", "num_classes": 1, "weights": "random", } diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 985cb7f43f0..76e826c786c 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -65,6 +65,13 @@ def test_invalid_model(self) -> None: with pytest.raises(AttributeError, match=match): RegressionTask(model="invalid_model", pretrained=False) + @pytest.fixture + def model_kwargs(self) -> Dict[Any, Any]: + return { + "model": "resnet18", + "pretrained": False, + } + def test_missing_attributes( self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch ) -> None: From 4dff88bf9bab5df7571a7909abc5dd53d8e7eea1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 13:31:57 -0500 Subject: [PATCH 07/17] Fix classification tests --- tests/trainers/test_classification.py | 6 +++--- tests/trainers/test_regression.py | 5 +---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 347148de704..858416ebde4 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -86,7 +86,7 @@ def test_no_logger(self) -> None: def model_kwargs(self) -> Dict[Any, Any]: return { "classification_model": "resnet18", - "in_channels": 1, + "in_channels": 13, "loss": "ce", "num_classes": 2, "weights": "random", @@ -190,9 +190,9 @@ def test_no_logger(self) -> None: def model_kwargs(self) -> Dict[Any, Any]: return { "classification_model": "resnet18", - "in_channels": 1, + "in_channels": 14, "loss": "bce", - "num_classes": 1, + "num_classes": 19, "weights": "random", } diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 76e826c786c..396a1a55e73 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -67,10 +67,7 @@ def test_invalid_model(self) -> None: @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: - return { - "model": "resnet18", - "pretrained": False, - } + return {"model": "resnet18", "pretrained": False} def test_missing_attributes( self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch From 58e52d53afb1edb468f9b22bfbbe0e9c18362c3b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 13:40:56 -0500 Subject: [PATCH 08/17] Fix regression tests --- torchgeo/datasets/cyclone.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index e12a3359d6d..c27c383284f 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -232,6 +232,7 @@ def plot( .. versionadded:: 0.2 """ image, label = sample["image"], sample["label"] + image = image.permute((1, 2, 0)).numpy() showing_predictions = "prediction" in sample if showing_predictions: From e9b6435ff2ff6d5b004cdc5f88ff83b55ab033e7 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 4 Oct 2022 16:21:24 -0500 Subject: [PATCH 09/17] More fixes --- tests/trainers/test_regression.py | 2 +- torchgeo/datasets/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 396a1a55e73..dcc4fabbad5 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -5,9 +5,9 @@ from typing import Any, Dict, Type, cast import pytest +from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -from _pytest.monkeypatch import MonkeyPatch from torchgeo.datamodules import COWCCountingDataModule, CycloneDataModule from torchgeo.trainers import RegressionTask diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 658220d4903..2335dd16060 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -645,7 +645,7 @@ def draw_semantic_segmentation_masks( classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8)) class_masks = mask == classes[:, None, None] img = draw_segmentation_masks( - image=image, masks=class_masks, alpha=alpha, colors=colors + image=image.uint8(), masks=class_masks, alpha=alpha, colors=colors ) img = img.permute((1, 2, 0)).numpy().astype(np.uint8) return cast("np.typing.NDArray[np.uint8]", img) From e4a1c026e8f1fbf45b1196d3ae1fb5fd0969084c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:02:27 -0500 Subject: [PATCH 10/17] Undo changes that break code --- torchgeo/datamodules/cyclone.py | 8 -------- torchgeo/datamodules/deepglobelandcover.py | 8 -------- torchgeo/datamodules/oscd.py | 8 -------- torchgeo/datamodules/sen12ms.py | 8 -------- torchgeo/datasets/cyclone.py | 1 - torchgeo/datasets/utils.py | 2 +- 6 files changed, 1 insertion(+), 34 deletions(-) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index b4fe1993310..5ded6c05265 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional -import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit @@ -149,10 +148,3 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.TropicalCycloneWindEstimation.plot`. - - .. versionadded:: 0.4 - """ - return self.all_train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 955467048e3..1a5de89aee4 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional -import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose @@ -123,10 +122,3 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: num_workers=self.num_workers, shuffle=False, ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index fbe4ff09d23..e604eed767c 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple import kornia.augmentation as K -import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from einops import repeat @@ -197,10 +196,3 @@ def test_dataloader(self) -> DataLoader[Any]: return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.OSCD.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 3d3360f0008..e72d295acec 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional -import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit @@ -192,10 +191,3 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.SEN12MS.plot`. - - .. versionadded:: 0.4 - """ - return self.all_train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index c27c383284f..e12a3359d6d 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -232,7 +232,6 @@ def plot( .. versionadded:: 0.2 """ image, label = sample["image"], sample["label"] - image = image.permute((1, 2, 0)).numpy() showing_predictions = "prediction" in sample if showing_predictions: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 2335dd16060..1e7f43bbf87 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -645,7 +645,7 @@ def draw_semantic_segmentation_masks( classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8)) class_masks = mask == classes[:, None, None] img = draw_segmentation_masks( - image=image.uint8(), masks=class_masks, alpha=alpha, colors=colors + image=image.byte(), masks=class_masks, alpha=alpha, colors=colors ) img = img.permute((1, 2, 0)).numpy().astype(np.uint8) return cast("np.typing.NDArray[np.uint8]", img) From 0974502dcfb3dbfb0fd8824960c7ad4a80ef7299 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:05:22 -0500 Subject: [PATCH 11/17] Remove redundant cast --- torchgeo/datasets/cyclone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index e12a3359d6d..49453519d9e 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -6,7 +6,7 @@ import json import os from functools import lru_cache -from typing import Any, Callable, Dict, Optional, cast +from typing import Any, Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np @@ -245,7 +245,7 @@ def plot( if show_titles: title = f"Label: {label}" if showing_predictions: - title += f"\nPrediction: {cast(str, prediction)}" + title += f"\nPrediction: {prediction}" ax.set_title(title) if suptitle is not None: From d07be1e80cf621b0161466f4c24a74aa1178bb20 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:12:08 -0500 Subject: [PATCH 12/17] DGLC no longer has a plot method --- tests/trainers/test_segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 03b9a22e3dc..84a86ba3e0d 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -138,7 +138,8 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: def test_missing_attributes( self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch ) -> None: - monkeypatch.delattr(DeepGlobeLandCoverDataModule, "plot") + # TODO: uncomment once DGLC has a plot method + # monkeypatch.delattr(DeepGlobeLandCoverDataModule, "plot") datamodule = DeepGlobeLandCoverDataModule( root="tests/data/deepglobelandcover", batch_size=1, num_workers=0 ) From 6c686770a3bb83863de25e9a3c459f9fd4b2dd27 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:28:05 -0500 Subject: [PATCH 13/17] Change no attribute plot dataset --- tests/trainers/test_segmentation.py | 11 +++++------ torchgeo/datamodules/deepglobelandcover.py | 7 +++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 84a86ba3e0d..ff10be7a5e0 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -104,8 +104,8 @@ def model_kwargs(self) -> Dict[Any, Any]: "segmentation_model": "unet", "encoder_name": "resnet18", "encoder_weights": None, - "in_channels": 1, - "num_classes": 2, + "in_channels": 3, + "num_classes": 6, "loss": "ce", "ignore_index": 0, } @@ -138,10 +138,9 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: def test_missing_attributes( self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch ) -> None: - # TODO: uncomment once DGLC has a plot method - # monkeypatch.delattr(DeepGlobeLandCoverDataModule, "plot") - datamodule = DeepGlobeLandCoverDataModule( - root="tests/data/deepglobelandcover", batch_size=1, num_workers=0 + monkeypatch.delattr(LandCoverAIDataModule, "plot") + datamodule = LandCoverAIDataModule( + root="tests/data/landcoverai", batch_size=1, num_workers=0 ) model = SemanticSegmentationTask(**model_kwargs) trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 1a5de89aee4..ad5514962f5 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose @@ -122,3 +123,9 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) From 29c5b101c30537e52bbdb16e6a3f36f4760ac48a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:29:44 -0500 Subject: [PATCH 14/17] correct num classes --- tests/trainers/test_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 858416ebde4..fa5fdb7003f 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -88,7 +88,7 @@ def model_kwargs(self) -> Dict[Any, Any]: "classification_model": "resnet18", "in_channels": 13, "loss": "ce", - "num_classes": 2, + "num_classes": 10, "weights": "random", } From d2aac6de4e4da2d11d56a531a63b276397bf708e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:32:56 -0500 Subject: [PATCH 15/17] Fix pydocstyle --- torchgeo/datamodules/deepglobelandcover.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index ad5514962f5..955467048e3 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -126,6 +126,7 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. + .. versionadded:: 0.4 """ return self.test_dataset.plot(*args, **kwargs) From 53df886cbc54535a2b9043695deab2964751fbff Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:46:07 -0500 Subject: [PATCH 16/17] Increase test coverage --- tests/datamodules/test_fair1m.py | 8 ++++++++ tests/datamodules/test_loveda.py | 8 ++++++++ tests/datamodules/test_potsdam.py | 8 ++++++++ tests/datamodules/test_usavars.py | 8 ++++++++ tests/datamodules/test_vaihingen.py | 8 ++++++++ tests/datamodules/test_xview2.py | 8 ++++++++ 6 files changed, 48 insertions(+) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index 44f2d4883ff..24745d456e4 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -3,8 +3,10 @@ import os +import matplotlib.pyplot as plt import pytest +from torchgeo.datasets import unbind_samples from torchgeo.datamodules import FAIR1MDataModule @@ -32,3 +34,9 @@ def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: FAIR1MDataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() diff --git a/tests/datamodules/test_loveda.py b/tests/datamodules/test_loveda.py index 2e2f3a89b66..ff444112723 100644 --- a/tests/datamodules/test_loveda.py +++ b/tests/datamodules/test_loveda.py @@ -3,8 +3,10 @@ import os +import matplotlib.pyplot as plt import pytest +from torchgeo.datasets import unbind_samples from torchgeo.datamodules import LoveDADataModule @@ -32,3 +34,9 @@ def test_val_dataloader(self, datamodule: LoveDADataModule) -> None: def test_test_dataloader(self, datamodule: LoveDADataModule) -> None: next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: LoveDADataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() diff --git a/tests/datamodules/test_potsdam.py b/tests/datamodules/test_potsdam.py index daabb6e6d72..f330713f3ed 100644 --- a/tests/datamodules/test_potsdam.py +++ b/tests/datamodules/test_potsdam.py @@ -3,9 +3,11 @@ import os +import matplotlib.pyplot as plt import pytest from _pytest.fixtures import SubRequest +from torchgeo.datasets import unbind_samples from torchgeo.datamodules import Potsdam2DDataModule @@ -34,3 +36,9 @@ def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None: def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None: next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: Potsdam2DDataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index aebfbe67d6b..73530550dd0 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -3,9 +3,11 @@ import os +import matplotlib.pyplot as plt import pytest from _pytest.fixtures import SubRequest +from torchgeo.datasets import unbind_samples from torchgeo.datamodules import USAVarsDataModule @@ -38,3 +40,9 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: assert len(datamodule.test_dataloader()) == 1 sample = next(iter(datamodule.test_dataloader())) assert sample["image"].shape[0] == datamodule.batch_size + + def test_plot(self, datamodule: USAVarsDataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py index 85f16fd56b2..6260df9fa6e 100644 --- a/tests/datamodules/test_vaihingen.py +++ b/tests/datamodules/test_vaihingen.py @@ -3,9 +3,11 @@ import os +import matplotlib.pyplot as plt import pytest from _pytest.fixtures import SubRequest +from torchgeo.datasets import unbind_samples from torchgeo.datamodules import Vaihingen2DDataModule @@ -34,3 +36,9 @@ def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: Vaihingen2DDataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 7086532b4d7..1490f45396b 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -3,9 +3,11 @@ import os +import matplotlib.pyplot as plt import pytest from _pytest.fixtures import SubRequest +from torchgeo.datasets import unbind_samples from torchgeo.datamodules import XView2DataModule @@ -34,3 +36,9 @@ def test_val_dataloader(self, datamodule: XView2DataModule) -> None: def test_test_dataloader(self, datamodule: XView2DataModule) -> None: next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: XView2DataModule) -> None: + batch = next(iter(datamodule.train_dataloader())) + sample = unbind_samples(batch)[0] + datamodule.plot(sample) + plt.close() From 8323831bda778e24b5a571cca82e7f01a64ef14e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 13 Oct 2022 21:50:24 -0500 Subject: [PATCH 17/17] Fix isort --- tests/datamodules/test_fair1m.py | 2 +- tests/datamodules/test_loveda.py | 2 +- tests/datamodules/test_potsdam.py | 2 +- tests/datamodules/test_usavars.py | 2 +- tests/datamodules/test_vaihingen.py | 2 +- tests/datamodules/test_xview2.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index 24745d456e4..9b26a3def61 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -6,8 +6,8 @@ import matplotlib.pyplot as plt import pytest -from torchgeo.datasets import unbind_samples from torchgeo.datamodules import FAIR1MDataModule +from torchgeo.datasets import unbind_samples class TestFAIR1MDataModule: diff --git a/tests/datamodules/test_loveda.py b/tests/datamodules/test_loveda.py index ff444112723..4b7c3dee1cb 100644 --- a/tests/datamodules/test_loveda.py +++ b/tests/datamodules/test_loveda.py @@ -6,8 +6,8 @@ import matplotlib.pyplot as plt import pytest -from torchgeo.datasets import unbind_samples from torchgeo.datamodules import LoveDADataModule +from torchgeo.datasets import unbind_samples class TestLoveDADataModule: diff --git a/tests/datamodules/test_potsdam.py b/tests/datamodules/test_potsdam.py index f330713f3ed..5a8ad1f785f 100644 --- a/tests/datamodules/test_potsdam.py +++ b/tests/datamodules/test_potsdam.py @@ -7,8 +7,8 @@ import pytest from _pytest.fixtures import SubRequest -from torchgeo.datasets import unbind_samples from torchgeo.datamodules import Potsdam2DDataModule +from torchgeo.datasets import unbind_samples class TestPotsdam2DDataModule: diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 73530550dd0..874c502a619 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -7,8 +7,8 @@ import pytest from _pytest.fixtures import SubRequest -from torchgeo.datasets import unbind_samples from torchgeo.datamodules import USAVarsDataModule +from torchgeo.datasets import unbind_samples class TestUSAVarsDataModule: diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py index 6260df9fa6e..13fd2d52e4c 100644 --- a/tests/datamodules/test_vaihingen.py +++ b/tests/datamodules/test_vaihingen.py @@ -7,8 +7,8 @@ import pytest from _pytest.fixtures import SubRequest -from torchgeo.datasets import unbind_samples from torchgeo.datamodules import Vaihingen2DDataModule +from torchgeo.datasets import unbind_samples class TestVaihingen2DDataModule: diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 1490f45396b..c190b5d2acd 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -7,8 +7,8 @@ import pytest from _pytest.fixtures import SubRequest -from torchgeo.datasets import unbind_samples from torchgeo.datamodules import XView2DataModule +from torchgeo.datasets import unbind_samples class TestXView2DataModule: