Skip to content

Commit

Permalink
USAVars: implementing DataModule (microsoft#441)
Browse files Browse the repository at this point in the history
* USAVars: implementing DataModule

* Adding initial version

* add to __init__

* changes

* add transforms argument

* black, isort fix

* fixed shuffle option

* update docs

* fix formatting

* initial split method

* fix formatting

* testing for datamodule

* this is simpler

* testing seed

* fix isort + test seed

* refactor dataset for splits

* black fix

* adding splits to fake data

* change test splits

* working tests locally

* fix black

* fix black

* adapt module to dataset refactor

* complete docstring

* Style fixes

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
iejMac and adamjstewart authored Jun 27, 2022
1 parent 2c1477a commit 8d89b70
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 8 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ UC Merced

.. autoclass:: UCMercedDataModule

USAVars
^^^^^^^

.. autoclass:: USAVarsDataModule

Vaihingen
^^^^^^^^^

Expand Down
16 changes: 16 additions & 0 deletions tests/data/usavars/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"housing",
"roads",
]
splits = ["train", "val", "test"]

SIZE = 3


Expand All @@ -47,9 +49,12 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
# Remove old data
filename = f"{data_dir}.zip"
csvs = glob.glob("*.csv")
txts = glob.glob("*.txt")

for csv in csvs:
os.remove(csv)
for txt in txts:
os.remove(txt)
if os.path.exists(filename):
os.remove(filename)
if os.path.exists(data_dir):
Expand All @@ -67,6 +72,17 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
df = pd.DataFrame(fake_vals, columns=cols)
df.to_csv(lab + ".csv")

# Create splits:
with open("train_split.txt", "w") as f:
f.write("tile_0,0.tif" + "\n")
f.write("tile_0,0.tif" + "\n")
f.write("tile_0,0.tif" + "\n")
with open("val_split.txt", "w") as f:
f.write("tile_0,1.tif" + "\n")
f.write("tile_0,1.tif" + "\n")
with open("test_split.txt", "w") as f:
f.write("tile_0,0.tif" + "\n")

# Compress data
shutil.make_archive(data_dir, "zip", ".", data_dir)

Expand Down
1 change: 1 addition & 0 deletions tests/data/usavars/test_split.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tile_0,0.tif
3 changes: 3 additions & 0 deletions tests/data/usavars/train_split.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tile_0,0.tif
tile_0,0.tif
tile_0,0.tif
2 changes: 2 additions & 0 deletions tests/data/usavars/val_split.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tile_0,1.tif
tile_0,1.tif
37 changes: 37 additions & 0 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import USAVarsDataModule


class TestUSAVarsDataModule:
@pytest.fixture()
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
root = os.path.join("tests", "data", "usavars")
batch_size = 1
num_workers = 0

