Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

USAVars: implementing DataModule #441

Merged
merged 25 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ UC Merced

.. autoclass:: UCMercedDataModule

USAVars
^^^^^^^

.. autoclass:: USAVarsDataModule

Vaihingen
^^^^^^^^^

Expand Down
16 changes: 16 additions & 0 deletions tests/data/usavars/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"housing",
"roads",
]
splits = ["train", "val", "test"]

SIZE = 3


Expand All @@ -47,9 +49,12 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
# Remove old data
filename = f"{data_dir}.zip"
csvs = glob.glob("*.csv")
txts = glob.glob("*.txt")

for csv in csvs:
os.remove(csv)
for txt in txts:
os.remove(txt)
if os.path.exists(filename):
os.remove(filename)
if os.path.exists(data_dir):
Expand All @@ -67,6 +72,17 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
df = pd.DataFrame(fake_vals, columns=cols)
df.to_csv(lab + ".csv")

# Create splits:
with open("train_split.txt", "w") as f:
f.write("tile_0,0.tif" + "\n")
f.write("tile_0,0.tif" + "\n")
f.write("tile_0,0.tif" + "\n")
with open("val_split.txt", "w") as f:
f.write("tile_0,1.tif" + "\n")
f.write("tile_0,1.tif" + "\n")
with open("test_split.txt", "w") as f:
f.write("tile_0,0.tif" + "\n")

# Compress data
shutil.make_archive(data_dir, "zip", ".", data_dir)

Expand Down
1 change: 1 addition & 0 deletions tests/data/usavars/test_split.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tile_0,0.tif
3 changes: 3 additions & 0 deletions tests/data/usavars/train_split.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tile_0,0.tif
tile_0,0.tif
tile_0,0.tif
2 changes: 2 additions & 0 deletions tests/data/usavars/val_split.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tile_0,1.tif
tile_0,1.tif
37 changes: 37 additions & 0 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import USAVarsDataModule


class TestUSAVarsDataModule:
@pytest.fixture()
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
root = os.path.join("tests", "data", "usavars")
batch_size = 1
num_workers = 0

