diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 4edeb0c37a9..1ff0044c30b 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -29,6 +29,11 @@ COWC .. autoclass:: COWCCountingDataModule +Deep Globe Land Cover Challenge +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: DeepGlobeLandCoverDataModule + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index f5e171313c9..41a4960e354 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -164,6 +164,11 @@ Kenya Crop Type .. autoclass:: CV4AKenyaCropType +Deep Globe Land Cover Challenge +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: DeepGlobeLandCover + DFC2022 ^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 8e11d7a23b0..7fca8bebec7 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -4,6 +4,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `BigEarthNet`_,C,Sentinel-1/2,"590,326",19--43,120x120,10,"SAR, MSI" `COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB `Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI +`Deep Globe Land Cover Challenge`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,,15,"2,000x2,000",0.5,RGB `ETCI2021 Flood Detection`_,S,Sentinel-1,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"27,000",10,64x64,10,MSI diff --git a/tests/conf/deepglobelandcover_0.yaml b/tests/conf/deepglobelandcover_0.yaml new file mode 100644 index 00000000000..7a696c7615f --- /dev/null +++ b/tests/conf/deepglobelandcover_0.yaml @@ -0,0 +1,19 @@ +experiment: + task: "deepglobelandcover" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root_dir: "tests/data/deepglobelandcover" + val_split_pct: 0.0 + batch_size: 1 + num_workers: 0 diff --git a/tests/conf/deepglobelandcover_5.yaml b/tests/conf/deepglobelandcover_5.yaml new file mode 100644 index 00000000000..18499deebec --- /dev/null +++ b/tests/conf/deepglobelandcover_5.yaml @@ -0,0 +1,19 @@ +experiment: + task: "deepglobelandcover" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + datamodule: + root_dir: "tests/data/deepglobelandcover" + val_split_pct: 0.5 + batch_size: 1 + num_workers: 0 diff --git a/tests/data/deepglobelandcover/data.py b/tests/data/deepglobelandcover/data.py new file mode 100644 index 00000000000..1c8778bf8d4 --- /dev/null +++ b/tests/data/deepglobelandcover/data.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import numpy as np +from PIL import Image +from torchvision.datasets.utils import calculate_md5 + + +def generate_test_data(root: str, n_samples: int = 3) -> str: + """Create test data archive for DeepGlobeLandCover dataset. + + Args: + root: path to store test data + n_samples: number of samples. + + Returns: + md5 hash of created archive + """ + dtype = np.uint8 + size = 2 + + folder_path = os.path.join(root, "data") + + train_img_dir = os.path.join(folder_path, "data", "training_data", "images") + train_mask_dir = os.path.join(folder_path, "data", "training_data", "masks") + test_img_dir = os.path.join(folder_path, "data", "test_data", "images") + test_mask_dir = os.path.join(folder_path, "data", "test_data", "masks") + + os.makedirs(train_img_dir, exist_ok=True) + os.makedirs(train_mask_dir, exist_ok=True) + os.makedirs(test_img_dir, exist_ok=True) + os.makedirs(test_mask_dir, exist_ok=True) + + train_ids = [1, 2, 3] + test_ids = [8, 9, 10] + + for i in range(n_samples): + train_id = train_ids[i] + test_id = test_ids[i] + + dtype_max = np.iinfo(dtype).max + train_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype) + train_img = Image.fromarray(train_arr) + train_img.save(os.path.join(train_img_dir, str(train_id) + "_sat.jpg")) + + test_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype) + test_img = Image.fromarray(test_arr) + test_img.save(os.path.join(test_img_dir, str(test_id) + "_sat.jpg")) + + train_mask_arr = np.full((size, size, 3), (0, 255, 255), dtype=dtype) + train_mask_img = Image.fromarray(train_mask_arr) + train_mask_img.save(os.path.join(train_mask_dir, str(train_id) + "_mask.png")) + + test_mask_arr = np.full((size, size, 3), (255, 0, 255), dtype=dtype) + test_mask_img = Image.fromarray(test_mask_arr) + test_mask_img.save(os.path.join(test_mask_dir, str(test_id) + "_mask.png")) + + # Create archive + shutil.make_archive(folder_path, "zip", folder_path) + shutil.rmtree(folder_path) + return calculate_md5(f"{folder_path}.zip") + + +if __name__ == "__main__": + md5_hash = generate_test_data(os.getcwd(), 3) + print(md5_hash + "\n") diff --git a/tests/data/deepglobelandcover/data.zip b/tests/data/deepglobelandcover/data.zip new file mode 100644 index 00000000000..10c7bcbf308 Binary files /dev/null and b/tests/data/deepglobelandcover/data.zip differ diff --git a/tests/datasets/test_deepglobelandcover.py b/tests/datasets/test_deepglobelandcover.py new file mode 100644 index 00000000000..e6f08bbf4e6 --- /dev/null +++ b/tests/datasets/test_deepglobelandcover.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch + +from torchgeo.datasets import DeepGlobeLandCover + + +class TestDeepGlobeLandCover: + @pytest.fixture(params=["train", "test"]) + def dataset( + self, monkeypatch: MonkeyPatch, request: SubRequest + ) -> DeepGlobeLandCover: + md5 = "2cbd68d36b1485f09f32d874dde7c5c5" + monkeypatch.setattr(DeepGlobeLandCover, "md5", md5) + root = os.path.join("tests", "data", "deepglobelandcover") + split = request.param + transforms = nn.Identity() + return DeepGlobeLandCover(root, split, transforms, checksum=True) + + def test_getitem(self, dataset: DeepGlobeLandCover) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + + def test_len(self, dataset: DeepGlobeLandCover) -> None: + assert len(dataset) == 3 + + def test_extract(self, tmp_path: Path) -> None: + root = os.path.join("tests", "data", "deepglobelandcover") + filename = "data.zip" + shutil.copyfile( + os.path.join(root, filename), os.path.join(str(tmp_path), filename) + ) + DeepGlobeLandCover(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "data.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + DeepGlobeLandCover(root=str(tmp_path), checksum=True) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + DeepGlobeLandCover(split="foo") + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises( + RuntimeError, + match="Dataset not found in `root`, either" + + " specify a different `root` directory or manually download" + + " the dataset to this directory.", + ): + DeepGlobeLandCover(str(tmp_path)) + + def test_plot(self, dataset: DeepGlobeLandCover) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index bfd61291008..c399293345f 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -13,6 +13,7 @@ from torchgeo.datamodules import ( ChesapeakeCVPRDataModule, + DeepGlobeLandCoverDataModule, ETCI2021DataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, @@ -34,6 +35,8 @@ class TestSemanticSegmentationTask: "name,classname", [ ("chesapeake_cvpr_5", ChesapeakeCVPRDataModule), + ("deepglobelandcover_0", DeepGlobeLandCoverDataModule), + ("deepglobelandcover_5", DeepGlobeLandCoverDataModule), ("etci2021", ETCI2021DataModule), ("inria", InriaAerialImageLabelingDataModule), ("landcoverai", LandCoverAIDataModule), diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 02cb2b83e14..2d4ef64c3b5 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -7,6 +7,7 @@ from .chesapeake import ChesapeakeCVPRDataModule from .cowc import COWCCountingDataModule from .cyclone import CycloneDataModule +from .deepglobelandcover import DeepGlobeLandCoverDataModule from .etci2021 import ETCI2021DataModule from .eurosat import EuroSATDataModule from .fair1m import FAIR1MDataModule @@ -32,6 +33,7 @@ # VisionDataset "BigEarthNetDataModule", "COWCCountingDataModule", + "DeepGlobeLandCoverDataModule", "ETCI2021DataModule", "EuroSATDataModule", "FAIR1MDataModule", diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py new file mode 100644 index 00000000000..632add64a92 --- /dev/null +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""DeepGlobe Land Cover Classification Challenge datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import DeepGlobeLandCover +from .utils import dataset_split + + +class DeepGlobeLandCoverDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the DeepGlobe Land Cover dataset. + + Uses the train/test splits from the dataset. + + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for DeepGlobe Land Cover based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the DeepGlobe Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = DeepGlobeLandCover(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = DeepGlobeLandCover( + self.root_dir, "test", transforms=transforms + ) + + def train_dataloader(self) -> DataLoader[Dict[str, Any]]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Dict[str, Any]]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Dict[str, Any]]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index d2b0ff1858e..e7bbb25b9b8 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -27,6 +27,7 @@ from .cowc import COWC, COWCCounting, COWCDetection from .cv4a_kenya_crop_type import CV4AKenyaCropType from .cyclone import TropicalCycloneWindEstimation +from .deepglobelandcover import DeepGlobeLandCover from .dfc2022 import DFC2022 from .eddmaps import EDDMapS from .enviroatlas import EnviroAtlas @@ -148,6 +149,7 @@ "COWCCounting", "COWCDetection", "CV4AKenyaCropType", + "DeepGlobeLandCover", "DFC2022", "EnviroAtlas", "ETCI2021", diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py new file mode 100644 index 00000000000..645bef5049c --- /dev/null +++ b/torchgeo/datasets/deepglobelandcover.py @@ -0,0 +1,270 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""DeepGlobe Land Cover Classification Challenge dataset.""" + +import os +from typing import Callable, Dict, Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset +from .utils import ( + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, + rgb_to_mask, +) + + +class DeepGlobeLandCover(VisionDataset): + """DeepGlobe Land Cover Classification Challenge dataset. + + The `DeepGlobe Land Cover Classification Challenge + `__ dataset + offers high-resolution sub-meter satellite imagery focusing for the task of + semantic segmentation to detect areas of urban, agriculture, rangeland, forest, + water, barren, and unknown. It contains 1,146 satellite images of size + 2448 x 2448 pixels in total, split into training/validation/test sets, the original + dataset can be downloaded from `Kaggle `__. + However, we only use the training dataset with 803 images since the original test + and valid dataset are not accompanied by labels. The dataset that we use with a + custom train/test split can be downloaded from `Kaggle `__ (created as a + part of Computer Vision by Deep Learning (CS4245) course offered at TU Delft). + + Dataset format: + + * images are RGB data + * masks are RGB image with with unique RGB values representing the class + + Dataset classes: + + 0. Urban land + 1. Agriculture land + 2. Rangeland + 3. Forest land + 4. Water + 5. Barren land + 6. Unknown + + File names for satellite images and the corresponding mask image are id_sat.jpg and + id_mask.png, where id is an integer assigned to every image. + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/pdf/1805.06561.pdf + + .. versionadded:: 0.3 + """ + + filename = "data.zip" + data_root = "data" + md5 = "f32684b0b2bf6f8d604cd359a399c061" + splits = ["train", "test"] + classes = [ + "Urban land", + "Agriculture land", + "Rangeland", + "Forest land", + "Water", + "Barren land", + "Unknown", + ] + colormap = [ + (0, 255, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 0), + (0, 0, 255), + (255, 255, 255), + (0, 0, 0), + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + checksum: bool = False, + ) -> None: + """Initialize a new DeepGlobeLandCover dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train" or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + """ + assert split in self.splits + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + if split == "train": + split_folder = "training_data" + else: + split_folder = "test_data" + + self.image_fns = [] + self.mask_fns = [] + for image in sorted( + os.listdir(os.path.join(root, self.data_root, split_folder, "images")) + ): + if image.endswith(".jpg"): + id = image[:-8] + image_path = os.path.join( + root, self.data_root, split_folder, "images", image + ) + mask_path = os.path.join( + root, self.data_root, split_folder, "masks", str(id) + "_mask.png" + ) + + self.image_fns.append(image_path) + self.mask_fns.append(mask_path) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + image = self._load_image(index) + mask = self._load_target(index) + sample = {"image": image, "mask": mask} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.image_fns) + + def _load_image(self, index: int) -> Tensor: + """Load a single image. + + Args: + index: index to return + + Returns: + the image + """ + path = self.image_fns[index] + + with Image.open(path) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img) + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, index: int) -> Tensor: + """Load the target mask for a single image. + + Args: + index: index to return + + Returns: + the target mask + """ + path = self.mask_fns[index] + with Image.open(path) as img: + array: "np.typing.NDArray[np.uint8]" = np.array(img) + array = rgb_to_mask(array, self.colormap) + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.to(torch.long) + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if checksum fails or the dataset is not downloaded + """ + # Check if the files already exist + if os.path.exists(os.path.join(self.root, self.data_root)): + return + + # Check if .zip file already exists (if so extract) + filepath = os.path.join(self.root, self.filename) + + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + raise RuntimeError( + "Dataset not found in `root`, either specify a different" + + " `root` directory or manually download the dataset to this directory." + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + alpha: float = 0.5, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + alpha: opacity with which to render predictions on top of the imagery + + Returns: + a matplotlib Figure with the rendered sample + """ + ncols = 1 + image1 = draw_semantic_segmentation_masks( + sample["image"], sample["mask"], alpha=alpha, colors=self.colormap + ) + if "prediction" in sample: + ncols += 1 + image2 = draw_semantic_segmentation_masks( + sample["image"], sample["prediction"], alpha=alpha, colors=self.colormap + ) + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + if ncols > 1: + (ax0, ax1) = axs + else: + ax0 = axs + + ax0.imshow(image1) + ax0.axis("off") + if ncols > 1: + ax1.imshow(image2) + ax1.axis("off") + + if show_titles: + ax0.set_title("Ground Truth") + if ncols > 1: + ax1.set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 3981202d24a..be7ffd0d48b 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -642,8 +642,7 @@ def draw_semantic_segmentation_masks( a version of ``image`` overlayed with the colors given by ``mask`` and ``colors`` """ - classes = torch.unique(mask) - classes = classes[1:] + classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8)) class_masks = mask == classes[:, None, None] img = draw_segmentation_masks( image=image, masks=class_masks, alpha=alpha, colors=colors