From 7b5a92b1c1ad0d221020968e3f25e660bb96d157 Mon Sep 17 00:00:00 2001 From: Maciej Kilian <61431446+iejMac@users.noreply.github.com> Date: Mon, 27 Jun 2022 02:52:15 +0200 Subject: [PATCH] USAVars: implementing DataModule (#441) * USAVars: implementing DataModule * Adding initial version * add to __init__ * changes * add transforms argument * black, isort fix * fixed shuffle option * update docs * fix formatting * initial split method * fix formatting * testing for datamodule * this is simpler * testing seed * fix isort + test seed * refactor dataset for splits * black fix * adding splits to fake data * change test splits * working tests locally * fix black * fix black * adapt module to dataset refactor * complete docstring * Style fixes Co-authored-by: Adam J. Stewart --- docs/api/datamodules.rst | 5 ++ tests/data/usavars/data.py | 16 +++++ tests/data/usavars/test_split.txt | 1 + tests/data/usavars/train_split.txt | 3 + tests/data/usavars/val_split.txt | 2 + tests/datamodules/test_usavars.py | 37 ++++++++++++ tests/datasets/test_usavars.py | 48 +++++++++++++-- torchgeo/datamodules/__init__.py | 2 + torchgeo/datamodules/usavars.py | 95 ++++++++++++++++++++++++++++++ torchgeo/datasets/usavars.py | 40 +++++++++++-- 10 files changed, 241 insertions(+), 8 deletions(-) create mode 100644 tests/data/usavars/test_split.txt create mode 100644 tests/data/usavars/train_split.txt create mode 100644 tests/data/usavars/val_split.txt create mode 100644 tests/datamodules/test_usavars.py create mode 100644 torchgeo/datamodules/usavars.py diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 33810a9e50f..4edeb0c37a9 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -99,6 +99,11 @@ UC Merced .. autoclass:: UCMercedDataModule +USAVars +^^^^^^^ + +.. autoclass:: USAVarsDataModule + Vaihingen ^^^^^^^^^ diff --git a/tests/data/usavars/data.py b/tests/data/usavars/data.py index c392887b04e..413511dda88 100644 --- a/tests/data/usavars/data.py +++ b/tests/data/usavars/data.py @@ -22,6 +22,8 @@ "housing", "roads", ] +splits = ["train", "val", "test"] + SIZE = 3 @@ -47,9 +49,12 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: # Remove old data filename = f"{data_dir}.zip" csvs = glob.glob("*.csv") +txts = glob.glob("*.txt") for csv in csvs: os.remove(csv) +for txt in txts: + os.remove(txt) if os.path.exists(filename): os.remove(filename) if os.path.exists(data_dir): @@ -67,6 +72,17 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: df = pd.DataFrame(fake_vals, columns=cols) df.to_csv(lab + ".csv") +# Create splits: +with open("train_split.txt", "w") as f: + f.write("tile_0,0.tif" + "\n") + f.write("tile_0,0.tif" + "\n") + f.write("tile_0,0.tif" + "\n") +with open("val_split.txt", "w") as f: + f.write("tile_0,1.tif" + "\n") + f.write("tile_0,1.tif" + "\n") +with open("test_split.txt", "w") as f: + f.write("tile_0,0.tif" + "\n") + # Compress data shutil.make_archive(data_dir, "zip", ".", data_dir) diff --git a/tests/data/usavars/test_split.txt b/tests/data/usavars/test_split.txt new file mode 100644 index 00000000000..7f982dc7765 --- /dev/null +++ b/tests/data/usavars/test_split.txt @@ -0,0 +1 @@ +tile_0,0.tif diff --git a/tests/data/usavars/train_split.txt b/tests/data/usavars/train_split.txt new file mode 100644 index 00000000000..1fcebec2e34 --- /dev/null +++ b/tests/data/usavars/train_split.txt @@ -0,0 +1,3 @@ +tile_0,0.tif +tile_0,0.tif +tile_0,0.tif diff --git a/tests/data/usavars/val_split.txt b/tests/data/usavars/val_split.txt new file mode 100644 index 00000000000..7ff523c0ffb --- /dev/null +++ b/tests/data/usavars/val_split.txt @@ -0,0 +1,2 @@ +tile_0,1.tif +tile_0,1.tif diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py new file mode 100644 index 00000000000..f8813261d6d --- /dev/null +++ b/tests/datamodules/test_usavars.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import USAVarsDataModule + + +class TestUSAVarsDataModule: + @pytest.fixture() + def datamodule(self, request: SubRequest) -> USAVarsDataModule: + root = os.path.join("tests", "data", "usavars") + batch_size = 1 + num_workers = 0 + + dm = USAVarsDataModule(root, batch_size=batch_size, num_workers=num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None: + assert len(datamodule.train_dataloader()) == 3 + sample = next(iter(datamodule.train_dataloader())) + assert sample["image"].shape[0] == datamodule.batch_size + + def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None: + assert len(datamodule.val_dataloader()) == 2 + sample = next(iter(datamodule.val_dataloader())) + assert sample["image"].shape[0] == datamodule.batch_size + + def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: + assert len(datamodule.test_dataloader()) == 1 + sample = next(iter(datamodule.test_dataloader())) + assert sample["image"].shape[0] == datamodule.batch_size diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 15aa10ad3cf..1fd95e53ad2 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -26,7 +26,16 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestUSAVars: - @pytest.fixture() + @pytest.fixture( + params=zip( + ["train", "val", "test"], + [ + ["elevation", "population", "treecover"], + ["elevation", "population"], + ["treecover"], + ], + ) + ) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> USAVars: @@ -50,21 +59,49 @@ def dataset( } monkeypatch.setattr(USAVars, "label_urls", label_urls) + split_metadata = { + "train": { + "url": os.path.join("tests", "data", "usavars", "train_split.txt"), + "filename": "train_split.txt", + "md5": "b94f3f6f63110b253779b65bc31d91b5", + }, + "val": { + "url": os.path.join("tests", "data", "usavars", "val_split.txt"), + "filename": "val_split.txt", + "md5": "e39aa54b646c4c45921fcc9765d5a708", + }, + "test": { + "url": os.path.join("tests", "data", "usavars", "test_split.txt"), + "filename": "test_split.txt", + "md5": "4ab0f5549fee944a5690de1bc95ed245", + }, + } + monkeypatch.setattr(USAVars, "split_metadata", split_metadata) + root = str(tmp_path) + split, labels = request.param transforms = nn.Identity() # type: ignore[no-untyped-call] - return USAVars(root, transforms=transforms, download=True, checksum=True) + return USAVars( + root, split, labels, transforms=transforms, download=True, checksum=True + ) def test_getitem(self, dataset: USAVars) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert x["image"].ndim == 3 - assert len(x.keys()) == 2 # image, elevation, population, treecover + assert len(x.keys()) == 2 # image, labels assert x["image"].shape[0] == 4 # R, G, B, Inf + assert len(dataset.labels) == len(x["labels"]) def test_len(self, dataset: USAVars) -> None: - assert len(dataset) == 2 + if dataset.split == "train": + assert len(dataset) == 3 + elif dataset.split == "val": + assert len(dataset) == 2 + else: + assert len(dataset) == 1 def test_add(self, dataset: USAVars) -> None: ds = dataset + dataset @@ -88,6 +125,9 @@ def test_already_downloaded(self, tmp_path: Path) -> None: ] for csv in csvs: shutil.copy(os.path.join("tests", "data", "usavars", csv), root) + splits = ["train_split.txt", "val_split.txt", "test_split.txt"] + for split in splits: + shutil.copy(os.path.join("tests", "data", "usavars", split), root) USAVars(root) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 64f4c2d8d95..02cb2b83e14 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -21,6 +21,7 @@ from .sen12ms import SEN12MSDataModule from .so2sat import So2SatDataModule from .ucmerced import UCMercedDataModule +from .usavars import USAVarsDataModule from .vaihingen import Vaihingen2DDataModule from .xview import XView2DataModule @@ -45,6 +46,7 @@ "So2SatDataModule", "CycloneDataModule", "UCMercedDataModule", + "USAVarsDataModule", "Vaihingen2DDataModule", "XView2DataModule", ) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py new file mode 100644 index 00000000000..7caa4862307 --- /dev/null +++ b/torchgeo/datamodules/usavars.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""USAVars datamodule.""" + +from typing import Any, Callable, Dict, Optional, Sequence + +import pytorch_lightning as pl +from torch import Tensor +from torch.utils.data import DataLoader + +from ..datasets import USAVars + + +class USAVarsDataModule(pl.LightningModule): + """LightningDataModule implementation for the USAVars dataset. + + Uses random train/val/test splits. + + .. versionadded:: 0.3 + """ + + def __init__( + self, + root_dir: str, + labels: Sequence[str] = USAVars.ALL_LABELS, + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + batch_size: int = 64, + num_workers: int = 0, + ) -> None: + """Initialize a LightningDataModule for USAVars based DataLoaders. + + Args: + root_dir: The root argument passed to the USAVars Dataset classes + labels: The labels argument passed to the USAVars Dataset classes + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() + self.root_dir = root_dir + self.labels = labels + self.transforms = transforms + self.batch_size = batch_size + self.num_workers = num_workers + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + USAVars(self.root_dir, labels=self.labels, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main Dataset objects. + + This method is called once per GPU per run. + """ + self.train_dataset = USAVars( + self.root_dir, "train", self.labels, transforms=self.transforms + ) + self.val_dataset = USAVars( + self.root_dir, "val", self.labels, transforms=self.transforms + ) + self.test_dataset = USAVars( + self.root_dir, "test", self.labels, transforms=self.transforms + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 4d259d0447f..af482644812 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -71,11 +71,30 @@ class USAVars(VisionDataset): + f"outcomes_sampled_treecover_{uar_csv_suffix}", } + split_metadata = { + "train": { + "url": "https://mosaiks.blob.core.windows.net/datasets/train_split.txt", + "filename": "train_split.txt", + "md5": "3f58fffbf5fe177611112550297200e7", + }, + "val": { + "url": "https://mosaiks.blob.core.windows.net/datasets/val_split.txt", + "filename": "val_split.txt", + "md5": "bca7183b132b919dec0fc24fb11662a0", + }, + "test": { + "url": "https://mosaiks.blob.core.windows.net/datasets/test_split.txt", + "filename": "test_split.txt", + "md5": "97bb36bc003ae0bf556a8d6e8f77141a", + }, + } + ALL_LABELS = list(label_urls.keys()) def __init__( self, root: str = "data", + split: str = "train", labels: Sequence[str] = ALL_LABELS, transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, @@ -85,6 +104,7 @@ def __init__( Args: root: root directory where dataset can be found + split: train/val/test split to load labels: list of labels to include transforms: a function/transform that takes input sample and its target as entry and returns a transformed version @@ -99,6 +119,9 @@ def __init__( """ self.root = root + assert split in self.split_metadata + self.split = split + for lab in labels: assert lab in self.ALL_LABELS @@ -157,8 +180,8 @@ def __len__(self) -> int: def _load_files(self) -> List[str]: """Loads file names.""" - file_path = os.path.join(self.root, "uar") - files = os.listdir(file_path) + with open(os.path.join(self.root, f"{self.split}_split.txt")) as f: + files = f.read().splitlines() return files def _load_image(self, path: str) -> Tensor: @@ -184,13 +207,15 @@ def _verify(self) -> None: # Check if the extracted files already exist pathname = os.path.join(self.root, "uar") csv_pathname = os.path.join(self.root, "*.csv") + split_pathname = os.path.join(self.root, "*_split.txt") - if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + csv_split_count = (len(glob.glob(csv_pathname)), len(glob.glob(split_pathname))) + if glob.glob(pathname) and csv_split_count == (7, 3): return # Check if the zip files have already been downloaded pathname = os.path.join(self.root, self.dirname + ".zip") - if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7: + if glob.glob(pathname) and csv_split_count == (7, 3): self._extract() return @@ -212,6 +237,13 @@ def _download(self) -> None: download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None) + for metadata in self.split_metadata.values(): + download_url( + metadata["url"], + self.root, + md5=metadata["md5"] if self.checksum else None, + ) + def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.dirname + ".zip"))