diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index ed9c68d39a0..2ac68ca4ea5 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -12,5 +12,6 @@ experiment: root: "tests/data/bigearthnet" bands: "all" num_classes: ${experiment.module.num_classes} + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index 6c16bb4e7e0..a5427bff5b1 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -12,5 +12,6 @@ experiment: root: "tests/data/bigearthnet" bands: "s1" num_classes: ${experiment.module.num_classes} + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 74876350e8f..49ea9c7235e 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -12,5 +12,6 @@ experiment: root: "tests/data/bigearthnet" bands: "s2" num_classes: ${experiment.module.num_classes} + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/byol.yaml b/tests/conf/byol.yaml index a135cb13102..982fcd888e9 100644 --- a/tests/conf/byol.yaml +++ b/tests/conf/byol.yaml @@ -10,6 +10,7 @@ experiment: learning_rate_schedule_patience: 6 datamodule: root: "tests/data/chesapeake/cvpr" + download: true train_splits: - "de-test" val_splits: diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index ee167904eb4..7e1663c9493 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -13,6 +13,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/chesapeake/cvpr" + download: true train_splits: - "de-test" val_splits: diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 04807d5ab4f..5597890e98b 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -13,6 +13,7 @@ experiment: weights: imagenet datamodule: root: "tests/data/chesapeake/cvpr" + download: true train_splits: - "de-test" val_splits: diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior.yaml index 9c2f96c1718..3f4d440e2d2 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior.yaml @@ -13,6 +13,7 @@ experiment: weights: imagenet datamodule: root: "tests/data/chesapeake/cvpr" + download: true train_splits: - "de-test" val_splits: diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index 781f632853e..fe12f68e727 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -10,6 +10,7 @@ experiment: pretrained: True datamodule: root: "tests/data/cowc_counting" + download: true seed: 0 batch_size: 1 num_workers: 0 diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index 38192db9642..72b5c5e0b60 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -4,12 +4,13 @@ experiment: model: "resnet18" weights: "random" num_outputs: 1 - in_channels: 3 + in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 pretrained: False datamodule: root: "tests/data/cyclone" + download: true seed: 0 batch_size: 1 num_workers: 0 diff --git a/tests/conf/etci2021.yaml b/tests/conf/etci2021.yaml index 73fde1dfb87..cbb766ea522 100644 --- a/tests/conf/etci2021.yaml +++ b/tests/conf/etci2021.yaml @@ -12,5 +12,6 @@ experiment: ignore_index: 0 datamodule: root: "tests/data/etci2021" + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index e865c7af8e9..e674d2f0924 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -10,5 +10,6 @@ experiment: num_classes: 2 datamodule: root: "tests/data/eurosat" + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/landcoverai.yaml b/tests/conf/landcoverai.yaml index 04ae4a33030..9bffc96b83d 100644 --- a/tests/conf/landcoverai.yaml +++ b/tests/conf/landcoverai.yaml @@ -14,5 +14,6 @@ experiment: ignore_index: null datamodule: root: "tests/data/landcoverai" + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/naipchesapeake.yaml b/tests/conf/naipchesapeake.yaml index b2d43c44f45..9cd0e2beb96 100644 --- a/tests/conf/naipchesapeake.yaml +++ b/tests/conf/naipchesapeake.yaml @@ -14,6 +14,7 @@ experiment: datamodule: naip_root: "tests/data/naip" chesapeake_root: "tests/data/chesapeake/BAYWIDE" + chesapeake_download: true batch_size: 2 num_workers: 0 patch_size: 32 diff --git a/tests/conf/nasa_marine_debris.yaml b/tests/conf/nasa_marine_debris.yaml index 1a827dd3fe3..5528b38c39c 100644 --- a/tests/conf/nasa_marine_debris.yaml +++ b/tests/conf/nasa_marine_debris.yaml @@ -9,5 +9,6 @@ experiment: verbose: false datamodule: root: "tests/data/nasa_marine_debris" + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/oscd_all.yaml b/tests/conf/oscd_all.yaml index 34857595cd4..dbaede4185b 100644 --- a/tests/conf/oscd_all.yaml +++ b/tests/conf/oscd_all.yaml @@ -14,6 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/oscd" + download: true train_batch_size: 1 num_workers: 0 val_split_pct: 0.5 diff --git a/tests/conf/oscd_rgb.yaml b/tests/conf/oscd_rgb.yaml index ede707145b8..0f7a51f8f1f 100644 --- a/tests/conf/oscd_rgb.yaml +++ b/tests/conf/oscd_rgb.yaml @@ -14,6 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/oscd" + download: true train_batch_size: 1 num_workers: 0 val_split_pct: 0.5 diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index 89f7b8072c4..0b545cc18d9 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -10,5 +10,6 @@ experiment: num_classes: 3 datamodule: root: "tests/data/resisc45" + download: true batch_size: 1 num_workers: 0 diff --git a/tests/conf/spacenet1.yaml b/tests/conf/spacenet1.yaml index de290dd1527..3f05a745573 100644 --- a/tests/conf/spacenet1.yaml +++ b/tests/conf/spacenet1.yaml @@ -14,6 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/spacenet" + download: true batch_size: 1 num_workers: 0 val_split_pct: 0.33 diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index fe39f579c35..7c2995b475f 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -10,5 +10,6 @@ experiment: num_classes: 2 datamodule: root: "tests/data/ucmerced" + download: true batch_size: 1 num_workers: 0 diff --git a/tests/datamodules/test_loveda.py b/tests/datamodules/test_loveda.py index 4b7c3dee1cb..92457ebff64 100644 --- a/tests/datamodules/test_loveda.py +++ b/tests/datamodules/test_loveda.py @@ -19,7 +19,11 @@ def datamodule(self) -> LoveDADataModule: scene = ["rural", "urban"] dm = LoveDADataModule( - root=root, scene=scene, batch_size=batch_size, num_workers=num_workers + root=root, + scene=scene, + batch_size=batch_size, + num_workers=num_workers, + download=True, ) dm.prepare_data() diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 874c502a619..ddcde7e6b26 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -20,7 +20,7 @@ def datamodule(self, request: SubRequest) -> USAVarsDataModule: num_workers = 0 dm = USAVarsDataModule( - root=root, batch_size=batch_size, num_workers=num_workers + root=root, batch_size=batch_size, num_workers=num_workers, download=True ) dm.prepare_data() dm.setup() diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index ae5653b4ae9..25e2f298e1b 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -111,7 +111,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - BigEarthNet(split="train", **self.kwargs) + if self.kwargs.get("download", False): + BigEarthNet(split="train", **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 2f9aa4437c5..6475e8923dd 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -226,7 +226,8 @@ def prepare_data(self) -> None: This method is called once per node, while :func:`setup` is called once per GPU. """ - ChesapeakeCVPR(splits=self.train_splits, layers=self.layers, **self.kwargs) + if self.kwargs.get("download", False): + ChesapeakeCVPR(splits=self.train_splits, layers=self.layers, **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 1607b6d1e1e..4d5255b5066 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -59,7 +59,8 @@ def prepare_data(self) -> None: This includes optionally downloading the dataset. This is done once per node, while :func:`setup` is done once per GPU. """ - COWCCounting(**self.kwargs) + if self.kwargs.get("download", False): + COWCCounting(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 6051cb247ae..fb0b1a54f80 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -71,7 +71,8 @@ def prepare_data(self) -> None: This includes optionally downloading the dataset. This is done once per node, while :func:`setup` is done once per GPU. """ - TropicalCyclone(split="train", **self.kwargs) + if self.kwargs.get("download", False): + TropicalCyclone(split="train", **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 82f5342cf47..101c9a3a318 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -79,7 +79,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - ETCI2021(**self.kwargs) + if self.kwargs.get("download", False): + ETCI2021(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 48634f7c6a9..cb672011fb6 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -94,7 +94,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - EuroSAT(**self.kwargs) + if self.kwargs.get("download", False): + EuroSAT(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 463b0e3c872..17e26b31faa 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -106,7 +106,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - LandCoverAI(**self.kwargs) + if self.kwargs.get("download", False): + LandCoverAI(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index f4f1dc39c28..c51a7e3d939 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -59,7 +59,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - LoveDA(**self.kwargs) + if self.kwargs.get("download", False): + LoveDA(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index cd640cadf0b..b3d08a9ba13 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -31,8 +31,6 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule): def __init__( self, - naip_root: str, - chesapeake_root: str, batch_size: int = 64, num_workers: int = 0, patch_size: int = 256, @@ -41,22 +39,26 @@ def __init__( """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. Args: - naip_root: directory containing NAIP data - chesapeake_root: directory containing Chesapeake data batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders patch_size: size of patches to sample **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.NAIP` and + :class:`~torchgeo.datasets.NAIP` (prefix keys with ``naip_``) and :class:`~torchgeo.datasets.Chesapeake13` + (prefix keys with ``chesapeake_``) """ super().__init__() - self.naip_root = naip_root - self.chesapeake_root = chesapeake_root self.batch_size = batch_size self.num_workers = num_workers self.patch_size = patch_size - self.kwargs = kwargs + + self.naip_kwargs = {} + self.chesapeake_kwargs = {} + for key, val in kwargs.items(): + if key.startswith("naip_"): + self.naip_kwargs[key[5:]] = val + elif key.startswith("chesapeake_"): + self.chesapeake_kwargs[key[11:]] = val def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the NAIP Dataset. @@ -102,7 +104,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - Chesapeake13(self.chesapeake_root, **self.kwargs) + if self.chesapeake_kwargs.get("download", False): + Chesapeake13(**self.chesapeake_kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. @@ -119,14 +122,13 @@ def setup(self, stage: Optional[str] = None) -> None: chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox]) self.chesapeake = Chesapeake13( - self.chesapeake_root, transforms=chesapeak_transforms, **self.kwargs + transforms=chesapeak_transforms, **self.chesapeake_kwargs ) self.naip = NAIP( - self.naip_root, - self.chesapeake.crs, - self.chesapeake.res, + crs=self.chesapeake.crs, + res=self.chesapeake.res, transforms=naip_transforms, - **self.kwargs, + **self.naip_kwargs, ) self.dataset = self.chesapeake & self.naip diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 4184759a5fc..c9b559dce32 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -84,7 +84,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - NASAMarineDebris(**self.kwargs) + if self.kwargs.get("download", False): + NASAMarineDebris(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index e604eed767c..90a4c94797b 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -117,7 +117,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - OSCD(split="train", **self.kwargs) + if self.kwargs.get("download", False): + OSCD(split="train", **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 8d8feb0022c..19cec81f254 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -107,7 +107,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - RESISC45(**self.kwargs) + if self.kwargs.get("download", False): + RESISC45(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index eb1a91a4b7c..6d6b7c4e5f3 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -103,13 +103,6 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - So2Sat(**self.kwargs) - def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 7f19218d579..739b9d5a5e9 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -123,7 +123,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - SpaceNet1(**self.kwargs) + if self.kwargs.get("download", False): + SpaceNet1(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index e6c1ccbae53..8ec04491578 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -63,7 +63,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - UCMerced(**self.kwargs) + if self.kwargs.get("download", False): + UCMerced(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 2beeb055afd..9c7fa6d0333 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -54,7 +54,8 @@ def prepare_data(self) -> None: This method is only called once per run. """ - USAVars(**self.kwargs) + if self.kwargs.get("download", False): + USAVars(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main Dataset objects.