diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst new file mode 100644 index 00000000000..ad47f9cd648 --- /dev/null +++ b/docs/api/datamodules.rst @@ -0,0 +1,105 @@ +torchgeo.datamodules +==================== + +.. module:: torchgeo.datamodules + +Geospatial DataModules +---------------------- + +Chesapeake Bay High-Resolution Land Cover Project +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: ChesapeakeCVPRDataModule + +National Agriculture Imagery Program (NAIP) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: NAIPChesapeakeDataModule + +Non-geospatial DataModules +-------------------------- + +BigEarthNet +^^^^^^^^^^^ + +.. autoclass:: BigEarthNetDataModule + +Cars Overhead With Context (COWC) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: COWCCountingDataModule + +ETCI2021 Flood Detection +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: ETCI2021DataModule + +EuroSAT +^^^^^^^ + +.. autoclass:: EuroSATDataModule + +FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: FAIR1MDataModule + +LandCover.ai (Land Cover from Aerial Imagery) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: LandCoverAIDataModule + +LoveDA (Land-cOVEr Domain Adaptive semantic segmentation) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: LoveDADataModule + +NASA Marine Debris +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: NASAMarineDebrisDataModule + +OSCD (Onera Satellite Change Detection) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: OSCDDataModule + +Potsdam +^^^^^^^ + +.. autoclass:: Potsdam2DDataModule + +RESISC45 (Remote Sensing Image Scene Classification) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RESISC45DataModule + +SEN12MS +^^^^^^^ + +.. autoclass:: SEN12MSDataModule + +So2Sat +^^^^^^ + +.. autoclass:: So2SatDataModule + +Tropical Cyclone Wind Estimation Competition +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: CycloneDataModule + +UC Merced +^^^^^^^^^ + +.. autoclass:: UCMercedDataModule + +Vaihingen +^^^^^^^^^ + +.. autoclass:: Vaihingen2DDataModule + +xView2 +^^^^^^ + +.. autoclass:: XView2DataModule diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 4b30c762a76..f3df706a8b4 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -31,7 +31,6 @@ Chesapeake Bay High-Resolution Land Cover Project .. autoclass:: ChesapeakeVA .. autoclass:: ChesapeakeWV .. autoclass:: ChesapeakeCVPR -.. autoclass:: ChesapeakeCVPRDataModule Cropland Data Layer (CDL) ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -57,7 +56,6 @@ National Agriculture Imagery Program (NAIP) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: NAIP -.. autoclass:: NAIPChesapeakeDataModule Sentinel ^^^^^^^^ @@ -86,7 +84,6 @@ BigEarthNet ^^^^^^^^^^^ .. autoclass:: BigEarthNet -.. autoclass:: BigEarthNetDataModule Cars Overhead With Context (COWC) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -94,7 +91,6 @@ Cars Overhead With Context (COWC) .. autoclass:: COWC .. autoclass:: COWCCounting .. autoclass:: COWCDetection -.. autoclass:: COWCCountingDataModule CV4A Kenya Crop Type Competition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -105,19 +101,16 @@ ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: ETCI2021 -.. autoclass:: ETCI2021DataModule EuroSAT -^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^ .. autoclass:: EuroSAT -.. autoclass:: EuroSATDataModule FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: FAIR1M -.. autoclass:: FAIR1MDataModule GID-15 (Gaofen Image Dataset) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -133,7 +126,6 @@ LandCover.ai (Land Cover from Aerial Imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: LandCoverAI -.. autoclass:: LandCoverAIDataModule LEVIR-CD+ (LEVIR Change Detection +) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -144,19 +136,16 @@ LoveDA (Land-cOVEr Domain Adaptive semantic segmentation) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: LoveDA -.. autoclass:: LoveDADataModule NASA Marine Debris ^^^^^^^^^^^^^^^^^^ .. autoclass:: NASAMarineDebris -.. autoclass:: NASAMarineDebrisDataModule OSCD (Onera Satellite Change Detection) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: OSCD -.. autoclass:: OSCDDataModule PatternNet ^^^^^^^^^^ @@ -167,13 +156,11 @@ Potsdam ^^^^^^^ .. autoclass:: Potsdam2D -.. autoclass:: Potsdam2DDataModule RESISC45 (Remote Sensing Image Scene Classification) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: RESISC45 -.. autoclass:: RESISC45DataModule Seasonal Contrast ^^^^^^^^^^^^^^^^^ @@ -184,13 +171,11 @@ SEN12MS ^^^^^^^ .. autoclass:: SEN12MS -.. autoclass:: SEN12MSDataModule So2Sat ^^^^^^ .. autoclass:: So2Sat -.. autoclass:: So2SatDataModule SpaceNet ^^^^^^^^ @@ -206,30 +191,26 @@ Tropical Cyclone Wind Estimation Competition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: TropicalCycloneWindEstimation -.. autoclass:: CycloneDataModule + +UC Merced +^^^^^^^^^ + +.. autoclass:: UCMerced Vaihingen ^^^^^^^^^ .. autoclass:: Vaihingen2D -.. autoclass:: Vaihingen2DDataModule NWPU VHR-10 ^^^^^^^^^^^ .. autoclass:: VHR10 -UC Merced -^^^^^^^^^ - -.. autoclass:: UCMerced -.. autoclass:: UCMercedDataModule - xView2 ^^^^^^ .. autoclass:: XView2 -.. autoclass:: XView2DataModule ZueriCrop ^^^^^^^^^ diff --git a/docs/index.rst b/docs/index.rst index d8fe309a2d1..d2e85fa5d5a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ torchgeo :maxdepth: 2 :caption: Package Reference + api/datamodules api/datasets api/models api/samplers diff --git a/experiments/test_chesapeakecvpr_models.py b/experiments/test_chesapeakecvpr_models.py index 249dd20bc7d..38ab24fa842 100755 --- a/experiments/test_chesapeakecvpr_models.py +++ b/experiments/test_chesapeakecvpr_models.py @@ -11,7 +11,7 @@ import pytorch_lightning as pl import torch -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]] diff --git a/pyproject.toml b/pyproject.toml index cfb34131556..f449ed08902 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ strict_equality = true [tool.pydocstyle] convention = "google" -match_dir = "(datasets|models|samplers|torchgeo|trainers|transforms)" +match_dir = "(datamodules|datasets|models|samplers|torchgeo|trainers|transforms)" [tool.pytest.ini_options] # Skip slow tests by default diff --git a/tests/datamodules/__init__.py b/tests/datamodules/__init__.py new file mode 100644 index 00000000000..5b7f7a925cc --- /dev/null +++ b/tests/datamodules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. diff --git a/tests/datamodules/test_bigearthnet.py b/tests/datamodules/test_bigearthnet.py new file mode 100644 index 00000000000..b460877a76a --- /dev/null +++ b/tests/datamodules/test_bigearthnet.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import BigEarthNetDataModule + + +class TestBigEarthNetDataModule: + @pytest.fixture(scope="class", params=["s1", "s2", "all"]) + def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: + bands = request.param + root = os.path.join("tests", "data", "bigearthnet") + num_classes = 19 + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_chesapeake.py b/tests/datamodules/test_chesapeake.py new file mode 100644 index 00000000000..f04d470b779 --- /dev/null +++ b/tests/datamodules/test_chesapeake.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +import torch +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import ChesapeakeCVPRDataModule + + +class TestChesapeakeCVPRDataModule: + @pytest.fixture(scope="class", params=[5, 7]) + def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule: + dm = ChesapeakeCVPRDataModule( + os.path.join("tests", "data", "chesapeake", "cvpr"), + ["de-test"], + ["de-test"], + ["de-test"], + patch_size=32, + patches_per_tile=2, + batch_size=2, + num_workers=0, + class_set=request.param, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: + next(iter(datamodule.test_dataloader())) + + def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None: + nodata_check = datamodule.nodata_check(4) + sample = { + "image": torch.ones(1, 2, 2), # type: ignore[attr-defined] + "mask": torch.ones(2, 2), # type: ignore[attr-defined] + } + out = nodata_check(sample) + assert torch.equal( # type: ignore[attr-defined] + out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined] + ) + assert torch.equal( # type: ignore[attr-defined] + out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined] + ) diff --git a/tests/datamodules/test_cowc.py b/tests/datamodules/test_cowc.py new file mode 100644 index 00000000000..8b6e974f35b --- /dev/null +++ b/tests/datamodules/test_cowc.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import COWCCountingDataModule + + +class TestCOWCCountingDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> COWCCountingDataModule: + root = os.path.join("tests", "data", "cowc_counting") + seed = 0 + batch_size = 1 + num_workers = 0 + dm = COWCCountingDataModule(root, seed, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_cyclone.py b/tests/datamodules/test_cyclone.py new file mode 100644 index 00000000000..843cc1656ed --- /dev/null +++ b/tests/datamodules/test_cyclone.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import CycloneDataModule + + +class TestCycloneDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> CycloneDataModule: + root = os.path.join("tests", "data", "cyclone") + seed = 0 + batch_size = 1 + num_workers = 0 + dm = CycloneDataModule(root, seed, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: CycloneDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: CycloneDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: CycloneDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_etci2021.py b/tests/datamodules/test_etci2021.py new file mode 100644 index 00000000000..b51e8daf1b1 --- /dev/null +++ b/tests/datamodules/test_etci2021.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import ETCI2021DataModule + + +class TestETCI2021DataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> ETCI2021DataModule: + root = os.path.join("tests", "data", "etci2021") + seed = 0 + batch_size = 2 + num_workers = 0 + dm = ETCI2021DataModule(root, seed, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_eurosat.py b/tests/datamodules/test_eurosat.py new file mode 100644 index 00000000000..a8a51cd7b53 --- /dev/null +++ b/tests/datamodules/test_eurosat.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import EuroSATDataModule + + +class TestEuroSATDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> EuroSATDataModule: + root = os.path.join("tests", "data", "eurosat") + batch_size = 1 + num_workers = 0 + dm = EuroSATDataModule(root, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py new file mode 100644 index 00000000000..1f19922f1eb --- /dev/null +++ b/tests/datamodules/test_fair1m.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import FAIR1MDataModule + + +class TestFAIR1MDataModule: + @pytest.fixture(scope="class", params=[True, False]) + def datamodule(self) -> FAIR1MDataModule: + root = os.path.join("tests", "data", "fair1m") + batch_size = 2 + num_workers = 0 + dm = FAIR1MDataModule( + root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33 + ) + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_landcoverai.py b/tests/datamodules/test_landcoverai.py new file mode 100644 index 00000000000..1d4f2e43150 --- /dev/null +++ b/tests/datamodules/test_landcoverai.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import LandCoverAIDataModule + + +class TestLandCoverAIDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> LandCoverAIDataModule: + root = os.path.join("tests", "data", "landcoverai") + batch_size = 2 + num_workers = 0 + dm = LandCoverAIDataModule(root, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_loveda.py b/tests/datamodules/test_loveda.py new file mode 100644 index 00000000000..c19e8cb0ab9 --- /dev/null +++ b/tests/datamodules/test_loveda.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import LoveDADataModule + + +class TestLoveDADataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> LoveDADataModule: + root = os.path.join("tests", "data", "loveda") + batch_size = 2 + num_workers = 0 + scene = ["rural", "urban"] + + dm = LoveDADataModule( + root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers + ) + + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: LoveDADataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: LoveDADataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: LoveDADataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_naip.py b/tests/datamodules/test_naip.py new file mode 100644 index 00000000000..5f9d676f2ba --- /dev/null +++ b/tests/datamodules/test_naip.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import NAIPChesapeakeDataModule + + +class TestNAIPChesapeakeDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> NAIPChesapeakeDataModule: + dm = NAIPChesapeakeDataModule( + os.path.join("tests", "data", "naip"), + os.path.join("tests", "data", "chesapeake", "BAYWIDE"), + batch_size=2, + num_workers=0, + ) + dm.patch_size = 32 + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_nasa_marine_debris.py b/tests/datamodules/test_nasa_marine_debris.py new file mode 100644 index 00000000000..eff571f953c --- /dev/null +++ b/tests/datamodules/test_nasa_marine_debris.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import NASAMarineDebrisDataModule + + +class TestNASAMarineDebrisDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> NASAMarineDebrisDataModule: + root = os.path.join("tests", "data", "nasa_marine_debris") + batch_size = 2 + num_workers = 0 + val_split_pct = 0.3 + test_split_pct = 0.3 + dm = NASAMarineDebrisDataModule( + root, batch_size, num_workers, val_split_pct, test_split_pct + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py new file mode 100644 index 00000000000..7d090f99c97 --- /dev/null +++ b/tests/datamodules/test_oscd.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import OSCDDataModule + + +class TestOSCDDataModule: + @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) + def datamodule(self, request: SubRequest) -> OSCDDataModule: + bands, val_split_pct = request.param + patch_size = (2, 2) + num_patches_per_tile = 2 + root = os.path.join("tests", "data", "oscd") + batch_size = 1 + num_workers = 0 + dm = OSCDDataModule( + root, + bands, + batch_size, + num_workers, + val_split_pct, + patch_size, + num_patches_per_tile, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: + sample = next(iter(datamodule.train_dataloader())) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 + if datamodule.bands == "all": + assert sample["image"].shape[1] == 26 + else: + assert sample["image"].shape[1] == 6 + + def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: + sample = next(iter(datamodule.val_dataloader())) + if datamodule.val_split_pct > 0.0: + assert ( + sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) + ) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert sample["image"].shape[1] == 26 + else: + assert sample["image"].shape[1] == 6 + + def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: + sample = next(iter(datamodule.test_dataloader())) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert sample["image"].shape[1] == 26 + else: + assert sample["image"].shape[1] == 6 diff --git a/tests/datamodules/test_potsdam.py b/tests/datamodules/test_potsdam.py new file mode 100644 index 00000000000..f67be0fea7c --- /dev/null +++ b/tests/datamodules/test_potsdam.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import Potsdam2DDataModule + + +class TestPotsdam2DDataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> Potsdam2DDataModule: + root = os.path.join("tests", "data", "potsdam") + batch_size = 1 + num_workers = 0 + val_split_size = request.param + dm = Potsdam2DDataModule( + root, batch_size, num_workers, val_split_pct=val_split_size + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_resisc45.py b/tests/datamodules/test_resisc45.py new file mode 100644 index 00000000000..e1b9baa83f5 --- /dev/null +++ b/tests/datamodules/test_resisc45.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import RESISC45DataModule + + +class TestRESISC45DataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> RESISC45DataModule: + root = os.path.join("tests", "data", "resisc45") + batch_size = 2 + num_workers = 0 + dm = RESISC45DataModule(root, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_sen12ms.py b/tests/datamodules/test_sen12ms.py new file mode 100644 index 00000000000..7cd6df73857 --- /dev/null +++ b/tests/datamodules/test_sen12ms.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import SEN12MSDataModule + + +class TestSEN12MSDataModule: + @pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"]) + def datamodule(self, request: SubRequest) -> SEN12MSDataModule: + root = os.path.join("tests", "data", "sen12ms") + seed = 0 + bands = request.param + batch_size = 1 + num_workers = 0 + dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_so2sat.py b/tests/datamodules/test_so2sat.py new file mode 100644 index 00000000000..2f732306901 --- /dev/null +++ b/tests/datamodules/test_so2sat.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import So2SatDataModule + +pytest.importorskip("h5py") + + +class TestSo2SatDataModule: + @pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"])) + def datamodule(self, request: SubRequest) -> So2SatDataModule: + unsupervised_mode, bands = request.param + root = os.path.join("tests", "data", "so2sat") + batch_size = 2 + num_workers = 0 + dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: So2SatDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: So2SatDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: So2SatDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_ucmerced.py b/tests/datamodules/test_ucmerced.py new file mode 100644 index 00000000000..8dd7ab83360 --- /dev/null +++ b/tests/datamodules/test_ucmerced.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import UCMercedDataModule + + +class TestUCMercedDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> UCMercedDataModule: + root = os.path.join("tests", "data", "ucmerced") + batch_size = 2 + num_workers = 0 + dm = UCMercedDataModule(root, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py new file mode 100644 index 00000000000..e5bc527f6c3 --- /dev/null +++ b/tests/datamodules/test_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import torch +from torch.utils.data import TensorDataset + +from torchgeo.datamodules.utils import dataset_split + + +def test_dataset_split() -> None: + num_samples = 24 + x = torch.ones(num_samples, 5) # type: ignore[attr-defined] + y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined] + ds = TensorDataset(x, y) + + # Test only train/val set split + train_ds, val_ds = dataset_split(ds, val_pct=1 / 2) + assert len(train_ds) == num_samples // 2 + assert len(val_ds) == num_samples // 2 + + # Test train/val/test set split + train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3) + assert len(train_ds) == num_samples // 3 + assert len(val_ds) == num_samples // 3 + assert len(test_ds) == num_samples // 3 diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py new file mode 100644 index 00000000000..453a987ecef --- /dev/null +++ b/tests/datamodules/test_vaihingen.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import Vaihingen2DDataModule + + +class TestVaihingen2DDataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: + root = os.path.join("tests", "data", "vaihingen") + batch_size = 1 + num_workers = 0 + val_split_size = request.param + dm = Vaihingen2DDataModule( + root, batch_size, num_workers, val_split_pct=val_split_size + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py new file mode 100644 index 00000000000..5e1637533d6 --- /dev/null +++ b/tests/datamodules/test_xview2.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import XView2DataModule + + +class TestXView2DataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> XView2DataModule: + root = os.path.join("tests", "data", "xview2") + batch_size = 1 + num_workers = 0 + val_split_size = request.param + dm = XView2DataModule( + root, batch_size, num_workers, val_split_pct=val_split_size + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: XView2DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: XView2DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: XView2DataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 2561eb9f8e5..d1f9ca80f68 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -13,7 +13,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import BigEarthNet, BigEarthNetDataModule +from torchgeo.datasets import BigEarthNet def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -148,26 +148,3 @@ def test_not_downloaded(self, tmp_path: Path) -> None: "to automaticaly download the dataset." with pytest.raises(RuntimeError, match=err): BigEarthNet(str(tmp_path)) - - -class TestBigEarthNetDataModule: - @pytest.fixture(scope="class", params=["s1", "s2", "all"]) - def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: - bands = request.param - root = os.path.join("tests", "data", "bigearthnet") - num_classes = 19 - batch_size = 1 - num_workers = 0 - dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 573719ce9f3..2f321967be9 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -19,7 +19,6 @@ BoundingBox, Chesapeake13, ChesapeakeCVPR, - ChesapeakeCVPRDataModule, IntersectionDataset, UnionDataset, ) @@ -179,45 +178,3 @@ def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None: IndexError, match="query: .* spans multiple tiles which is not valid" ): ds[dataset.bounds] - - -class TestChesapeakeCVPRDataModule: - @pytest.fixture(scope="class", params=[5, 7]) - def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=32, - patches_per_tile=2, - batch_size=2, - num_workers=0, - class_set=request.param, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: - next(iter(datamodule.test_dataloader())) - - def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None: - nodata_check = datamodule.nodata_check(4) - sample = { - "image": torch.ones(1, 2, 2), # type: ignore[attr-defined] - "mask": torch.ones(2, 2), # type: ignore[attr-defined] - } - out = nodata_check(sample) - assert torch.equal( # type: ignore[attr-defined] - out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined] - ) - assert torch.equal( # type: ignore[attr-defined] - out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined] - ) diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index 6ec7b533007..11cc3744cea 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import COWCCounting, COWCCountingDataModule, COWCDetection +from torchgeo.datasets import COWCCounting, COWCDetection from torchgeo.datasets.cowc import COWC @@ -148,25 +148,3 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): COWCDetection(str(tmp_path)) - - -class TestCOWCCountingDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> COWCCountingDataModule: - root = os.path.join("tests", "data", "cowc_counting") - seed = 0 - batch_size = 1 - num_workers = 0 - dm = COWCCountingDataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index 6955143a1fb..c9bb803c856 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import CycloneDataModule, TropicalCycloneWindEstimation +from torchgeo.datasets import TropicalCycloneWindEstimation class Dataset: @@ -103,25 +103,3 @@ def test_plot(self, dataset: TropicalCycloneWindEstimation) -> None: ) dataset.plot(sample) plt.close() - - -class TestCycloneDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> CycloneDataModule: - root = os.path.join("tests", "data", "cyclone") - seed = 0 - batch_size = 1 - num_workers = 0 - dm = CycloneDataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: CycloneDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: CycloneDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: CycloneDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index 0aaee918be8..89750dd0d1c 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -14,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ETCI2021, ETCI2021DataModule +from torchgeo.datasets import ETCI2021 def download_url(url: str, root: str, *args: str) -> None: @@ -95,25 +95,3 @@ def test_plot(self, dataset: ETCI2021) -> None: x["prediction"] = x["mask"][0].clone() dataset.plot(x) plt.close() - - -class TestETCI2021DataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> ETCI2021DataModule: - root = os.path.join("tests", "data", "etci2021") - seed = 0 - batch_size = 2 - num_workers = 0 - dm = ETCI2021DataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 008195bb72a..a8b47ea2561 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import EuroSAT, EuroSATDataModule +from torchgeo.datasets import EuroSAT def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -100,24 +100,3 @@ def test_plot(self, dataset: EuroSAT) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() - - -class TestEuroSATDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> EuroSATDataModule: - root = os.path.join("tests", "data", "eurosat") - batch_size = 1 - num_workers = 0 - dm = EuroSATDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index 9b175c649a0..3f188ebb6e1 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -12,7 +12,7 @@ import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import FAIR1M, FAIR1MDataModule +from torchgeo.datasets import FAIR1M class TestFAIR1M: @@ -73,25 +73,3 @@ def test_plot(self, dataset: FAIR1M) -> None: x["prediction_boxes"] = x["boxes"].clone() dataset.plot(x) plt.close() - - -class TestFAIR1MDataModule: - @pytest.fixture(scope="class", params=[True, False]) - def datamodule(self) -> FAIR1MDataModule: - root = os.path.join("tests", "data", "fair1m") - batch_size = 2 - num_workers = 0 - dm = FAIR1MDataModule( - root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33 - ) - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 8a6942e5800..e971f88ae7f 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import LandCoverAI, LandCoverAIDataModule +from torchgeo.datasets import LandCoverAI def download_url(url: str, root: str, *args: str) -> None: @@ -78,24 +78,3 @@ def test_plot(self, dataset: LandCoverAI) -> None: x["prediction"] = x["mask"].clone() dataset.plot(x) plt.close() - - -class TestLandCoverAIDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> LandCoverAIDataModule: - root = os.path.join("tests", "data", "landcoverai") - batch_size = 2 - num_workers = 0 - dm = LandCoverAIDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py index 0bfca7bc6c2..e445ae9d3d4 100644 --- a/tests/datasets/test_loveda.py +++ b/tests/datasets/test_loveda.py @@ -14,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import LoveDA, LoveDADataModule +from torchgeo.datasets import LoveDA def download_url(url: str, root: str, *args: str) -> None: @@ -99,29 +99,3 @@ def test_not_downloaded(self, tmp_path: Path) -> None: def test_plot(self, dataset: LoveDA) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() - - -class TestLoveDADataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> LoveDADataModule: - root = os.path.join("tests", "data", "loveda") - batch_size = 2 - num_workers = 0 - scene = ["rural", "urban"] - - dm = LoveDADataModule( - root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers - ) - - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index 4ef17e7cfc7..2089d09ac45 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -12,13 +12,7 @@ from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS -from torchgeo.datasets import ( - NAIP, - BoundingBox, - IntersectionDataset, - NAIPChesapeakeDataModule, - UnionDataset, -) +from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset class TestNAIP: @@ -60,27 +54,3 @@ def test_invalid_query(self, dataset: NAIP) -> None: IndexError, match="query: .* not found in index with bounds:" ): dataset[query] - - -class TestNAIPChesapeakeDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> NAIPChesapeakeDataModule: - dm = NAIPChesapeakeDataModule( - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "chesapeake", "BAYWIDE"), - batch_size=2, - num_workers=0, - ) - dm.patch_size = 32 - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index a8b20c2cd29..deb8366ddfd 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -13,7 +13,7 @@ import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import NASAMarineDebris, NASAMarineDebrisDataModule +from torchgeo.datasets import NASAMarineDebris class Dataset: @@ -85,28 +85,3 @@ def test_plot(self, dataset: NASAMarineDebris) -> None: x["prediction_boxes"] = x["boxes"].clone() dataset.plot(x) plt.close() - - -class TestNASAMarineDebrisDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> NASAMarineDebrisDataModule: - root = os.path.join("tests", "data", "nasa_marine_debris") - batch_size = 2 - num_workers = 0 - val_split_pct = 0.3 - test_split_pct = 0.3 - dm = NASAMarineDebrisDataModule( - root, batch_size, num_workers, val_split_pct, test_split_pct - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index 2bfaf25d50d..be61d08b8e0 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -16,7 +16,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import OSCD, OSCDDataModule +from torchgeo.datasets import OSCD def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -105,56 +105,3 @@ def test_not_downloaded(self, tmp_path: Path) -> None: def test_plot(self, dataset: OSCD) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() - - -class TestOSCDDataModule: - @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) - def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands, val_split_pct = request.param - patch_size = (2, 2) - num_patches_per_tile = 2 - root = os.path.join("tests", "data", "oscd") - batch_size = 1 - num_workers = 0 - dm = OSCDDataModule( - root, - bands, - batch_size, - num_workers, - val_split_pct, - patch_size, - num_patches_per_tile, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.train_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 - else: - assert sample["image"].shape[1] == 6 - - def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.val_dataloader())) - if datamodule.val_split_pct > 0.0: - assert ( - sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - ) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 - else: - assert sample["image"].shape[1] == 6 - - def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 - else: - assert sample["image"].shape[1] == 6 diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py index b11d0dc138e..6a298baf359 100644 --- a/tests/datasets/test_potsdam.py +++ b/tests/datasets/test_potsdam.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import Potsdam2D, Potsdam2DDataModule +from torchgeo.datasets import Potsdam2D class TestPotsdam2D: @@ -75,27 +75,3 @@ def test_plot(self, dataset: Potsdam2D) -> None: x["prediction"] = x["mask"].clone() dataset.plot(x) plt.close() - - -class TestPotsdam2DDataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> Potsdam2DDataModule: - root = os.path.join("tests", "data", "potsdam") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = Potsdam2DDataModule( - root, batch_size, num_workers, val_split_pct=val_split_size - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index 75ed6ee2d58..c8f4d9157d1 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import RESISC45, RESISC45DataModule +from torchgeo.datasets import RESISC45 def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -101,24 +101,3 @@ def test_plot(self, dataset: RESISC45) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() - - -class TestRESISC45DataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> RESISC45DataModule: - root = os.path.join("tests", "data", "resisc45") - batch_size = 2 - num_workers = 0 - dm = RESISC45DataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 2332e70f39d..94c7964340d 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -12,7 +12,7 @@ from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import SEN12MS, SEN12MSDataModule +from torchgeo.datasets import SEN12MS class TestSEN12MS: @@ -82,26 +82,3 @@ def test_band_subsets(self) -> None: ds = SEN12MS(root, bands=bands, checksum=False) x = ds[0]["image"] assert x.shape[0] == len(bands) - - -class TestSEN12MSDataModule: - @pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"]) - def datamodule(self, request: SubRequest) -> SEN12MSDataModule: - root = os.path.join("tests", "data", "sen12ms") - seed = 0 - bands = request.param - batch_size = 1 - num_workers = 0 - dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index ab4085ba5e8..7df9fe5dd2b 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import So2Sat, So2SatDataModule +from torchgeo.datasets import So2Sat pytest.importorskip("h5py") @@ -91,25 +91,3 @@ def test_mock_missing_module( match="h5py is not installed and is required to use this dataset", ): So2Sat(dataset.root) - - -class TestSo2SatDataModule: - @pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"])) - def datamodule(self, request: SubRequest) -> So2SatDataModule: - unsupervised_mode, bands = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: So2SatDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: So2SatDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: So2SatDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index ad6efb6628b..600c2595d4d 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import UCMerced, UCMercedDataModule +from torchgeo.datasets import UCMerced def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -102,24 +102,3 @@ def test_plot(self, dataset: UCMerced) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() - - -class TestUCMercedDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> UCMercedDataModule: - root = os.path.join("tests", "data", "ucmerced") - batch_size = 2 - num_workers = 0 - dm = UCMercedDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 9e8732deac4..631897bb7f2 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -18,13 +18,11 @@ import torch from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS -from torch.utils.data import TensorDataset import torchgeo.datasets.utils from torchgeo.datasets.utils import ( BoundingBox, concat_samples, - dataset_split, disambiguate_timestamp, download_and_extract_archive, download_radiant_mlhub_collection, @@ -563,24 +561,6 @@ def test_nonexisting_directory(tmp_path: Path) -> None: assert subdir.cwd() == subdir -def test_dataset_split() -> None: - num_samples = 24 - x = torch.ones(num_samples, 5) # type: ignore[attr-defined] - y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined] - ds = TensorDataset(x, y) - - # Test only train/val set split - train_ds, val_ds = dataset_split(ds, val_pct=1 / 2) - assert len(train_ds) == num_samples // 2 - assert len(val_ds) == num_samples // 2 - - # Test train/val/test set split - train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3) - assert len(train_ds) == num_samples // 3 - assert len(val_ds) == num_samples // 3 - assert len(test_ds) == num_samples // 3 - - def test_percentile_normalization() -> None: img = np.array([[1, 2], [98, 100]]) diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py index 531dd24e592..033017ea0ee 100644 --- a/tests/datasets/test_vaihingen.py +++ b/tests/datasets/test_vaihingen.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import Vaihingen2D, Vaihingen2DDataModule +from torchgeo.datasets import Vaihingen2D class TestVaihingen2D: @@ -84,27 +84,3 @@ def test_plot(self, dataset: Vaihingen2D) -> None: x["prediction"] = x["mask"].clone() dataset.plot(x) plt.close() - - -class TestVaihingen2DDataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: - root = os.path.join("tests", "data", "vaihingen") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = Vaihingen2DDataModule( - root, batch_size, num_workers, val_split_pct=val_split_size - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index b358337617c..92e00f4c7fd 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import XView2, XView2DataModule +from torchgeo.datasets import XView2 class TestXView2: @@ -95,27 +95,3 @@ def test_plot(self, dataset: XView2) -> None: x["prediction"] = x["mask"][0].clone() dataset.plot(x) plt.close() - - -class TestXView2DataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> XView2DataModule: - root = os.path.join("tests", "data", "xview2") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = XView2DataModule( - root, batch_size, num_workers, val_split_pct=val_split_size - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: XView2DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: XView2DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: XView2DataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index ac5e9e2b792..304d4a5bf17 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -12,7 +12,7 @@ from pytorch_lightning.core.lightning import LightningModule from torchvision.models import resnet18 -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation diff --git a/tests/trainers/test_chesapeake.py b/tests/trainers/test_chesapeake.py index a9c95907dbc..920936802a7 100644 --- a/tests/trainers/test_chesapeake.py +++ b/tests/trainers/test_chesapeake.py @@ -9,7 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py index b14d5b46684..d3e70dfb098 100644 --- a/tests/trainers/test_landcoverai.py +++ b/tests/trainers/test_landcoverai.py @@ -8,7 +8,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import LandCoverAIDataModule +from torchgeo.datamodules import LandCoverAIDataModule from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_naipchesapeake.py b/tests/trainers/test_naipchesapeake.py index 3b8cce5aca0..37d94cb0ed8 100644 --- a/tests/trainers/test_naipchesapeake.py +++ b/tests/trainers/test_naipchesapeake.py @@ -8,7 +8,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import NAIPChesapeakeDataModule +from torchgeo.datamodules import NAIPChesapeakeDataModule from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index cfa7e16924b..ed3af3a3b63 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -8,7 +8,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import CycloneDataModule +from torchgeo.datamodules import CycloneDataModule from torchgeo.trainers import RegressionTask from .test_utils import mocked_log diff --git a/tests/trainers/test_resisc45.py b/tests/trainers/test_resisc45.py index 0b832295faf..1eec36e2fee 100644 --- a/tests/trainers/test_resisc45.py +++ b/tests/trainers/test_resisc45.py @@ -7,7 +7,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import RESISC45DataModule +from torchgeo.datamodules import RESISC45DataModule from torchgeo.trainers.resisc45 import RESISC45ClassificationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 058f94170b7..658e4870f57 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -9,7 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers import SemanticSegmentationTask from .test_utils import FakeTrainer, mocked_log diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py index cc1b2abb965..a447a5f873f 100644 --- a/torchgeo/__init__.py +++ b/torchgeo/__init__.py @@ -13,7 +13,7 @@ import pytorch_lightning as pl -from .datasets import ( +from .datamodules import ( BigEarthNetDataModule, ChesapeakeCVPRDataModule, COWCCountingDataModule, diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py new file mode 100644 index 00000000000..e09fe0ab378 --- /dev/null +++ b/torchgeo/datamodules/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TorchGeo datamodules.""" + +from .bigearthnet import BigEarthNetDataModule +from .chesapeake import ChesapeakeCVPRDataModule +from .cowc import COWCCountingDataModule +from .cyclone import CycloneDataModule +from .etci2021 import ETCI2021DataModule +from .eurosat import EuroSATDataModule +from .fair1m import FAIR1MDataModule +from .landcoverai import LandCoverAIDataModule +from .loveda import LoveDADataModule +from .naip import NAIPChesapeakeDataModule +from .nasa_marine_debris import NASAMarineDebrisDataModule +from .oscd import OSCDDataModule +from .potsdam import Potsdam2DDataModule +from .resisc45 import RESISC45DataModule +from .sen12ms import SEN12MSDataModule +from .so2sat import So2SatDataModule +from .ucmerced import UCMercedDataModule +from .vaihingen import Vaihingen2DDataModule +from .xview import XView2DataModule + +__all__ = ( + # GeoDataset + "ChesapeakeCVPRDataModule", + "NAIPChesapeakeDataModule", + # VisionDataset + "BigEarthNetDataModule", + "COWCCountingDataModule", + "ETCI2021DataModule", + "EuroSATDataModule", + "FAIR1MDataModule", + "LandCoverAIDataModule", + "LoveDADataModule", + "NASAMarineDebrisDataModule", + "OSCDDataModule", + "Potsdam2DDataModule", + "RESISC45DataModule", + "SEN12MSDataModule", + "So2SatDataModule", + "CycloneDataModule", + "UCMercedDataModule", + "Vaihingen2DDataModule", + "XView2DataModule", +) + +# https://stackoverflow.com/questions/40018681 +for module in __all__: + globals()[module].__module__ = "torchgeo.datamodules" diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py new file mode 100644 index 00000000000..11c2e4ed9ab --- /dev/null +++ b/torchgeo/datamodules/bigearthnet.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""BigEarthNet datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import BigEarthNet + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class BigEarthNetDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the BigEarthNet dataset. + + Uses the train/val/test splits from the dataset. + """ + + # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) + # min/max band statistics computed on 100k random samples + band_mins_raw = torch.tensor( # type: ignore[attr-defined] + [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] + ) + band_maxs_raw = torch.tensor( # type: ignore[attr-defined] + [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18976.0, + 17874.0, + 16611.0, + 16512.0, + 16394.0, + 16672.0, + 16141.0, + 16097.0, + 15336.0, + 15203.0, + ] + ) + + # min/max band statistics computed by percentile clipping the + # above to samples to [2, 98] + band_mins = torch.tensor( # type: ignore[attr-defined] + [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + band_maxs = torch.tensor( # type: ignore[attr-defined] + [ + 6.0, + 16.0, + 9859.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, + ] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + num_classes: int = 19, + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes + bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} + num_classes: number of classes to load in target. one of {19, 43} + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.num_classes = num_classes + self.batch_size = batch_size + self.num_workers = num_workers + + if bands == "all": + self.mins = self.band_mins[:, None, None] + self.maxs = self.band_maxs[:, None, None] + elif bands == "s1": + self.mins = self.band_mins[:2, None, None] + self.maxs = self.band_maxs[:2, None, None] + else: + self.mins = self.band_mins[2:, None, None] + self.maxs = self.band_maxs[2:, None, None] + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) + sample["image"] = torch.clip( # type: ignore[attr-defined] + sample["image"], min=0.0, max=1.0 + ) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + transforms = Compose([self.preprocess]) + self.train_dataset = BigEarthNet( + self.root_dir, + split="train", + bands=self.bands, + num_classes=self.num_classes, + transforms=transforms, + ) + self.val_dataset = BigEarthNet( + self.root_dir, + split="val", + bands=self.bands, + num_classes=self.num_classes, + transforms=transforms, + ) + self.test_dataset = BigEarthNet( + self.root_dir, + split="test", + bands=self.bands, + num_classes=self.num_classes, + transforms=transforms, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py new file mode 100644 index 00000000000..225f4c31169 --- /dev/null +++ b/torchgeo/datamodules/chesapeake.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Chesapeake Bay High-Resolution Land Cover Project datamodule.""" + +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from pytorch_lightning.core.datamodule import LightningDataModule +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import ChesapeakeCVPR, stack_samples +from ..samplers.batch import RandomBatchGeoSampler +from ..samplers.single import GridGeoSampler + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class ChesapeakeCVPRDataModule(LightningDataModule): + """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. + + Uses the random splits defined per state to partition tiles into train, val, + and test sets. + """ + + def __init__( + self, + root_dir: str, + train_splits: List[str], + val_splits: List[str], + test_splits: List[str], + patches_per_tile: int = 200, + patch_size: int = 256, + batch_size: int = 64, + num_workers: int = 0, + class_set: int = 7, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset + classes + train_splits: The splits used to train the model, e.g. ["ny-train"] + val_splits: The splits used to validate the model, e.g. ["ny-val"] + test_splits: The splits used to test the model, e.g. ["ny-test"] + patches_per_tile: The number of patches per tile to sample + patch_size: The size of each patch in pixels (test patches will be 1.5 times + this size) + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + class_set: The high-resolution land cover class set to use - 5 or 7 + """ + super().__init__() # type: ignore[no-untyped-call] + for state in train_splits + val_splits + test_splits: + assert state in ChesapeakeCVPR.splits + assert class_set in [5, 7] + + self.root_dir = root_dir + self.train_splits = train_splits + self.val_splits = val_splits + self.test_splits = test_splits + self.layers = ["naip-new", "lc"] + self.patches_per_tile = patches_per_tile + self.patch_size = patch_size + # This is a rough estimate of how large of a patch we will need to sample in + # EPSG:3857 in order to guarantee a large enough patch in the local CRS. + self.original_patch_size = int(patch_size * 2.0) + self.batch_size = batch_size + self.num_workers = num_workers + self.class_set = class_set + + def pad_to( + self, size: int = 512, image_value: int = 0, mask_value: int = 0 + ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: + """Returns a function to perform a padding transform on a single sample. + + Args: + size: output image size + image_value: value to pad image with + mask_value: value to pad mask with + + Returns: + function to perform padding + """ + + def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + _, height, width = sample["image"].shape + assert height <= size and width <= size + + height_pad = size - height + width_pad = size - width + + # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + # for a description of the format of the padding tuple + sample["image"] = F.pad( + sample["image"], + (0, width_pad, 0, height_pad), + mode="constant", + value=image_value, + ) + sample["mask"] = F.pad( + sample["mask"], + (0, width_pad, 0, height_pad), + mode="constant", + value=mask_value, + ) + return sample + + return pad_inner + + def center_crop( + self, size: int = 512 + ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: + """Returns a function to perform a center crop transform on a single sample. + + Args: + size: output image size + + Returns: + function to perform center crop + """ + + def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + _, height, width = sample["image"].shape + + y1 = (height - size) // 2 + x1 = (width - size) // 2 + sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] + sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] + + return sample + + return center_crop_inner + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Preprocesses a single sample. + + Args: + sample: sample dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 + sample["mask"] = sample["mask"] + sample["mask"] = sample["mask"].squeeze() + + if self.class_set == 5: + sample["mask"][sample["mask"] == 5] = 4 + sample["mask"][sample["mask"] == 6] = 4 + + sample["image"] = sample["image"].float() + sample["mask"] = sample["mask"].long() + + return sample + + def nodata_check( + self, size: int = 512 + ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: + """Returns a function to check for nodata or mis-sized input. + + Args: + size: output image size + + Returns: + function to check for nodata values + """ + + def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + num_channels, height, width = sample["image"].shape + + if height < size or width < size: + sample["image"] = torch.zeros( # type: ignore[attr-defined] + (num_channels, size, size) + ) + sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined] + + return sample + + return nodata_check_inner + + def prepare_data(self) -> None: + """Confirms that the dataset is downloaded on the local node. + + This method is called once per node, while :func:`setup` is called once per GPU. + """ + ChesapeakeCVPR( + self.root_dir, + splits=self.train_splits, + layers=self.layers, + transforms=None, + download=False, + checksum=False, + ) + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + Args: + stage: stage to set up + """ + train_transforms = Compose( + [ + self.center_crop(self.patch_size), + self.nodata_check(self.patch_size), + self.preprocess, + ] + ) + val_transforms = Compose( + [ + self.center_crop(self.patch_size), + self.nodata_check(self.patch_size), + self.preprocess, + ] + ) + test_transforms = Compose( + [ + self.pad_to(self.original_patch_size, image_value=0, mask_value=0), + self.preprocess, + ] + ) + + self.train_dataset = ChesapeakeCVPR( + self.root_dir, + splits=self.train_splits, + layers=self.layers, + transforms=train_transforms, + download=False, + checksum=False, + ) + self.val_dataset = ChesapeakeCVPR( + self.root_dir, + splits=self.val_splits, + layers=self.layers, + transforms=val_transforms, + download=False, + checksum=False, + ) + self.test_dataset = ChesapeakeCVPR( + self.root_dir, + splits=self.test_splits, + layers=self.layers, + transforms=test_transforms, + download=False, + checksum=False, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + sampler = RandomBatchGeoSampler( + self.train_dataset, + size=self.original_patch_size, + batch_size=self.batch_size, + length=self.patches_per_tile * len(self.train_dataset), + ) + return DataLoader( + self.train_dataset, + batch_sampler=sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + sampler = GridGeoSampler( + self.val_dataset, + size=self.original_patch_size, + stride=self.original_patch_size, + ) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + sampler = GridGeoSampler( + self.test_dataset, + size=self.original_patch_size, + stride=self.original_patch_size, + ) + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py new file mode 100644 index 00000000000..4d6e4a7cdb8 --- /dev/null +++ b/torchgeo/datamodules/cowc.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""COWC datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch import Generator # type: ignore[attr-defined] +from torch.utils.data import DataLoader, random_split + +from ..datasets import COWCCounting + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class COWCCountingDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the COWC Counting dataset.""" + + def __init__( + self, + root_dir: str, + seed: int, + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for COWC Counting based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class + seed: The seed value to use when doing the dataset random_split + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.seed = seed + self.batch_size = batch_size + self.num_workers = num_workers + + def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and target + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 # scale to [0, 1] + sample["label"] = sample["label"].float() + return sample + + def prepare_data(self) -> None: + """Initialize the main ``Dataset`` objects for use in :func:`setup`. + + This includes optionally downloading the dataset. This is done once per node, + while :func:`setup` is done once per GPU. + """ + COWCCounting(self.root_dir, download=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + Args: + stage: stage to set up + """ + train_val_dataset = COWCCounting( + self.root_dir, split="train", transforms=self.custom_transform + ) + self.test_dataset = COWCCounting( + self.root_dir, split="test", transforms=self.custom_transform + ) + self.train_dataset, self.val_dataset = random_split( + train_val_dataset, + [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], + generator=Generator().manual_seed(self.seed), + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py new file mode 100644 index 00000000000..929628e7c37 --- /dev/null +++ b/torchgeo/datamodules/cyclone.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Tropical Cyclone Wind Estimation Competition datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from sklearn.model_selection import GroupShuffleSplit +from torch.utils.data import DataLoader, Subset + +from ..datasets import TropicalCycloneWindEstimation + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class CycloneDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the NASA Cyclone dataset. + + Implements 80/20 train/val splits based on hurricane storm ids. + See :func:`setup` for more details. + """ + + def __init__( + self, + root_dir: str, + seed: int, + batch_size: int = 64, + num_workers: int = 0, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the + TropicalCycloneWindEstimation Datasets classes + seed: The seed value to use when doing the sklearn based GroupShuffleSplit + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + api_key: The RadiantEarth MLHub API key to use if the dataset needs to be + downloaded + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.seed = seed + self.batch_size = batch_size + self.num_workers = num_workers + self.api_key = api_key + + def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and target + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 # scale to [0,1] + sample["image"] = ( + sample["image"].unsqueeze(0).repeat(3, 1, 1) + ) # convert to 3 channel + sample["label"] = torch.as_tensor( # type: ignore[attr-defined] + sample["label"] + ).float() + + return sample + + def prepare_data(self) -> None: + """Initialize the main ``Dataset`` objects for use in :func:`setup`. + + This includes optionally downloading the dataset. This is done once per node, + while :func:`setup` is done once per GPU. + """ + TropicalCycloneWindEstimation( + self.root_dir, + split="train", + transforms=self.custom_transform, + download=self.api_key is not None, + api_key=self.api_key, + ) + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + We split samples between train/val by the ``storm_id`` property. I.e. all + samples with the same ``storm_id`` value will be either in the train or the val + split. This is important to test one type of generalizability -- given a new + storm, can we predict its windspeed. The test set, however, contains *some* + storms from the training set (specifically, the latter parts of the storms) as + well as some novel storms. + + Args: + stage: stage to set up + """ + self.all_train_dataset = TropicalCycloneWindEstimation( + self.root_dir, + split="train", + transforms=self.custom_transform, + download=False, + ) + + self.all_test_dataset = TropicalCycloneWindEstimation( + self.root_dir, + split="test", + transforms=self.custom_transform, + download=False, + ) + + storm_ids = [] + for item in self.all_train_dataset.collection: + storm_id = item["href"].split("/")[0].split("_")[-2] + storm_ids.append(storm_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( + storm_ids, groups=storm_ids + ) + ) + + self.train_dataset = Subset(self.all_train_dataset, train_indices) + self.val_dataset = Subset(self.all_train_dataset, val_indices) + self.test_dataset = Subset( + self.all_test_dataset, range(len(self.all_test_dataset)) + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py new file mode 100644 index 00000000000..5db89a07379 --- /dev/null +++ b/torchgeo/datamodules/etci2021.py @@ -0,0 +1,151 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""ETCI 2021 datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch import Generator # type: ignore[attr-defined] +from torch.utils.data import DataLoader, random_split +from torchvision.transforms import Normalize + +from ..datasets import ETCI2021 + + +class ETCI2021DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the ETCI2021 dataset. + + Splits the existing train split from the dataset into train/val with 80/20 + proportions, then uses the existing val dataset as the test data. + + .. versionadded:: 0.2 + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1] + ) + + def __init__( + self, + root_dir: str, + seed: int = 0, + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for ETCI2021 based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes + seed: The seed value to use when doing the dataset random_split + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.seed = seed + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Notably, moves the given water mask to act as an input layer. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + image = sample["image"] + water_mask = sample["mask"][0].unsqueeze(0) + flood_mask = sample["mask"][1] + flood_mask = (flood_mask > 0).long() + + sample["image"] = torch.cat( # type: ignore[attr-defined] + [image, water_mask], dim=0 + ).float() + sample["image"] /= 255.0 + sample["image"] = self.norm(sample["image"]) + sample["mask"] = flood_mask + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + ETCI2021(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_val_dataset = ETCI2021( + self.root_dir, split="train", transforms=self.preprocess + ) + self.test_dataset = ETCI2021( + self.root_dir, split="val", transforms=self.preprocess + ) + + size_train_val = len(train_val_dataset) + size_train = int(0.8 * size_train_val) + size_val = size_train_val - size_train + + self.train_dataset, self.val_dataset = random_split( + train_val_dataset, + [size_train, size_val], + generator=Generator().manual_seed(self.seed), + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py new file mode 100644 index 00000000000..72708e07019 --- /dev/null +++ b/torchgeo/datamodules/eurosat.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""EuroSAT datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import EuroSAT + + +class EuroSATDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the EuroSAT dataset. + + Uses the train/val/test splits from the dataset. + + .. versionadded:: 0.2 + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [ + 1354.40546513, + 1118.24399958, + 1042.92983953, + 947.62620298, + 1199.47283961, + 1999.79090914, + 2369.22292565, + 2296.82608323, + 732.08340178, + 12.11327804, + 1819.01027855, + 1118.92391149, + 2594.14080798, + ] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [ + 245.71762908, + 333.00778264, + 395.09249139, + 593.75055589, + 566.4170017, + 861.18399006, + 1086.63139075, + 1117.98170791, + 404.91978886, + 4.77584468, + 1002.58768311, + 761.30323499, + 1231.58581042, + ] + ) + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for EuroSAT based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + EuroSAT(self.root_dir) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms) + self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms) + self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py new file mode 100644 index 00000000000..15a8cbfca52 --- /dev/null +++ b/torchgeo/datamodules/fair1m.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""FAIR1M datamodule.""" + +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +import torch +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import FAIR1M +from .utils import dataset_split + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable number of boxes. + + Args: + batch: list of sample dicts return by dataset + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["boxes"] = [sample["boxes"] for sample in batch] + return output + + +class FAIR1MDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the FAIR1M dataset.""" + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for FAIR1M based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = FAIR1M(self.root_dir, transforms=transforms) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py new file mode 100644 index 00000000000..95256dffe5a --- /dev/null +++ b/torchgeo/datamodules/landcoverai.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LandCover.ai datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from ..datasets import LandCoverAI + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class LandCoverAIDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the LandCover.ai dataset. + + Uses the train/val/test splits from the dataset. + """ + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for LandCover.ai based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 + + sample["image"] = sample["image"].float() + sample["mask"] = sample["mask"].float().unsqueeze(0) + 1 + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + _ = LandCoverAI(self.root_dir, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_transforms = self.preprocess + val_test_transforms = self.preprocess + + self.train_dataset = LandCoverAI( + self.root_dir, split="train", transforms=train_transforms + ) + + self.val_dataset = LandCoverAI( + self.root_dir, split="val", transforms=val_test_transforms + ) + + self.test_dataset = LandCoverAI( + self.root_dir, split="test", transforms=val_test_transforms + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py new file mode 100644 index 00000000000..4aeae5323b6 --- /dev/null +++ b/torchgeo/datamodules/loveda.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LoveDA datamodule.""" + +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from ..datasets import LoveDA + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class LoveDADataModule(pl.LightningDataModule): + """LightningDataModule implementation for the LoveDA dataset. + + Uses the train/val/test splits from the dataset. + """ + + def __init__( + self, + root_dir: str, + scene: List[str], + batch_size: int = 32, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for LoveDA based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to LoveDA Dataset classes + scene: specify whether to load only 'urban', only 'rural' or both + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.scene = scene + self.batch_size = batch_size + self.num_workers = num_workers + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + _ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_transforms = self.preprocess + val_test_transforms = self.preprocess + + self.train_dataset = LoveDA( + self.root_dir, split="train", scene=self.scene, transforms=train_transforms + ) + + self.val_dataset = LoveDA( + self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms + ) + + self.test_dataset = LoveDA( + self.root_dir, + split="test", + scene=self.scene, + transforms=val_test_transforms, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py new file mode 100644 index 00000000000..b00d142edee --- /dev/null +++ b/torchgeo/datamodules/naip.py @@ -0,0 +1,161 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""National Agriculture Imagery Program (NAIP) datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples +from ..samplers.batch import RandomBatchGeoSampler +from ..samplers.single import GridGeoSampler + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class NAIPChesapeakeDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the NAIP and Chesapeake datasets. + + Uses the train/val/test splits from the dataset. + """ + + # TODO: tune these hyperparams + length = 1000 + stride = 128 + + def __init__( + self, + naip_root_dir: str, + chesapeake_root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + patch_size: int = 256, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. + + Args: + naip_root_dir: directory containing NAIP data + chesapeake_root_dir: directory containing Chesapeake data + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + patch_size: size of patches to sample + """ + super().__init__() # type: ignore[no-untyped-call] + self.naip_root_dir = naip_root_dir + self.chesapeake_root_dir = chesapeake_root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.patch_size = patch_size + + def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the NAIP Dataset. + + Args: + sample: NAIP image dictionary + + Returns: + preprocessed NAIP data + """ + sample["image"] = sample["image"] / 255.0 + sample["image"] = sample["image"].float() + return sample + + def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Chesapeake Dataset. + + Args: + sample: Chesapeake mask dictionary + + Returns: + preprocessed Chesapeake data + """ + sample["mask"] = sample["mask"].long()[0] + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: state to set up + """ + # TODO: these transforms will be applied independently, this won't work if we + # add things like random horizontal flip + chesapeake = Chesapeake13( + self.chesapeake_root_dir, transforms=self.chesapeake_transform + ) + naip = NAIP( + self.naip_root_dir, + chesapeake.crs, + chesapeake.res, + transforms=self.naip_transform, + ) + self.dataset = chesapeake & naip + + # TODO: figure out better train/val/test split + roi = self.dataset.bounds + midx = roi.minx + (roi.maxx - roi.minx) / 2 + midy = roi.miny + (roi.maxy - roi.miny) / 2 + train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) + val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) + test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) + + self.train_sampler = RandomBatchGeoSampler( + naip, self.patch_size, self.batch_size, self.length, train_roi + ) + self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi) + self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.dataset, + batch_sampler=self.train_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.val_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.test_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py new file mode 100644 index 00000000000..e6337e9fb6a --- /dev/null +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NASA Marine Debris datamodule.""" + +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +import torch +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import NASAMarineDebris +from .utils import dataset_split + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable boxes. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["boxes"] = [sample["boxes"] for sample in batch] + return output + + +class NASAMarineDebrisDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the NASA Marine Debris dataset.""" + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the Dataset class + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + NASAMarineDebris(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = NASAMarineDebris(self.root_dir, transforms=transforms) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py new file mode 100644 index 00000000000..f77f95310f7 --- /dev/null +++ b/torchgeo/datamodules/oscd.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""OSCD datamodule.""" + +from typing import Any, Dict, List, Optional, Tuple + +import kornia.augmentation as K +import pytorch_lightning as pl +import torch +from einops import repeat +from torch.utils.data import DataLoader, Dataset +from torch.utils.data._utils.collate import default_collate +from torchvision.transforms import Compose, Normalize + +from ..datasets import OSCD +from .utils import dataset_split + + +class OSCDDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the OSCD dataset. + + Uses the train/test splits from the dataset and further splits + the train split into train/val splits. + + .. versionadded: 0.2 + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [ + 1583.0741, + 1374.3202, + 1294.1616, + 1325.6158, + 1478.7408, + 1933.0822, + 2166.0608, + 2076.4868, + 2306.0652, + 690.9814, + 16.2360, + 2080.3347, + 1524.6930, + ] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [ + 52.1937, + 83.4168, + 105.6966, + 151.1401, + 147.4615, + 115.9289, + 123.1974, + 114.6483, + 141.4530, + 73.2758, + 4.8368, + 213.4821, + 179.4793, + ] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + train_batch_size: int = 32, + num_workers: int = 0, + val_split_pct: float = 0.2, + patch_size: Tuple[int, int] = (64, 64), + num_patches_per_tile: int = 32, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for OSCD based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the OSCD Dataset classes + bands: "rgb" or "all" + train_batch_size: The batch size used in the train DataLoader + (val_batch_size == test_batch_size == 1) + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + patch_size: Size of random patch from image and mask (height, width) + num_patches_per_tile: number of random patches per sample + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.train_batch_size = train_batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.patch_size = patch_size + self.num_patches_per_tile = num_patches_per_tile + + if bands == "rgb": + self.band_means = self.band_means[[3, 2, 1], None, None] + self.band_stds = self.band_stds[[3, 2, 1], None, None] + else: + self.band_means = self.band_means[:, None, None] + self.band_stds = self.band_stds[:, None, None] + + self.norm = Normalize(self.band_means, self.band_stds) + self.rcrop = K.AugmentationSequential( + K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True + ) + self.padto = K.PadTo((1280, 1280)) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["mask"] = sample["mask"] + sample["image"] = self.norm(sample["image"]) + sample["image"] = torch.flatten( # type: ignore[attr-defined] + sample["image"], 0, 1 + ) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + OSCD(self.root_dir, split="train", bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + + def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: + images, masks = [], [] + for i in range(self.num_patches_per_tile): + mask = repeat(sample["mask"], "h w -> t h w", t=2).float() + image, mask = self.rcrop(sample["image"], mask) + mask = mask.squeeze()[0] + images.append(image.squeeze()) + masks.append(mask.long()) + sample["image"] = torch.stack(images) + sample["mask"] = torch.stack(masks) + return sample + + def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: + sample["image"] = self.padto(sample["image"])[0] + sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] + return sample + + train_transforms = Compose([self.preprocess, n_random_crop]) + # for testing and validation we pad all inputs to a fixed size to avoid issues + # with the upsampling paths in encoder-decoder architectures + test_transforms = Compose([self.preprocess, pad_to]) + + train_dataset = OSCD( + self.root_dir, split="train", bands=self.bands, transforms=train_transforms + ) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + val_dataset = OSCD( + self.root_dir, + split="train", + bands=self.bands, + transforms=test_transforms, + ) + self.train_dataset, self.val_dataset, _ = dataset_split( + train_dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + self.val_dataset.dataset = val_dataset + else: + self.train_dataset = train_dataset + self.val_dataset = train_dataset + + self.test_dataset = OSCD( + self.root_dir, split="test", bands=self.bands, transforms=test_transforms + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + + def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] + batch + ) + r_batch["image"] = torch.flatten( # type: ignore[attr-defined] + r_batch["image"], 0, 1 + ) + r_batch["mask"] = torch.flatten( # type: ignore[attr-defined] + r_batch["mask"], 0, 1 + ) + return r_batch + + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + num_workers=self.num_workers, + collate_fn=collate_wrapper, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + return DataLoader( + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + return DataLoader( + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + ) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py new file mode 100644 index 00000000000..0ddbd2dba9b --- /dev/null +++ b/torchgeo/datamodules/potsdam.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Potsdam datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import Potsdam2D +from .utils import dataset_split + + +class Potsdam2DDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the Potsdam2D dataset. + + Uses the train/test splits from the dataset. + + .. versionadded: 0.2 + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for Potsdam2D based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = Potsdam2D(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py new file mode 100644 index 00000000000..844ee0968a9 --- /dev/null +++ b/torchgeo/datamodules/resisc45.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""RESISC45 datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import RESISC45 + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class RESISC45DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the RESISC45 dataset. + + Uses the train/val/test splits from the dataset. + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [0.36801773, 0.38097873, 0.343583] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [0.14540215, 0.13558227, 0.13203649] + ) + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for RESISC45 based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + RESISC45(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms) + self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms) + self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py new file mode 100644 index 00000000000..cfe5900c478 --- /dev/null +++ b/torchgeo/datamodules/sen12ms.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SEN12MS datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from sklearn.model_selection import GroupShuffleSplit +from torch.utils.data import DataLoader, Subset + +from ..datasets import SEN12MS + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class SEN12MSDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the SEN12MS dataset. + + Implements 80/20 geographic train/val splits and uses the test split from the + classification dataset definitions. See :func:`setup` for more details. + + Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See + https://arxiv.org/abs/2002.08254. + """ + + #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader + #: here https://github.com/lukasliebel/dfc2020_baseline. + DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined] + [ + 0, # maps 0s to 0 + 1, # maps 1s to 1 + 1, # maps 2s to 1 + 1, # ... + 1, + 1, + 2, + 2, + 3, + 3, + 4, + 5, + 6, + 7, + 6, + 8, + 9, + 10, + ] + ) + + def __init__( + self, + root_dir: str, + seed: int, + band_set: str = "all", + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for SEN12MS based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes + seed: The seed value to use when doing the sklearn based ShuffleSplit + band_set: The subset of S1/S2 bands to use. Options are: "all", + "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: + B2, B3, B4, B8, B11, and B12. + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + assert band_set in SEN12MS.BAND_SETS.keys() + + self.root_dir = root_dir + self.seed = seed + self.band_set = band_set + self.band_indices = SEN12MS.BAND_SETS[band_set] + self.batch_size = batch_size + self.num_workers = num_workers + + def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + + if self.band_set == "all": + sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 + sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 + elif self.band_set == "s1": + sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 + else: + sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 + + sample["mask"] = sample["mask"][0, :, :].long() + sample["mask"] = torch.take( # type: ignore[attr-defined] + self.DFC2020_CLASS_MAPPING, sample["mask"] + ) + + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + We split samples between train and val geographically with proportions of 80/20. + This mimics the geographic test set split. + + Args: + stage: stage to set up + """ + season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} + + self.all_train_dataset = SEN12MS( + self.root_dir, + split="train", + bands=self.band_indices, + transforms=self.custom_transform, + checksum=False, + ) + + self.all_test_dataset = SEN12MS( + self.root_dir, + split="test", + bands=self.band_indices, + transforms=self.custom_transform, + checksum=False, + ) + + # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" + # This patch will belong to the scene that is uniquelly identified by its + # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply + # give each season a large number and representing a `unique_scene_id` as + # `season_id + scene_id`. + scenes = [] + for scene_fn in self.all_train_dataset.ids: + parts = scene_fn.split("_") + season_id = season_to_int[parts[1]] + scene_id = int(parts[3]) + scenes.append(season_id + scene_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( + scenes, groups=scenes + ) + ) + + self.train_dataset = Subset(self.all_train_dataset, train_indices) + self.val_dataset = Subset(self.all_train_dataset, val_indices) + self.test_dataset = Subset( + self.all_test_dataset, range(len(self.all_test_dataset)) + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py new file mode 100644 index 00000000000..9f072edbf43 --- /dev/null +++ b/torchgeo/datamodules/so2sat.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""So2Sat datamodule.""" + +from typing import Any, Dict, Optional, cast + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import So2Sat + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class So2SatDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the So2Sat dataset. + + Uses the train/val/test splits from the dataset. + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [ + -3.591224256609313e-05, + -7.658561276843396e-06, + 5.9373857475971184e-05, + 2.5166231537121083e-05, + 0.04420110659759328, + 0.25761027084996196, + 0.0007556743372573258, + 0.0013503466830024448, + 0.12375696117681859, + 0.1092774636368323, + 0.1010855203267882, + 0.1142398616114001, + 0.1592656692023089, + 0.18147236008771792, + 0.1745740312291377, + 0.19501607349635292, + 0.15428468872076637, + 0.10905050699570007, + ] + ).reshape(18, 1, 1) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [ + 0.17555201137417686, + 0.17556463274968204, + 0.45998793417834255, + 0.455988755730148, + 2.8559909213125763, + 8.324800606439833, + 2.4498757382563103, + 1.4647352984509094, + 0.03958795985905458, + 0.047778262752410296, + 0.06636616706371974, + 0.06358874912497474, + 0.07744387147984592, + 0.09101635085921553, + 0.09218466562387101, + 0.10164581233948201, + 0.09991773043519253, + 0.08780632509122865, + ] + ).reshape(18, 1, 1) + + # this reorders the bands to put S2 RGB first, then remainder of S2, then S1 + reindex_to_rgb_first = [ + 10, + 9, + 8, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + # 0, + # 1, + # 2, + # 3, + # 4, + # 5, + # 6, + # 7, + ] + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + bands: str = "rgb", + unsupervised_mode: bool = False, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for So2Sat based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + bands: Either "rgb" or "s2" + unsupervised_mode: Makes the train dataloader return imagery from the train, + val, and test sets + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.bands = bands + self.unsupervised_mode = unsupervised_mode + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image + + Returns: + preprocessed sample + """ + # sample["image"] = (sample["image"] - self.band_means) / self.band_stds + sample["image"] = sample["image"].float() + sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] + + if self.bands == "rgb": + sample["image"] = sample["image"][:3, :, :] + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + So2Sat(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_transforms = Compose([self.preprocess]) + val_test_transforms = self.preprocess + + if not self.unsupervised_mode: + + self.train_dataset = So2Sat( + self.root_dir, split="train", transforms=train_transforms + ) + + self.val_dataset = So2Sat( + self.root_dir, split="validation", transforms=val_test_transforms + ) + + self.test_dataset = So2Sat( + self.root_dir, split="test", transforms=val_test_transforms + ) + + else: + + temp_train = So2Sat( + self.root_dir, split="train", transforms=train_transforms + ) + + self.val_dataset = So2Sat( + self.root_dir, split="validation", transforms=train_transforms + ) + + self.test_dataset = So2Sat( + self.root_dir, split="test", transforms=train_transforms + ) + + self.train_dataset = cast( + So2Sat, temp_train + self.val_dataset + self.test_dataset + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py new file mode 100644 index 00000000000..69cd9773384 --- /dev/null +++ b/torchgeo/datamodules/ucmerced.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""UC Merced datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +import torchvision +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import UCMerced + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class UCMercedDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the UC Merced dataset. + + Uses random train/val/test splits. + """ + + band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined] + + band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined] + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for UCMerced based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + c, h, w = sample["image"].shape + if h != 256 or w != 256: + sample["image"] = torchvision.transforms.functional.resize( + sample["image"], size=(256, 256) + ) + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + UCMerced(self.root_dir, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms) + self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms) + self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py new file mode 100644 index 00000000000..ff1f571c2b6 --- /dev/null +++ b/torchgeo/datamodules/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Common datamodule utilities.""" + +from typing import Any, List, Optional + +from torch.utils.data import Dataset, Subset, random_split + + +def dataset_split( + dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None +) -> List[Subset[Any]]: + """Split a torch Dataset into train/val/test sets. + + If ``test_pct`` is not set then only train and validation splits are returned. + + Args: + dataset: dataset to be split into train/val or train/val/test subsets + val_pct: percentage of samples to be in validation set + test_pct: (Optional) percentage of samples to be in test set + Returns: + a list of the subset datasets. Either [train, val] or [train, val, test] + """ + if test_pct is None: + val_length = int(len(dataset) * val_pct) + train_length = len(dataset) - val_length + return random_split(dataset, [train_length, val_length]) + else: + val_length = int(len(dataset) * val_pct) + test_length = int(len(dataset) * test_pct) + train_length = len(dataset) - (val_length + test_length) + return random_split(dataset, [train_length, val_length, test_length]) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py new file mode 100644 index 00000000000..afc36892c10 --- /dev/null +++ b/torchgeo/datamodules/vaihingen.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Vaihingen datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import Vaihingen2D +from .utils import dataset_split + + +class Vaihingen2DDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the Vaihingen2D dataset. + + Uses the train/test splits from the dataset. + + .. versionadded: 0.2 + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py new file mode 100644 index 00000000000..a8b5e118835 --- /dev/null +++ b/torchgeo/datamodules/xview.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""xView2 datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import XView2 +from .utils import dataset_split + + +class XView2DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the xView2 dataset. + + Uses the train/val/test splits from the dataset. + + .. versionadded: 0.2 + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for xView2 based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the xView2 Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = XView2(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = XView2(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0f5e24b38cf..7e3cf7811c2 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -5,7 +5,7 @@ from .advance import ADVANCE from .benin_cashews import BeninSmallHolderCashews -from .bigearthnet import BigEarthNet, BigEarthNetDataModule +from .bigearthnet import BigEarthNet from .cbf import CanadianBuildingFootprints from .cdl import CDL from .chesapeake import ( @@ -13,7 +13,6 @@ Chesapeake7, Chesapeake13, ChesapeakeCVPR, - ChesapeakeCVPRDataModule, ChesapeakeDC, ChesapeakeDE, ChesapeakeMD, @@ -22,12 +21,12 @@ ChesapeakeVA, ChesapeakeWV, ) -from .cowc import COWC, COWCCounting, COWCCountingDataModule, COWCDetection +from .cowc import COWC, COWCCounting, COWCDetection from .cv4a_kenya_crop_type import CV4AKenyaCropType -from .cyclone import CycloneDataModule, TropicalCycloneWindEstimation -from .etci2021 import ETCI2021, ETCI2021DataModule -from .eurosat import EuroSAT, EuroSATDataModule -from .fair1m import FAIR1M, FAIR1MDataModule +from .cyclone import TropicalCycloneWindEstimation +from .etci2021 import ETCI2021 +from .eurosat import EuroSAT +from .fair1m import FAIR1M from .geo import ( GeoDataset, IntersectionDataset, @@ -39,7 +38,7 @@ ) from .gid15 import GID15 from .idtrees import IDTReeS -from .landcoverai import LandCoverAI, LandCoverAIDataModule +from .landcoverai import LandCoverAI from .landsat import ( Landsat, Landsat1, @@ -54,23 +53,23 @@ Landsat9, ) from .levircd import LEVIRCDPlus -from .loveda import LoveDA, LoveDADataModule -from .naip import NAIP, NAIPChesapeakeDataModule -from .nasa_marine_debris import NASAMarineDebris, NASAMarineDebrisDataModule +from .loveda import LoveDA +from .naip import NAIP +from .nasa_marine_debris import NASAMarineDebris from .nwpu import VHR10 -from .oscd import OSCD, OSCDDataModule +from .oscd import OSCD from .patternnet import PatternNet -from .potsdam import Potsdam2D, Potsdam2DDataModule -from .resisc45 import RESISC45, RESISC45DataModule +from .potsdam import Potsdam2D +from .resisc45 import RESISC45 from .seco import SeasonalContrastS2 -from .sen12ms import SEN12MS, SEN12MSDataModule +from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 -from .so2sat import So2Sat, So2SatDataModule +from .so2sat import So2Sat from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 -from .ucmerced import UCMerced, UCMercedDataModule +from .ucmerced import UCMerced from .utils import BoundingBox, concat_samples, merge_samples, stack_samples -from .vaihingen import Vaihingen2D, Vaihingen2DDataModule -from .xview import XView2, XView2DataModule +from .vaihingen import Vaihingen2D +from .xview import XView2 from .zuericrop import ZueriCrop __all__ = ( @@ -88,7 +87,6 @@ "ChesapeakeVA", "ChesapeakeWV", "ChesapeakeCVPR", - "ChesapeakeCVPRDataModule", "Landsat", "Landsat1", "Landsat2", @@ -101,46 +99,32 @@ "Landsat8", "Landsat9", "NAIP", - "NAIPChesapeakeDataModule", "Sentinel", "Sentinel2", # VisionDataset "ADVANCE", "BeninSmallHolderCashews", "BigEarthNet", - "BigEarthNetDataModule", "COWC", "COWCCounting", "COWCDetection", - "COWCCountingDataModule", "CV4AKenyaCropType", "ETCI2021", - "ETCI2021DataModule", "EuroSAT", - "EuroSATDataModule", "FAIR1M", - "FAIR1MDataModule", "GID15", "IDTReeS", "LandCoverAI", - "LandCoverAIDataModule", "LEVIRCDPlus", "LoveDA", - "LoveDADataModule", "NASAMarineDebris", - "NASAMarineDebrisDataModule", "OSCD", - "OSCDDataModule", "PatternNet", "Potsdam2D", - "Potsdam2DDataModule", "RESISC45", - "RESISC45DataModule", "SeasonalContrastS2", "SEN12MS", - "SEN12MSDataModule", "So2Sat", - "So2SatDataModule", "SpaceNet", "SpaceNet1", "SpaceNet2", @@ -148,14 +132,10 @@ "SpaceNet5", "SpaceNet7", "TropicalCycloneWindEstimation", - "CycloneDataModule", "UCMerced", - "UCMercedDataModule", "Vaihingen2D", - "Vaihingen2DDataModule", "VHR10", "XView2", - "XView2DataModule", "ZueriCrop", # Base classes "GeoDataset", diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 409e0b82230..48d24766553 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -6,24 +6,17 @@ import glob import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np -import pytorch_lightning as pl import rasterio import torch from rasterio.enums import Resampling from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from .geo import VisionDataset from .utils import download_url, extract_archive, sort_sentinel2_bands -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class BigEarthNet(VisionDataset): """BigEarthNet dataset. @@ -511,164 +504,3 @@ def _extract(self, filepath: str) -> None: """ if not filepath.endswith(".csv"): extract_archive(filepath) - - -class BigEarthNetDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the BigEarthNet dataset. - - Uses the train/val/test splits from the dataset. - """ - - # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) - # min/max band statistics computed on 100k random samples - band_mins_raw = torch.tensor( # type: ignore[attr-defined] - [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] - ) - band_maxs_raw = torch.tensor( # type: ignore[attr-defined] - [ - 31.0, - 35.0, - 18556.0, - 20528.0, - 18976.0, - 17874.0, - 16611.0, - 16512.0, - 16394.0, - 16672.0, - 16141.0, - 16097.0, - 15336.0, - 15203.0, - ] - ) - - # min/max band statistics computed by percentile clipping the - # above to samples to [2, 98] - band_mins = torch.tensor( # type: ignore[attr-defined] - [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - ) - band_maxs = torch.tensor( # type: ignore[attr-defined] - [ - 6.0, - 16.0, - 9859.0, - 12872.0, - 13163.0, - 14445.0, - 12477.0, - 12563.0, - 12289.0, - 15596.0, - 12183.0, - 9458.0, - 5897.0, - 5544.0, - ] - ) - - def __init__( - self, - root_dir: str, - bands: str = "all", - num_classes: int = 19, - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for BigEarthNet based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes - bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} - num_classes: number of classes to load in target. one of {19, 43} - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.bands = bands - self.num_classes = num_classes - self.batch_size = batch_size - self.num_workers = num_workers - - if bands == "all": - self.mins = self.band_mins[:, None, None] - self.maxs = self.band_maxs[:, None, None] - elif bands == "s1": - self.mins = self.band_mins[:2, None, None] - self.maxs = self.band_maxs[:2, None, None] - else: - self.mins = self.band_mins[2:, None, None] - self.maxs = self.band_maxs[2:, None, None] - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) - sample["image"] = torch.clip( # type: ignore[attr-defined] - sample["image"], min=0.0, max=1.0 - ) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - """ - transforms = Compose([self.preprocess]) - self.train_dataset = BigEarthNet( - self.root_dir, - split="train", - bands=self.bands, - num_classes=self.num_classes, - transforms=transforms, - ) - self.val_dataset = BigEarthNet( - self.root_dir, - split="val", - bands=self.bands, - num_classes=self.num_classes, - transforms=transforms, - ) - self.test_dataset = BigEarthNet( - self.root_dir, - split="test", - bands=self.bands, - num_classes=self.num_classes, - transforms=transforms, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 1709b1a9e13..d697d96dea8 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -16,21 +16,10 @@ import shapely.geometry import shapely.ops import torch -import torch.nn.functional as F -from pytorch_lightning.core.datamodule import LightningDataModule from rasterio.crs import CRS -from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler from .geo import GeoDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive, stack_samples - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" +from .utils import BoundingBox, download_url, extract_archive class Chesapeake(RasterDataset, abc.ABC): @@ -537,294 +526,3 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.filename)) - - -class ChesapeakeCVPRDataModule(LightningDataModule): - """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. - - Uses the random splits defined per state to partition tiles into train, val, - and test sets. - """ - - def __init__( - self, - root_dir: str, - train_splits: List[str], - val_splits: List[str], - test_splits: List[str], - patches_per_tile: int = 200, - patch_size: int = 256, - batch_size: int = 64, - num_workers: int = 0, - class_set: int = 7, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset - classes - train_splits: The splits used to train the model, e.g. ["ny-train"] - val_splits: The splits used to validate the model, e.g. ["ny-val"] - test_splits: The splits used to test the model, e.g. ["ny-test"] - patches_per_tile: The number of patches per tile to sample - patch_size: The size of each patch in pixels (test patches will be 1.5 times - this size) - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - class_set: The high-resolution land cover class set to use - 5 or 7 - """ - super().__init__() # type: ignore[no-untyped-call] - for state in train_splits + val_splits + test_splits: - assert state in ChesapeakeCVPR.splits - assert class_set in [5, 7] - - self.root_dir = root_dir - self.train_splits = train_splits - self.val_splits = val_splits - self.test_splits = test_splits - self.layers = ["naip-new", "lc"] - self.patches_per_tile = patches_per_tile - self.patch_size = patch_size - # This is a rough estimate of how large of a patch we will need to sample in - # EPSG:3857 in order to guarantee a large enough patch in the local CRS. - self.original_patch_size = int(patch_size * 2.0) - self.batch_size = batch_size - self.num_workers = num_workers - self.class_set = class_set - - def pad_to( - self, size: int = 512, image_value: int = 0, mask_value: int = 0 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a padding transform on a single sample. - - Args: - size: output image size - image_value: value to pad image with - mask_value: value to pad mask with - - Returns: - function to perform padding - """ - - def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - assert height <= size and width <= size - - height_pad = size - height - width_pad = size - width - - # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - # for a description of the format of the padding tuple - sample["image"] = F.pad( - sample["image"], - (0, width_pad, 0, height_pad), - mode="constant", - value=image_value, - ) - sample["mask"] = F.pad( - sample["mask"], - (0, width_pad, 0, height_pad), - mode="constant", - value=mask_value, - ) - return sample - - return pad_inner - - def center_crop( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a center crop transform on a single sample. - - Args: - size: output image size - - Returns: - function to perform center crop - """ - - def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - - y1 = (height - size) // 2 - x1 = (width - size) // 2 - sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] - sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] - - return sample - - return center_crop_inner - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Preprocesses a single sample. - - Args: - sample: sample dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 - sample["mask"] = sample["mask"] - sample["mask"] = sample["mask"].squeeze() - - if self.class_set == 5: - sample["mask"][sample["mask"] == 5] = 4 - sample["mask"][sample["mask"] == 6] = 4 - - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"].long() - - return sample - - def nodata_check( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to check for nodata or mis-sized input. - - Args: - size: output image size - - Returns: - function to check for nodata values - """ - - def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - num_channels, height, width = sample["image"].shape - - if height < size or width < size: - sample["image"] = torch.zeros( # type: ignore[attr-defined] - (num_channels, size, size) - ) - sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined] - - return sample - - return nodata_check_inner - - def prepare_data(self) -> None: - """Confirms that the dataset is downloaded on the local node. - - This method is called once per node, while :func:`setup` is called once per GPU. - """ - ChesapeakeCVPR( - self.root_dir, - splits=self.train_splits, - layers=self.layers, - transforms=None, - download=False, - checksum=False, - ) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - Args: - stage: stage to set up - """ - train_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - ] - ) - val_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - ] - ) - test_transforms = Compose( - [ - self.pad_to(self.original_patch_size, image_value=0, mask_value=0), - self.preprocess, - ] - ) - - self.train_dataset = ChesapeakeCVPR( - self.root_dir, - splits=self.train_splits, - layers=self.layers, - transforms=train_transforms, - download=False, - checksum=False, - ) - self.val_dataset = ChesapeakeCVPR( - self.root_dir, - splits=self.val_splits, - layers=self.layers, - transforms=val_transforms, - download=False, - checksum=False, - ) - self.test_dataset = ChesapeakeCVPR( - self.root_dir, - splits=self.test_splits, - layers=self.layers, - transforms=test_transforms, - download=False, - checksum=False, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - sampler = RandomBatchGeoSampler( - self.train_dataset, - size=self.original_patch_size, - batch_size=self.batch_size, - length=self.patches_per_tile * len(self.train_dataset), - ) - return DataLoader( - self.train_dataset, - batch_sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - sampler = GridGeoSampler( - self.val_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - sampler = GridGeoSampler( - self.test_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 35bbdc54be6..f16448824b3 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -6,22 +6,16 @@ import abc import csv import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np -import pytorch_lightning as pl import torch from PIL import Image -from torch import Generator, Tensor # type: ignore[attr-defined] -from torch.utils.data import DataLoader, random_split +from torch import Tensor from .geo import VisionDataset from .utils import check_integrity, download_and_extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class COWC(VisionDataset, abc.ABC): """Abstract base class for the COWC dataset. @@ -268,110 +262,3 @@ class COWCDetection(COWC): # 4. Unknown # # May need new abstract base class. Will need subclasses for different patch sizes. - - -class COWCCountingDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the COWC Counting dataset.""" - - def __init__( - self, - root_dir: str, - seed: int, - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for COWC Counting based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class - seed: The seed value to use when doing the dataset random_split - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 # scale to [0, 1] - sample["label"] = sample["label"].float() - return sample - - def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - COWCCounting(self.root_dir, download=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - Args: - stage: stage to set up - """ - train_val_dataset = COWCCounting( - self.root_dir, split="train", transforms=self.custom_transform - ) - self.test_dataset = COWCCounting( - self.root_dir, split="test", transforms=self.custom_transform - ) - self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], - generator=Generator().manual_seed(self.seed), - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 37c20ca42d6..0229f1f85dc 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -10,20 +10,13 @@ import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image -from sklearn.model_selection import GroupShuffleSplit from torch import Tensor -from torch.utils.data import DataLoader, Subset from .geo import VisionDataset from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class TropicalCycloneWindEstimation(VisionDataset): """Tropical Cyclone Wind Estimation Competition dataset. @@ -254,157 +247,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class CycloneDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the NASA Cyclone dataset. - - Implements 80/20 train/val splits based on hurricane storm ids. - See :func:`setup` for more details. - """ - - def __init__( - self, - root_dir: str, - seed: int, - batch_size: int = 64, - num_workers: int = 0, - api_key: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the - TropicalCycloneWindEstimation Datasets classes - seed: The seed value to use when doing the sklearn based GroupShuffleSplit - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - api_key: The RadiantEarth MLHub API key to use if the dataset needs to be - downloaded - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - self.api_key = api_key - - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 # scale to [0,1] - sample["image"] = ( - sample["image"].unsqueeze(0).repeat(3, 1, 1) - ) # convert to 3 channel - sample["label"] = torch.as_tensor( # type: ignore[attr-defined] - sample["label"] - ).float() - - return sample - - def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - TropicalCycloneWindEstimation( - self.root_dir, - split="train", - transforms=self.custom_transform, - download=self.api_key is not None, - api_key=self.api_key, - ) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - We split samples between train/val by the ``storm_id`` property. I.e. all - samples with the same ``storm_id`` value will be either in the train or the val - split. This is important to test one type of generalizability -- given a new - storm, can we predict its windspeed. The test set, however, contains *some* - storms from the training set (specifically, the latter parts of the storms) as - well as some novel storms. - - Args: - stage: stage to set up - """ - self.all_train_dataset = TropicalCycloneWindEstimation( - self.root_dir, - split="train", - transforms=self.custom_transform, - download=False, - ) - - self.all_test_dataset = TropicalCycloneWindEstimation( - self.root_dir, - split="test", - transforms=self.custom_transform, - download=False, - ) - - storm_ids = [] - for item in self.all_train_dataset.collection: - storm_id = item["href"].split("/")[0].split("_")[-2] - storm_ids.append(storm_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( - storm_ids, groups=storm_ids - ) - ) - - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index bb10da22bff..dbcf667ce95 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -5,16 +5,13 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image -from torch import Generator, Tensor # type: ignore[attr-defined] -from torch.utils.data import DataLoader, random_split -from torchvision.transforms import Normalize +from torch import Tensor from .geo import VisionDataset from .utils import download_and_extract_archive @@ -320,140 +317,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class ETCI2021DataModule(pl.LightningDataModule): - """LightningDataModule implementation for the ETCI2021 dataset. - - Splits the existing train split from the dataset into train/val with 80/20 - proportions, then uses the existing val dataset as the test data. - - .. versionadded:: 0.2 - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1] - ) - - def __init__( - self, - root_dir: str, - seed: int = 0, - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for ETCI2021 based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes - seed: The seed value to use when doing the dataset random_split - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Notably, moves the given water mask to act as an input layer. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - image = sample["image"] - water_mask = sample["mask"][0].unsqueeze(0) - flood_mask = sample["mask"][1] - flood_mask = (flood_mask > 0).long() - - sample["image"] = torch.cat( # type: ignore[attr-defined] - [image, water_mask], dim=0 - ).float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - sample["mask"] = flood_mask - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - ETCI2021(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_val_dataset = ETCI2021( - self.root_dir, split="train", transforms=self.preprocess - ) - self.test_dataset = ETCI2021( - self.root_dir, split="val", transforms=self.preprocess - ) - - size_train_val = len(train_val_dataset) - size_train = int(0.8 * size_train_val) - size_val = size_train_val - size_train - - self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [size_train, size_val], - generator=Generator().manual_seed(self.seed), - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 9ba06c1c683..e1e14072039 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,15 +4,11 @@ """EuroSAT dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl -import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset from .utils import check_integrity, download_url, extract_archive, rasterio_loader @@ -229,138 +225,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class EuroSATDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the EuroSAT dataset. - - Uses the train/val/test splits from the dataset. - - .. versionadded:: 0.2 - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798, - ] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042, - ] - ) - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for EuroSAT based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - EuroSAT(self.root_dir) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms) - self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms) - self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index e75f27be35f..c8e2184f23a 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -11,33 +11,12 @@ import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import check_integrity, dataset_split, extract_archive from .geo import VisionDataset - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - - -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable number of boxes. - - Args: - batch: list of sample dicts return by dataset - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - return output +from .utils import check_integrity, extract_archive def parse_pascal_voc(path: str) -> Dict[str, Any]: @@ -350,102 +329,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class FAIR1MDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the FAIR1M dataset.""" - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - test_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for FAIR1M based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.test_split_pct = test_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = FAIR1M(self.root_dir, transforms=transforms) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index e579d668d63..2fecb5d6e10 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -6,24 +6,18 @@ import hashlib import os from functools import lru_cache -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from matplotlib.colors import ListedColormap from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader from .geo import VisionDataset from .utils import check_integrity, download_and_extract_archive, working_dir -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class LandCoverAI(VisionDataset): r"""LandCover.ai dataset. @@ -266,110 +260,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class LandCoverAIDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the LandCover.ai dataset. - - Uses the train/val/test splits from the dataset. - """ - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for LandCover.ai based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 - - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"].float().unsqueeze(0) + 1 - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - _ = LandCoverAI(self.root_dir, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = self.preprocess - val_test_transforms = self.preprocess - - self.train_dataset = LandCoverAI( - self.root_dir, split="train", transforms=train_transforms - ) - - self.val_dataset = LandCoverAI( - self.root_dir, split="val", transforms=val_test_transforms - ) - - self.test_dataset = LandCoverAI( - self.root_dir, split="test", transforms=val_test_transforms - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index b3a0a52e8ca..30fe98adfb4 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -5,23 +5,17 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader from .geo import VisionDataset from .utils import download_and_extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class LoveDA(VisionDataset): """LoveDA dataset. @@ -305,117 +299,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class LoveDADataModule(pl.LightningDataModule): - """LightningDataModule implementation for the LoveDA dataset. - - Uses the train/val/test splits from the dataset. - """ - - def __init__( - self, - root_dir: str, - scene: List[str], - batch_size: int = 32, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for LoveDA based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to LoveDA Dataset classes - scene: specify whether to load only 'urban', only 'rural' or both - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.scene = scene - self.batch_size = batch_size - self.num_workers = num_workers - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - _ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = self.preprocess - val_test_transforms = self.preprocess - - self.train_dataset = LoveDA( - self.root_dir, split="train", scene=self.scene, transforms=train_transforms - ) - - self.val_dataset = LoveDA( - self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms - ) - - self.test_dataset = LoveDA( - self.root_dir, - split="test", - scene=self.scene, - transforms=val_test_transforms, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index 02cfe1e33f6..b6b4bceceb3 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -3,20 +3,7 @@ """National Agriculture Imagery Program (NAIP) dataset.""" -from typing import Any, Dict, Optional - -import pytorch_lightning as pl -from torch.utils.data import DataLoader - -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler -from .chesapeake import Chesapeake13 from .geo import RasterDataset -from .utils import BoundingBox, stack_samples - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" class NAIP(RasterDataset): @@ -55,147 +42,3 @@ class NAIP(RasterDataset): # Plotting all_bands = ["R", "G", "B", "NIR"] rgb_bands = ["R", "G", "B"] - - -class NAIPChesapeakeDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the NAIP and Chesapeake datasets. - - Uses the train/val/test splits from the dataset. - """ - - # TODO: tune these hyperparams - length = 1000 - stride = 128 - - def __init__( - self, - naip_root_dir: str, - chesapeake_root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - patch_size: int = 256, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. - - Args: - naip_root_dir: directory containing NAIP data - chesapeake_root_dir: directory containing Chesapeake data - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - patch_size: size of patches to sample - """ - super().__init__() # type: ignore[no-untyped-call] - self.naip_root_dir = naip_root_dir - self.chesapeake_root_dir = chesapeake_root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.patch_size = patch_size - - def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the NAIP Dataset. - - Args: - sample: NAIP image dictionary - - Returns: - preprocessed NAIP data - """ - sample["image"] = sample["image"] / 255.0 - sample["image"] = sample["image"].float() - return sample - - def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Chesapeake Dataset. - - Args: - sample: Chesapeake mask dictionary - - Returns: - preprocessed Chesapeake data - """ - sample["mask"] = sample["mask"].long()[0] - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: state to set up - """ - # TODO: these transforms will be applied independently, this won't work if we - # add things like random horizontal flip - chesapeake = Chesapeake13( - self.chesapeake_root_dir, transforms=self.chesapeake_transform - ) - naip = NAIP( - self.naip_root_dir, - chesapeake.crs, - chesapeake.res, - transforms=self.naip_transform, - ) - self.dataset = chesapeake & naip - - # TODO: figure out better train/val/test split - roi = self.dataset.bounds - midx = roi.minx + (roi.maxx - roi.minx) / 2 - midy = roi.miny + (roi.maxy - roi.miny) / 2 - train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) - val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) - test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) - - self.train_sampler = RandomBatchGeoSampler( - naip, self.patch_size, self.batch_size, self.length, train_roi - ) - self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi) - self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.dataset, - batch_sampler=self.train_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.val_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.test_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index bd239e65847..2b502751920 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -4,39 +4,17 @@ """NASA Marine Debris dataset.""" import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import rasterio import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from torchvision.utils import draw_bounding_boxes from .geo import VisionDataset -from .utils import dataset_split, download_radiant_mlhub_dataset, extract_archive - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - - -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable boxes. - - Args: - batch: list of sample dicts return by dataset - - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - return output +from .utils import download_radiant_mlhub_dataset, extract_archive class NASAMarineDebris(VisionDataset): @@ -279,109 +257,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class NASAMarineDebrisDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the NASA Marine Debris dataset.""" - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - test_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to the Dataset class - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.test_split_pct = test_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - NASAMarineDebris(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = NASAMarineDebris(self.root_dir, transforms=transforms) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index c2f807b490e..803d7405c09 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -5,25 +5,23 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Union -import kornia.augmentation as K import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch -from einops import repeat from matplotlib.figure import Figure from numpy import ndarray as Array from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate -from torchvision.transforms import Compose, Normalize -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import download_url, extract_archive, sort_sentinel2_bands +from .utils import ( + download_url, + draw_semantic_segmentation_masks, + extract_archive, + sort_sentinel2_bands, +) class OSCD(VisionDataset): @@ -317,202 +315,3 @@ def get_masked(img: Tensor) -> Array: # type: ignore[type-arg] plt.suptitle(suptitle) return fig - - -class OSCDDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the OSCD dataset. - - Uses the train/test splits from the dataset and further splits - the train split into train/val splits. - - .. versionadded: 0.2 - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [ - 1583.0741, - 1374.3202, - 1294.1616, - 1325.6158, - 1478.7408, - 1933.0822, - 2166.0608, - 2076.4868, - 2306.0652, - 690.9814, - 16.2360, - 2080.3347, - 1524.6930, - ] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [ - 52.1937, - 83.4168, - 105.6966, - 151.1401, - 147.4615, - 115.9289, - 123.1974, - 114.6483, - 141.4530, - 73.2758, - 4.8368, - 213.4821, - 179.4793, - ] - ) - - def __init__( - self, - root_dir: str, - bands: str = "all", - train_batch_size: int = 32, - num_workers: int = 0, - val_split_pct: float = 0.2, - patch_size: Tuple[int, int] = (64, 64), - num_patches_per_tile: int = 32, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for OSCD based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the OSCD Dataset classes - bands: "rgb" or "all" - train_batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1) - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - patch_size: Size of random patch from image and mask (height, width) - num_patches_per_tile: number of random patches per sample - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.bands = bands - self.train_batch_size = train_batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.patch_size = patch_size - self.num_patches_per_tile = num_patches_per_tile - - if bands == "rgb": - self.band_means = self.band_means[[3, 2, 1], None, None] - self.band_stds = self.band_stds[[3, 2, 1], None, None] - else: - self.band_means = self.band_means[:, None, None] - self.band_stds = self.band_stds[:, None, None] - - self.norm = Normalize(self.band_means, self.band_stds) - self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True - ) - self.padto = K.PadTo((1280, 1280)) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"] - sample["image"] = self.norm(sample["image"]) - sample["image"] = torch.flatten( # type: ignore[attr-defined] - sample["image"], 0, 1 - ) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - OSCD(self.root_dir, split="train", bands=self.bands, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - """ - - def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: - images, masks = [], [] - for i in range(self.num_patches_per_tile): - mask = repeat(sample["mask"], "h w -> t h w", t=2).float() - image, mask = self.rcrop(sample["image"], mask) - mask = mask.squeeze()[0] - images.append(image.squeeze()) - masks.append(mask.long()) - sample["image"] = torch.stack(images) - sample["mask"] = torch.stack(masks) - return sample - - def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = self.padto(sample["image"])[0] - sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] - return sample - - train_transforms = Compose([self.preprocess, n_random_crop]) - # for testing and validation we pad all inputs to a fixed size to avoid issues - # with the upsampling paths in encoder-decoder architectures - test_transforms = Compose([self.preprocess, pad_to]) - - train_dataset = OSCD( - self.root_dir, split="train", bands=self.bands, transforms=train_transforms - ) - if self.val_split_pct > 0.0: - val_dataset = OSCD( - self.root_dir, - split="train", - bands=self.bands, - transforms=test_transforms, - ) - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - self.val_dataset.dataset = val_dataset - else: - self.train_dataset = train_dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = OSCD( - self.root_dir, split="test", bands=self.bands, transforms=test_transforms - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - - def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] - batch - ) - r_batch["image"] = torch.flatten( # type: ignore[attr-defined] - r_batch["image"], 0, 1 - ) - r_batch["mask"] = torch.flatten( # type: ignore[attr-defined] - r_batch["mask"], 0, 1 - ) - return r_batch - - return DataLoader( - self.train_dataset, - batch_size=self.train_batch_size, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 40149d0429f..a54e4b18f75 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -4,22 +4,23 @@ """Potsdam dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import rasterio import torch from matplotlib.figure import Figure from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import check_integrity, extract_archive, rgb_to_mask +from .utils import ( + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, + rgb_to_mask, +) class Potsdam2D(VisionDataset): @@ -293,111 +294,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class Potsdam2DDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the Potsdam2D dataset. - - Uses the train/test splits from the dataset. - - .. versionadded: 0.2 - """ - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for Potsdam2D based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = Potsdam2D(self.root_dir, "train", transforms=transforms) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 4b5c9560a0b..13117d54645 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -4,23 +4,15 @@ """RESISC45 dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl -import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset from .utils import download_url, extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class RESISC45(VisionClassificationDataset): """RESISC45 dataset. @@ -288,109 +280,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class RESISC45DataModule(pl.LightningDataModule): - """LightningDataModule implementation for the RESISC45 dataset. - - Uses the train/val/test splits from the dataset. - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [0.36801773, 0.38097873, 0.343583] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [0.14540215, 0.13558227, 0.13203649] - ) - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for RESISC45 based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - RESISC45(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms) - self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms) - self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index f1cf8b2ad2c..8ff0e9a44b8 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -4,23 +4,16 @@ """SEN12MS dataset.""" import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np -import pytorch_lightning as pl import rasterio import torch -from sklearn.model_selection import GroupShuffleSplit from torch import Tensor -from torch.utils.data import DataLoader, Subset from .geo import VisionDataset from .utils import check_integrity -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class SEN12MS(VisionDataset): """SEN12MS dataset. @@ -246,188 +239,3 @@ def _check_integrity(self) -> bool: if not check_integrity(filepath, md5 if self.checksum else None): return False return True - - -class SEN12MSDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the SEN12MS dataset. - - Implements 80/20 geographic train/val splits and uses the test split from the - classification dataset definitions. See :func:`setup` for more details. - - Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See - https://arxiv.org/abs/2002.08254. - """ - - #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader - #: here https://github.com/lukasliebel/dfc2020_baseline. - DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined] - [ - 0, # maps 0s to 0 - 1, # maps 1s to 1 - 1, # maps 2s to 1 - 1, # ... - 1, - 1, - 2, - 2, - 3, - 3, - 4, - 5, - 6, - 7, - 6, - 8, - 9, - 10, - ] - ) - - def __init__( - self, - root_dir: str, - seed: int, - band_set: str = "all", - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for SEN12MS based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes - seed: The seed value to use when doing the sklearn based ShuffleSplit - band_set: The subset of S1/S2 bands to use. Options are: "all", - "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: - B2, B3, B4, B8, B11, and B12. - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - assert band_set in SEN12MS.BAND_SETS.keys() - - self.root_dir = root_dir - self.seed = seed - self.band_set = band_set - self.band_indices = SEN12MS.BAND_SETS[band_set] - self.batch_size = batch_size - self.num_workers = num_workers - - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - - if self.band_set == "all": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 - elif self.band_set == "s1": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - else: - sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 - - sample["mask"] = sample["mask"][0, :, :].long() - sample["mask"] = torch.take( # type: ignore[attr-defined] - self.DFC2020_CLASS_MAPPING, sample["mask"] - ) - - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - We split samples between train and val geographically with proportions of 80/20. - This mimics the geographic test set split. - - Args: - stage: stage to set up - """ - season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - - self.all_train_dataset = SEN12MS( - self.root_dir, - split="train", - bands=self.band_indices, - transforms=self.custom_transform, - checksum=False, - ) - - self.all_test_dataset = SEN12MS( - self.root_dir, - split="test", - bands=self.band_indices, - transforms=self.custom_transform, - checksum=False, - ) - - # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" - # This patch will belong to the scene that is uniquelly identified by its - # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply - # give each season a large number and representing a `unique_scene_id` as - # `season_id + scene_id`. - scenes = [] - for scene_fn in self.all_train_dataset.ids: - parts = scene_fn.split("_") - season_id = season_to_int[parts[1]] - scene_id = int(parts[3]) - scenes.append(season_id + scene_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( - scenes, groups=scenes - ) - ) - - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 606a73e62c0..aaee0ecd2d3 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -4,23 +4,16 @@ """So2Sat dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from .geo import VisionDataset from .utils import check_integrity, percentile_normalization -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class So2Sat(VisionDataset): """So2Sat dataset. @@ -250,211 +243,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class So2SatDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the So2Sat dataset. - - Uses the train/val/test splits from the dataset. - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [ - -3.591224256609313e-05, - -7.658561276843396e-06, - 5.9373857475971184e-05, - 2.5166231537121083e-05, - 0.04420110659759328, - 0.25761027084996196, - 0.0007556743372573258, - 0.0013503466830024448, - 0.12375696117681859, - 0.1092774636368323, - 0.1010855203267882, - 0.1142398616114001, - 0.1592656692023089, - 0.18147236008771792, - 0.1745740312291377, - 0.19501607349635292, - 0.15428468872076637, - 0.10905050699570007, - ] - ).reshape(18, 1, 1) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [ - 0.17555201137417686, - 0.17556463274968204, - 0.45998793417834255, - 0.455988755730148, - 2.8559909213125763, - 8.324800606439833, - 2.4498757382563103, - 1.4647352984509094, - 0.03958795985905458, - 0.047778262752410296, - 0.06636616706371974, - 0.06358874912497474, - 0.07744387147984592, - 0.09101635085921553, - 0.09218466562387101, - 0.10164581233948201, - 0.09991773043519253, - 0.08780632509122865, - ] - ).reshape(18, 1, 1) - - # this reorders the bands to put S2 RGB first, then remainder of S2, then S1 - reindex_to_rgb_first = [ - 10, - 9, - 8, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - # 0, - # 1, - # 2, - # 3, - # 4, - # 5, - # 6, - # 7, - ] - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - bands: str = "rgb", - unsupervised_mode: bool = False, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for So2Sat based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - bands: Either "rgb" or "s2" - unsupervised_mode: Makes the train dataloader return imagery from the train, - val, and test sets - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.bands = bands - self.unsupervised_mode = unsupervised_mode - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - # sample["image"] = (sample["image"] - self.band_means) / self.band_stds - sample["image"] = sample["image"].float() - sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] - - if self.bands == "rgb": - sample["image"] = sample["image"][:3, :, :] - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - So2Sat(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = Compose([self.preprocess]) - val_test_transforms = self.preprocess - - if not self.unsupervised_mode: - - self.train_dataset = So2Sat( - self.root_dir, split="train", transforms=train_transforms - ) - - self.val_dataset = So2Sat( - self.root_dir, split="validation", transforms=val_test_transforms - ) - - self.test_dataset = So2Sat( - self.root_dir, split="test", transforms=val_test_transforms - ) - - else: - - temp_train = So2Sat( - self.root_dir, split="train", transforms=train_transforms - ) - - self.val_dataset = So2Sat( - self.root_dir, split="validation", transforms=train_transforms - ) - - self.test_dataset = So2Sat( - self.root_dir, split="test", transforms=train_transforms - ) - - self.train_dataset = cast( - So2Sat, temp_train + self.val_dataset + self.test_dataset - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 41cb1f2ff6c..38ac5c6a150 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -24,8 +24,8 @@ from rasterio.transform import Affine from torch import Tensor -from torchgeo.datasets.geo import VisionDataset -from torchgeo.datasets.utils import ( +from .geo import VisionDataset +from .utils import ( check_integrity, download_radiant_mlhub_collection, extract_archive, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 431b526b756..21b09e32a1d 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -3,24 +3,15 @@ """UC Merced dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl -import torch -import torchvision from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset from .utils import check_integrity, download_url, extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class UCMerced(VisionClassificationDataset): """UC Merced dataset. @@ -251,110 +242,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class UCMercedDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the UC Merced dataset. - - Uses random train/val/test splits. - """ - - band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined] - - band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined] - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for UCMerced based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - c, h, w = sample["image"].shape - if h != 256 or w != 256: - sample["image"] = torchvision.transforms.functional.resize( - sample["image"], size=(256, 256) - ) - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - UCMerced(self.root_dir, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms) - self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms) - self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index fceb6201a38..9a68be30ca3 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -32,7 +32,6 @@ import rasterio import torch from torch import Tensor -from torch.utils.data import Dataset, Subset, random_split from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks @@ -48,7 +47,6 @@ "concat_samples", "merge_samples", "rasterio_loader", - "dataset_split", "sort_sentinel2_bands", "draw_semantic_segmentation_masks", "rgb_to_mask", @@ -519,31 +517,6 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg] return array -def dataset_split( - dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None -) -> List[Subset[Any]]: - """Split a torch Dataset into train/val/test sets. - - If ``test_pct`` is not set then only train and validation splits are returned. - - Args: - dataset: dataset to be split into train/val or train/val/test subsets - val_pct: percentage of samples to be in validation set - test_pct: (Optional) percentage of samples to be in test set - Returns: - a list of the subset datasets. Either [train, val] or [train, val, test] - """ - if test_pct is None: - val_length = int(len(dataset) * val_pct) - train_length = len(dataset) - val_length - return random_split(dataset, [train_length, val_length]) - else: - val_length = int(len(dataset) * val_pct) - test_length = int(len(dataset) * test_pct) - train_length = len(dataset) - (val_length + test_length) - return random_split(dataset, [train_length, val_length, test_length]) - - def sort_sentinel2_bands(x: str) -> str: """Sort Sentinel-2 band files in the correct order.""" x = os.path.basename(x).split("_")[-1] diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index f95e8e72d48..c7bb3e7f4ed 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -4,21 +4,22 @@ """Vaihingen dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from matplotlib.figure import Figure from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import check_integrity, extract_archive, rgb_to_mask +from .utils import ( + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, + rgb_to_mask, +) class Vaihingen2D(VisionDataset): @@ -293,111 +294,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class Vaihingen2DDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the Vaihingen2D dataset. - - Uses the train/test splits from the dataset. - - .. versionadded: 0.2 - """ - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index c4e7774e04d..b7ff4d46002 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -5,20 +5,16 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import check_integrity, extract_archive +from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive class XView2(VisionDataset): @@ -282,111 +278,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class XView2DataModule(pl.LightningDataModule): - """LightningDataModule implementation for the xView2 dataset. - - Uses the train/val/test splits from the dataset. - - .. versionadded: 0.2 - """ - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for xView2 based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the xView2 Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = XView2(self.root_dir, "train", transforms=transforms) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = XView2(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 0dcd1a85aef..3dc5dcc4b8e 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -10,9 +10,7 @@ from rtree.index import Index, Property from torch.utils.data import Sampler -from torchgeo.datasets.geo import GeoDataset -from torchgeo.datasets.utils import BoundingBox - +from ..datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 1804d9a2d84..d507f698e3b 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -10,9 +10,7 @@ from rtree.index import Index, Property from torch.utils.data import Sampler -from torchgeo.datasets.geo import GeoDataset -from torchgeo.datasets.utils import BoundingBox - +from ..datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index b8aecd85a11..265859eeb06 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -6,7 +6,7 @@ import random from typing import Tuple, Union -from torchgeo.datasets.utils import BoundingBox +from ..datasets import BoundingBox def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: