Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeepGlobe dataset for land cover #578

Merged
merged 36 commits into from
Jul 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c29b010
add class for Deep Globe Land Cover dataset
Jun 13, 2022
e6c5980
add Lightning data module implementation for deepglobe land cover
Jun 13, 2022
283645f
fix formatting errors
Jun 14, 2022
4f5ea5c
fix urls, formats and add link for paper
Jun 14, 2022
0a478a2
add tests for deepglobe dataset and datamodule
Jun 15, 2022
156af6c
fix a test case and a few more formatting error
Jun 15, 2022
48bf64b
add data.py and modify error match for data download
Jun 16, 2022
502d14e
modify draw_semantic_segmentation_masks for cases when mask is a subs…
Jun 16, 2022
b0b95e5
fix mypy error
Jun 17, 2022
20d346d
add to docs for documentation
Jun 27, 2022
5165830
add deepglobe to the dataset lists csv
Jun 27, 2022
68705d2
fix error in building docs
Jun 27, 2022
1a3852d
Update datamodules.rst
calebrob6 Jun 27, 2022
abb3651
Update datasets.rst
calebrob6 Jun 27, 2022
6bc68e6
Update data.py
calebrob6 Jun 27, 2022
0b86e13
Update utils.py
calebrob6 Jun 29, 2022
d664ba9
change file permissions of non_geo_datasets.csv
Jun 29, 2022
3989c33
Add versionadded
calebrob6 Jul 2, 2022
210d375
Update torchgeo/datasets/deepglobelandcover.py
calebrob6 Jul 2, 2022
073ea9a
Change end of line sequence
calebrob6 Jul 2, 2022
0dca2e1
Update tests/data/deepglobelandcover/data.py
calebrob6 Jul 2, 2022
ff6f2f6
exist_ok
calebrob6 Jul 2, 2022
b1358b6
Update tests/datasets/test_deepglobelandcover.py
calebrob6 Jul 2, 2022
21721cc
Remove datamodule tests
calebrob6 Jul 2, 2022
e65fba7
Remove split monkeypatch
calebrob6 Jul 2, 2022
7554def
Merge branch 'add-DeepGlobe-dataset' of github.com:saumyasinha/torchg…
calebrob6 Jul 2, 2022
6aa6f32
Running black
calebrob6 Jul 2, 2022
4d0212f
Add val percent to test conf
calebrob6 Jul 2, 2022
09d4a23
Sort filelist so indices are the same across platforms
calebrob6 Jul 2, 2022
e49e006
Simplified the file and mask fns
calebrob6 Jul 2, 2022
c63afd8
Re-adding datamodule tests for coverage
calebrob6 Jul 2, 2022
6b0a74b
Add sub-configs to test val_split_pct in the datamodule
calebrob6 Jul 2, 2022
3ccf788
Lets try it
calebrob6 Jul 2, 2022
015aa0f
Update tests/conf/deepglobelandcover_0.yaml
calebrob6 Jul 2, 2022
3142743
nulllllllll
calebrob6 Jul 2, 2022
7024047
ingore_zeros -> ignore_index
adamjstewart Jul 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ COWC

.. autoclass:: COWCCountingDataModule

Deep Globe Land Cover Challenge
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DeepGlobeLandCoverDataModule

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ Kenya Crop Type

.. autoclass:: CV4AKenyaCropType

Deep Globe Land Cover Challenge
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DeepGlobeLandCover

DFC2022
^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`BigEarthNet`_,C,Sentinel-1/2,"590,326",19--43,120x120,10,"SAR, MSI"
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB
`Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI
`Deep Globe Land Cover Challenge`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
`DFC2022`_,S,Aerial,,15,"2,000x2,000",0.5,RGB
`ETCI2021 Flood Detection`_,S,Sentinel-1,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"27,000",10,64x64,10,MSI
Expand Down
19 changes: 19 additions & 0 deletions tests/conf/deepglobelandcover_0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
experiment:
task: "deepglobelandcover"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/deepglobelandcover"
val_split_pct: 0.0
batch_size: 1
num_workers: 0
19 changes: 19 additions & 0 deletions tests/conf/deepglobelandcover_5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
experiment:
task: "deepglobelandcover"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root_dir: "tests/data/deepglobelandcover"
val_split_pct: 0.5
batch_size: 1
num_workers: 0
71 changes: 71 additions & 0 deletions tests/data/deepglobelandcover/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil

import numpy as np
from PIL import Image
from torchvision.datasets.utils import calculate_md5


def generate_test_data(root: str, n_samples: int = 3) -> str:
"""Create test data archive for DeepGlobeLandCover dataset.

Args:
root: path to store test data
n_samples: number of samples.

Returns:
md5 hash of created archive
"""
dtype = np.uint8
size = 2

folder_path = os.path.join(root, "data")

train_img_dir = os.path.join(folder_path, "data", "training_data", "images")
train_mask_dir = os.path.join(folder_path, "data", "training_data", "masks")
test_img_dir = os.path.join(folder_path, "data", "test_data", "images")
test_mask_dir = os.path.join(folder_path, "data", "test_data", "masks")

os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(train_mask_dir, exist_ok=True)
os.makedirs(test_img_dir, exist_ok=True)
os.makedirs(test_mask_dir, exist_ok=True)

train_ids = [1, 2, 3]
test_ids = [8, 9, 10]

for i in range(n_samples):
train_id = train_ids[i]
test_id = test_ids[i]

dtype_max = np.iinfo(dtype).max
train_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype)
train_img = Image.fromarray(train_arr)
train_img.save(os.path.join(train_img_dir, str(train_id) + "_sat.jpg"))

test_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype)
test_img = Image.fromarray(test_arr)
test_img.save(os.path.join(test_img_dir, str(test_id) + "_sat.jpg"))

train_mask_arr = np.full((size, size, 3), (0, 255, 255), dtype=dtype)
train_mask_img = Image.fromarray(train_mask_arr)
train_mask_img.save(os.path.join(train_mask_dir, str(train_id) + "_mask.png"))

test_mask_arr = np.full((size, size, 3), (255, 0, 255), dtype=dtype)
test_mask_img = Image.fromarray(test_mask_arr)
test_mask_img.save(os.path.join(test_mask_dir, str(test_id) + "_mask.png"))

# Create archive
shutil.make_archive(folder_path, "zip", folder_path)
shutil.rmtree(folder_path)
return calculate_md5(f"{folder_path}.zip")


if __name__ == "__main__":
md5_hash = generate_test_data(os.getcwd(), 3)
print(md5_hash + "\n")
Binary file added tests/data/deepglobelandcover/data.zip
Binary file not shown.
74 changes: 74 additions & 0 deletions tests/datasets/test_deepglobelandcover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datasets import DeepGlobeLandCover


class TestDeepGlobeLandCover:
@pytest.fixture(params=["train", "test"])
def dataset(
self, monkeypatch: MonkeyPatch, request: SubRequest
) -> DeepGlobeLandCover:
md5 = "2cbd68d36b1485f09f32d874dde7c5c5"
monkeypatch.setattr(DeepGlobeLandCover, "md5", md5)
root = os.path.join("tests", "data", "deepglobelandcover")
split = request.param
transforms = nn.Identity()
return DeepGlobeLandCover(root, split, transforms, checksum=True)

def test_getitem(self, dataset: DeepGlobeLandCover) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_len(self, dataset: DeepGlobeLandCover) -> None:
assert len(dataset) == 3

def test_extract(self, tmp_path: Path) -> None:
root = os.path.join("tests", "data", "deepglobelandcover")
filename = "data.zip"
shutil.copyfile(
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
)
DeepGlobeLandCover(root=str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "data.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
DeepGlobeLandCover(root=str(tmp_path), checksum=True)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
DeepGlobeLandCover(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(
RuntimeError,
match="Dataset not found in `root`, either"
+ " specify a different `root` directory or manually download"
+ " the dataset to this directory.",
):
DeepGlobeLandCover(str(tmp_path))

def test_plot(self, dataset: DeepGlobeLandCover) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()
3 changes: 3 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
DeepGlobeLandCoverDataModule,
ETCI2021DataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
Expand All @@ -34,6 +35,8 @@ class TestSemanticSegmentationTask:
"name,classname",
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("deepglobelandcover_0", DeepGlobeLandCoverDataModule),
("deepglobelandcover_5", DeepGlobeLandCoverDataModule),
("etci2021", ETCI2021DataModule),
("inria", InriaAerialImageLabelingDataModule),
("landcoverai", LandCoverAIDataModule),
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .chesapeake import ChesapeakeCVPRDataModule
from .cowc import COWCCountingDataModule
from .cyclone import CycloneDataModule
from .deepglobelandcover import DeepGlobeLandCoverDataModule
from .etci2021 import ETCI2021DataModule
from .eurosat import EuroSATDataModule
from .fair1m import FAIR1MDataModule
Expand All @@ -32,6 +33,7 @@
# VisionDataset
"BigEarthNetDataModule",
"COWCCountingDataModule",
"DeepGlobeLandCoverDataModule",
"ETCI2021DataModule",
"EuroSATDataModule",
"FAIR1MDataModule",
Expand Down
122 changes: 122 additions & 0 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""DeepGlobe Land Cover Classification Challenge datamodule."""

from typing import Any, Dict, Optional

import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose

from ..datasets import DeepGlobeLandCover
from .utils import dataset_split


class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the DeepGlobe Land Cover dataset.

Uses the train/test splits from the dataset.

"""

def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for DeepGlobe Land Cover based DataLoaders.

Args:
root_dir: The ``root`` argument to pass to the DeepGlobe Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.

Args:
sample: input image dictionary

Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.

This method is called once per GPU per run.

Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

dataset = DeepGlobeLandCover(self.root_dir, "train", transforms=transforms)

self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset

self.test_dataset = DeepGlobeLandCover(
self.root_dir, "test", transforms=transforms
)

def train_dataloader(self) -> DataLoader[Dict[str, Any]]:
"""Return a DataLoader for training.

Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Dict[str, Any]]:
"""Return a DataLoader for validation.

Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
"""Return a DataLoader for testing.

Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .deepglobelandcover import DeepGlobeLandCover
from .dfc2022 import DFC2022
from .eddmaps import EDDMapS
from .enviroatlas import EnviroAtlas
Expand Down Expand Up @@ -148,6 +149,7 @@
"COWCCounting",
"COWCDetection",
"CV4AKenyaCropType",
"DeepGlobeLandCover",
"DFC2022",
"EnviroAtlas",
"ETCI2021",
Expand Down
Loading