dm = USAVarsDataModule(root, batch_size=batch_size, num_workers=num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.train_dataloader()) == 3
sample = next(iter(datamodule.train_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size

def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.val_dataloader()) == 2
sample = next(iter(datamodule.val_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size

def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.test_dataloader()) == 1
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size
48 changes: 44 additions & 4 deletions tests/datasets/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:


class TestUSAVars:
@pytest.fixture()
@pytest.fixture(
params=zip(
["train", "val", "test"],
[
["elevation", "population", "treecover"],
["elevation", "population"],
["treecover"],
],
)
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> USAVars:
Expand All @@ -50,21 +59,49 @@ def dataset(
}
monkeypatch.setattr(USAVars, "label_urls", label_urls)

split_metadata = {
"train": {
"url": os.path.join("tests", "data", "usavars", "train_split.txt"),
"filename": "train_split.txt",
"md5": "b94f3f6f63110b253779b65bc31d91b5",
},
"val": {
"url": os.path.join("tests", "data", "usavars", "val_split.txt"),
"filename": "val_split.txt",
"md5": "e39aa54b646c4c45921fcc9765d5a708",
},
"test": {
"url": os.path.join("tests", "data", "usavars", "test_split.txt"),
"filename": "test_split.txt",
"md5": "4ab0f5549fee944a5690de1bc95ed245",
},
}
monkeypatch.setattr(USAVars, "split_metadata", split_metadata)

root = str(tmp_path)
split, labels = request.param
transforms = nn.Identity() # type: ignore[no-untyped-call]

return USAVars(root, transforms=transforms, download=True, checksum=True)
return USAVars(
root, split, labels, transforms=transforms, download=True, checksum=True
)

def test_getitem(self, dataset: USAVars) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 3
assert len(x.keys()) == 2 # image, elevation, population, treecover
assert len(x.keys()) == 2 # image, labels
assert x["image"].shape[0] == 4 # R, G, B, Inf
assert len(dataset.labels) == len(x["labels"])

def test_len(self, dataset: USAVars) -> None:
assert len(dataset) == 2
if dataset.split == "train":
assert len(dataset) == 3
elif dataset.split == "val":
assert len(dataset) == 2
else:
assert len(dataset) == 1

def test_add(self, dataset: USAVars) -> None:
ds = dataset + dataset
Expand All @@ -88,6 +125,9 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
]
for csv in csvs:
shutil.copy(os.path.join("tests", "data", "usavars", csv), root)
splits = ["train_split.txt", "val_split.txt", "test_split.txt"]
for split in splits:
shutil.copy(os.path.join("tests", "data", "usavars", split), root)

USAVars(root)

Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .vaihingen import Vaihingen2DDataModule
from .xview import XView2DataModule

Expand All @@ -45,6 +46,7 @@
"So2SatDataModule",
"CycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
"Vaihingen2DDataModule",
"XView2DataModule",
)
Expand Down
95 changes: 95 additions & 0 deletions torchgeo/datamodules/usavars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""USAVars datamodule."""

from typing import Any, Callable, Dict, Optional, Sequence

import pytorch_lightning as pl
from torch import Tensor
from torch.utils.data import DataLoader

from ..datasets import USAVars


class USAVarsDataModule(pl.LightningModule):
"""LightningDataModule implementation for the USAVars dataset.
Uses random train/val/test splits.
.. versionadded:: 0.3
"""

def __init__(
self,
root_dir: str,
labels: Sequence[str] = USAVars.ALL_LABELS,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
batch_size: int = 64,
num_workers: int = 0,
) -> None:
"""Initialize a LightningDataModule for USAVars based DataLoaders.
Args:
root_dir: The root argument passed to the USAVars Dataset classes
labels: The labels argument passed to the USAVars Dataset classes
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__()
self.root_dir = root_dir
self.labels = labels
self.transforms = transforms
self.batch_size = batch_size
self.num_workers = num_workers

def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
USAVars(self.root_dir, labels=self.labels, checksum=False)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main Dataset objects.
This method is called once per GPU per run.
"""
self.train_dataset = USAVars(
self.root_dir, "train", self.labels, transforms=self.transforms
)
self.val_dataset = USAVars(
self.root_dir, "val", self.labels, transforms=self.transforms
)
self.test_dataset = USAVars(
self.root_dir, "test", self.labels, transforms=self.transforms
)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
40 changes: 36 additions & 4 deletions torchgeo/datasets/usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,30 @@ class USAVars(VisionDataset):
+ f"outcomes_sampled_treecover_{uar_csv_suffix}",
}

split_metadata = {
"train": {
"url": "https://mosaiks.blob.core.windows.net/datasets/train_split.txt",
"filename": "train_split.txt",
"md5": "3f58fffbf5fe177611112550297200e7",
},
"val": {
"url": "https://mosaiks.blob.core.windows.net/datasets/val_split.txt",
"filename": "val_split.txt",
"md5": "bca7183b132b919dec0fc24fb11662a0",
},
"test": {
"url": "https://mosaiks.blob.core.windows.net/datasets/test_split.txt",
"filename": "test_split.txt",
"md5": "97bb36bc003ae0bf556a8d6e8f77141a",
},
}

ALL_LABELS = list(label_urls.keys())

def __init__(
self,
root: str = "data",
split: str = "train",
labels: Sequence[str] = ALL_LABELS,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
Expand All @@ -85,6 +104,7 @@ def __init__(
Args:
root: root directory where dataset can be found
split: train/val/test split to load
labels: list of labels to include
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Expand All @@ -99,6 +119,9 @@ def __init__(
"""
self.root = root

assert split in self.split_metadata
self.split = split

for lab in labels:
assert lab in self.ALL_LABELS

Expand Down Expand Up @@ -157,8 +180,8 @@ def __len__(self) -> int:

def _load_files(self) -> List[str]:
"""Loads file names."""
file_path = os.path.join(self.root, "uar")
files = os.listdir(file_path)
with open(os.path.join(self.root, f"{self.split}_split.txt")) as f:
files = f.read().splitlines()
return files

def _load_image(self, path: str) -> Tensor:
Expand All @@ -184,13 +207,15 @@ def _verify(self) -> None:
# Check if the extracted files already exist
pathname = os.path.join(self.root, "uar")
csv_pathname = os.path.join(self.root, "*.csv")
split_pathname = os.path.join(self.root, "*_split.txt")

if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7:
csv_split_count = (len(glob.glob(csv_pathname)), len(glob.glob(split_pathname)))
if glob.glob(pathname) and csv_split_count == (7, 3):
return

# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.dirname + ".zip")
if glob.glob(pathname) and len(glob.glob(csv_pathname)) == 7:
if glob.glob(pathname) and csv_split_count == (7, 3):
self._extract()
return

Expand All @@ -212,6 +237,13 @@ def _download(self) -> None:

download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None)

for metadata in self.split_metadata.values():
download_url(
metadata["url"],
self.root,
md5=metadata["md5"] if self.checksum else None,
)

def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.dirname + ".zip"))
Expand Down