dm = USAVarsDataModule(root, batch_size=batch_size, num_workers=num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.train_dataloader()) == 3
sample = next(iter(datamodule.train_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size

def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.val_dataloader()) == 2
sample = next(iter(datamodule.val_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size

def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.test_dataloader()) == 1
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size
48 changes: 44 additions & 4 deletions tests/datasets/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:


class TestUSAVars:
@pytest.fixture()
@pytest.fixture(
params=zip(
["train", "val", "test"],
[
["elevation", "population", "treecover"],
["elevation", "population"],
["treecover"],
],
)
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> USAVars:
Expand All @@ -50,21 +59,49 @@ def dataset(
}
monkeypatch.setattr(USAVars, "label_urls", label_urls)

split_metadata = {
"train": {
"url": os.path.join("tests", "data", "usavars", "train_split.txt"),
"filename": "train_split.txt",
"md5": "b94f3f6f63110b253779b65bc31d91b5",
},
"val": {
"url": os.path.join("tests", "data", "usavars", "val_split.txt"),
"filename": "val_split.txt",
"md5": "e39aa54b646c4c45921fcc9765d5a708",
},
"test": {
"url": os.path.join("tests", "data", "usavars", "test_split.txt"),
"filename": "test_split.txt",
"md5": "4ab0f5549fee944a5690de1bc95ed245",
},
}
monkeypatch.setattr(USAVars, "split_metadata", split_metadata)

root = str(tmp_path)
split, labels = request.param
transforms = nn.Identity() # type: ignore[no-untyped-call]

return USAVars(root, transforms=transforms, download=True, checksum=True)
return USAVars(
root, split, labels, transforms=transforms, download=True, checksum=True
)

def test_getitem(self, dataset: USAVars) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 3
assert len(x.keys()) == 2 # image, elevation, population, treecover
assert len(x.keys()) == 2 # image, labels
assert x["image"].shape[0] == 4 # R, G, B, Inf
assert len(dataset.labels) == len(x["labels"])

def test_len(self, dataset: USAVars) -> None:
assert len(dataset) == 2
if dataset.split == "train":
assert len(dataset) == 3
elif dataset.split == "val":
assert len(dataset) == 2
else:
assert len(dataset) == 1

def test_add(self, dataset: USAVars) -> None:
ds = dataset + dataset
Expand All @@ -88,6 +125,9 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
]
for csv in csvs:
shutil.copy(os.path.join("tests", "data", "usavars", csv), root)
splits = ["train_split.txt", "val_split.txt", "test_split.txt"]
for split in splits:
shutil.copy(os.path.join("tests", "data", "usavars", split), root)

USAVars(root)

Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .vaihingen import Vaihingen2DDataModule
from .xview import XView2DataModule

Expand All @@ -45,6 +46,7 @@
"So2SatDataModule",
"CycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
"Vaihingen2DDataModule",
"XView2DataModule",
)
Expand Down
95 changes: 95 additions & 0 deletions torchgeo/datamodules/usavars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""USAVars datamodule."""

from typing import Any, Callable, Dict, Optional, Sequence

import pytorch_lightning as pl
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import USAVars


class USAVarsDataModule(pl.LightningModule):
"""LightningDataModule implementation for the USAVars dataset.
Uses random train/val/test splits.
.. versionadded:: 0.3
"""

def __init__(
self,
root_dir: str,
labels: Sequence[str] = USAVars.ALL_LABELS,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
batch_size: int = 64,
num_workers: int = 0,
) -> None:
"""Initialize a LightningDataModule for USAVars based DataLoaders.
Args:
root_dir: The root argument passed to the USAVars Dataset classes
labels: The labels argument passed to the USAVars Dataset classes
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__()
self.root_dir = root_dir
self.labels = labels
self.transforms = transforms
self.batch_size = batch_size
self.num_workers = num_workers

def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
USAVars(self.root_dir, labels=self.labels, checksum=False)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main Dataset objects.
This method is called once per GPU per run.
"""
self.train_dataset = USAVars(
self.root_dir, "train", self.labels, transforms=self.transforms
)
self.val_dataset = USAVars(
self.root_dir, "val", self.labels, transforms=self.transforms
)
self.test_dataset = USAVars(
self.root_dir, "test", self.labels, transforms=self.transforms
)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
40 changes: 36 additions & 4 deletions torchgeo/datasets/usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,30 @@ class USAVars(VisionDataset):
+ f"outcomes_sampled_treecover_{uar_csv_suffix}",
}

split_metadata = {
"train": {
"url": "https://mosaiks.blob.core.windows.net/datasets/train_split.txt",
"filename": "train_split.txt",
"md5": "3f58fffbf5fe177611112550297200e7",
},
"val": {
"url": "https://mosaiks.blob.core.windows.net/datasets/val_split.txt",
"filename": "val_split.txt",
"md5": "bca7183b132b919dec0fc24fb11662a0",
},
"test": {
"url": "https://mosaiks.blob.core.windows.net/datasets/test_split.txt",
"filename": "test_split.txt",
"md5": "97bb36bc003ae0bf556a8d6e8f77141a",
},
}

ALL_LABELS = list(label_urls.keys())

def __init__(
self,
root: str = "data",
split: str = "train",
labels: Sequence[str] = ALL_LABELS,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
Expand All @@ -85,6 +104,7 @@ def __init__(
Args:
root: root directory where dataset can be found
split: train/val/test split to load
labels: list of labels to include
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Expand All @@ -99,6 +119,9 @@ def __init__(
"""
self.root = root

assert split in self.split_metadata
self.split = split

for lab in labels:
assert lab in self.ALL_LABELS

Expand Down Expand Up @@ -157,8 +180,8 @@ def __len__(self) -> int:

def _load_files(self) -> List[str]:
"""Loads file names."""
file_path = os.path.join(self.root, "uar")
files = os.listdir(file_path)
with open(os.path.join(self.root, f"{self.split}_split.txt")) as f:
files = f.read().splitlines()
return files

def _load_image(self, path: str) -> Tensor:
Expand All @@ -184,13 +207,15 @@ def _verify(self) -> None:
# Check if the extracted files already exist
pathname = os.path.join(self.root, "uar")
csv_pathname = os.path.join(self.root, "*.csv")
split_pathname = os.path.join(self.root, "*_split.txt")

if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7:
csv_split_count = (len(glob.glob(csv_pathname)), len(glob.glob(split_pathname)))
if glob.glob(pathname) and csv_split_count == (7, 3):
return

# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.dirname + ".zip")
if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7:
if glob.glob(pathname) and csv_split_count == (7, 3):
self._extract()
return

Expand All @@ -212,6 +237,13 @@ def _download(self) -> None:

download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None)

for metadata in self.split_metadata.values():
download_url(
metadata["url"],
self.root,
md5=metadata["md5"] if self.checksum else None,
)

def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.dirname + ".zip"))
Expand Down

0 comments on commit 8d89b70

Please sign in to comment.