diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index f41463b9b76..ef7a91b937d 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -128,6 +128,11 @@ Potsdam .. autoclass:: Potsdam2DDataModule +QuakeSet +^^^^^^^^ + +.. autoclass:: QuakeSetDataModule + RESISC45 ^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 341be4d4916..e3241b6315c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -348,6 +348,11 @@ Potsdam .. autoclass:: Potsdam2D +QuakeSet +^^^^^^^^ + +.. autoclass:: QuakeSet + ReforesTree ^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index a34b918b5ec..2dac9021daa 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI `PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI +`QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR `ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB `Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR diff --git a/tests/conf/quakeset.yaml b/tests/conf/quakeset.yaml new file mode 100644 index 00000000000..9d54e1b6d4f --- /dev/null +++ b/tests/conf/quakeset.yaml @@ -0,0 +1,14 @@ +model: + class_path: ClassificationTask + init_args: + loss: "ce" + model: "resnet18" + in_channels: 4 + num_classes: 2 +data: + class_path: QuakeSetDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: "tests/data/quakeset" + download: false diff --git a/tests/data/quakeset/data.py b/tests/data/quakeset/data.py new file mode 100644 index 00000000000..3d6eb66938b --- /dev/null +++ b/tests/data/quakeset/data.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os + +import h5py +import numpy as np + +NUM_CHANNELS = 2 +SIZE = 32 + +np.random.seed(0) + +filename = "earthquakes.h5" + +splits = { + "train": ["611645479", "611658170"], + "validation": ["611684805", "611744956"], + "test": ["611798698", "611818836"], +} + +# Remove old data +if os.path.exists(filename): + os.remove(filename) + +# Create dataset file +data = np.random.randn(SIZE, SIZE, NUM_CHANNELS) +data = data.astype(np.float32) + + +with h5py.File(filename, "w") as f: + for split, keys in splits.items(): + for key in keys: + sample = f.create_group(key) + sample.attrs.create(name="magnitude", data=np.float32(0.0)) + sample.attrs.create(name="split", data=split) + for i in range(2): + patch = sample.create_group(f"patch_{i}") + patch.create_dataset("before", data=data) + patch.create_dataset("pre", data=data) + patch.create_dataset("post", data=data) + +# Compute checksums +with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"md5: {md5}") diff --git a/tests/data/quakeset/earthquakes.h5 b/tests/data/quakeset/earthquakes.h5 new file mode 100644 index 00000000000..71deb28b2df Binary files /dev/null and b/tests/data/quakeset/earthquakes.h5 differ diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py new file mode 100644 index 00000000000..0c361aa27a9 --- /dev/null +++ b/tests/datasets/test_quakeset.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import builtins +import os +import shutil +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import DatasetNotFoundError, QuakeSet + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestQuakeSet: + @pytest.fixture(params=["train", "val", "test"]) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> QuakeSet: + monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url) + url = os.path.join("tests", "data", "quakeset", "earthquakes.h5") + md5 = "127d0d6a1f82d517129535f50053a4c9" + monkeypatch.setattr(QuakeSet, "md5", md5) + monkeypatch.setattr(QuakeSet, "url", url) + root = str(tmp_path) + split = request.param + transforms = nn.Identity() + return QuakeSet( + root, split, transforms=transforms, download=True, checksum=True + ) + + @pytest.fixture + def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: + import_orig = builtins.__import__ + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "h5py": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + def test_mock_missing_module( + self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None + ) -> None: + with pytest.raises( + ImportError, + match="h5py is not installed and is required to use this dataset", + ): + QuakeSet(dataset.root, download=True, checksum=True) + + def test_getitem(self, dataset: QuakeSet) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape[0] == 4 + + def test_len(self, dataset: QuakeSet) -> None: + assert len(dataset) == 8 + + def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: + QuakeSet(root=str(tmp_path), download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + QuakeSet(str(tmp_path)) + + def test_plot(self, dataset: QuakeSet) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + x["magnitude"] = torch.tensor(0.0) + dataset.plot(x) + plt.close() diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 02183978995..fb176a53071 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -76,6 +76,7 @@ class TestClassificationTask: "eurosat", "eurosat100", "fire_risk", + "quakeset", "resisc45", "so2sat_all", "so2sat_s1", @@ -87,7 +88,7 @@ class TestClassificationTask: def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name.startswith("so2sat"): + if name.startswith("so2sat") or name == "quakeset": pytest.importorskip("h5py", minversion="3") config = os.path.join("tests", "conf", name + ".yaml") diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 5ee3a47aaaa..7f0ee8a263d 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -26,6 +26,7 @@ from .nasa_marine_debris import NASAMarineDebrisDataModule from .oscd import OSCDDataModule from .potsdam import Potsdam2DDataModule +from .quakeset import QuakeSetDataModule from .resisc45 import RESISC45DataModule from .seco import SeasonalContrastS2DataModule from .sen12ms import SEN12MSDataModule @@ -76,6 +77,7 @@ "NASAMarineDebrisDataModule", "OSCDDataModule", "Potsdam2DDataModule", + "QuakeSetDataModule", "RESISC45DataModule", "SeasonalContrastS2DataModule", "SEN12MSDataModule", diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py new file mode 100644 index 00000000000..1963ba48ae2 --- /dev/null +++ b/torchgeo/datamodules/quakeset.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""QuakeSet datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch + +from ..datasets import QuakeSet +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule + + +class QuakeSetDataModule(NonGeoDataModule): + """LightningDataModule implementation for the QuakeSet dataset. + + .. versionadded:: 0.6 + """ + + mean = torch.tensor(0.0) + std = torch.tensor(1.0) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new QuakeSetDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.QuakeSet`. + """ + super().__init__(QuakeSet, batch_size, num_workers, **kwargs) + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=["image"], + ) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 5f3f974e2b2..739eeeaec27 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -90,6 +90,7 @@ from .patternnet import PatternNet from .potsdam import Potsdam2D from .prisma import PRISMA +from .quakeset import QuakeSet from .reforestree import ReforesTree from .resisc45 import RESISC45 from .rwanda_field_boundary import RwandaFieldBoundary @@ -226,6 +227,7 @@ "PASTIS", "PatternNet", "Potsdam2D", + "QuakeSet", "RESISC45", "ReforesTree", "RwandaFieldBoundary", diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py new file mode 100644 index 00000000000..025b5f4987b --- /dev/null +++ b/torchgeo/datasets/quakeset.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""QuakeSet dataset.""" + +import os +from collections.abc import Callable +from typing import Any, cast + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import DatasetNotFoundError, download_url, percentile_normalization + + +class QuakeSet(NonGeoDataset): + """QuakeSet dataset. + + `QuakeSet `__ + is a dataset for Earthquake Change Detection and Magnitude Estimation and is used + for the Seismic Monitoring and Analysis (SMAC) ECML-PKDD 2024 Discovery Challenge. + + Dataset features: + + * Sentinel-1 SAR imagery + * before/pre/post imagery of areas affected by earthquakes + * 2 SAR bands (VV/VH) + * 3,327 pairs of pre and post images with 5 m per pixel resolution (512x512 px) + * 2 classification labels (unaffected / affected by earthquake) + * pre/post image pairs represent earthquake affected areas + * before/pre image pairs represent hard negative unaffected areas + * earthquake magnitudes for each sample + + Dataset format: + + * single hdf5 dataset containing images, magnitudes, hypercenters, and splits + + Dataset classes: + + 0. unaffected area + 1. earthquake affected area + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2403.18116 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `h5py `_ to load the dataset + + .. versionadded:: 0.6 + """ + + filename = "earthquakes.h5" + url = "https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5" + md5 = "76fc7c76b7ca56f4844d852e175e1560" + splits = {"train": "train", "val": "validation", "test": "test"} + classes = ["unaffected_area", "earthquake_affected_area"] + + def __init__( + self, + root: str = "data", + split: str = "train", + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new QuakeSet dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If ``split`` argument is invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + ImportError: if h5py is not installed + """ + assert split in self.splits + + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + self.filepath = os.path.join(root, self.filename) + + self._verify() + + try: + import h5py # noqa: F401 + except ImportError: + raise ImportError( + "h5py is not installed and is required to use this dataset" + ) + + self.data = self._load_data() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + sample containing image and mask + """ + image = self._load_image(index) + label = torch.tensor(self.data[index]["label"]) + magnitude = torch.tensor(self.data[index]["magnitude"]) + + sample = {"image": image, "label": label, "magnitude": magnitude} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.data) + + def _load_data(self) -> list[dict[str, Any]]: + """Return the metadata for a given split. + + Returns: + the sample keys, patches, images, labels, and magnitudes + """ + import h5py + + data = [] + with h5py.File(self.filepath) as f: + for k in sorted(f.keys()): + if f[k].attrs["split"] != self.splits[self.split]: + continue + + for patch in sorted(f[k].keys()): + if patch not in ["x", "y"]: + # positive sample + magnitude = float(f[k].attrs["magnitude"]) + data.append( + dict( + key=k, + patch=patch, + images=("pre", "post"), + label=1, + magnitude=magnitude, + ) + ) + + # hard negative sample + if "before" in f[k][patch].keys(): + data.append( + dict( + key=k, + patch=patch, + images=("before", "pre"), + label=0, + magnitude=0.0, + ) + ) + return data + + def _load_image(self, index: int) -> Tensor: + """Load a single image. + + Args: + index: index to return + + Returns: + the image + """ + import h5py + + key = self.data[index]["key"] + patch = self.data[index]["patch"] + images = self.data[index]["images"] + + with h5py.File(self.filepath) as f: + pre_array = f[key][patch][images[0]][:] + pre_array = np.nan_to_num(pre_array, nan=0) + post_array = f[key][patch][images[1]][:] + post_array = np.nan_to_num(post_array, nan=0) + array = np.concatenate([pre_array, post_array], axis=-1) + array = array.astype(np.float32) + + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(self.filepath): + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + if not os.path.exists(self.filepath): + download_url( + self.url, + self.root, + filename=self.filename, + md5=self.md5 if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"].permute((1, 2, 0)).numpy() + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + # Create false color image for image1 + vv = percentile_normalization(image[..., 0]) + 1e-16 + vh = percentile_normalization(image[..., 1]) + 1e-16 + fci1 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) + + # Create false color image for image2 + vv = percentile_normalization(image[..., 2]) + 1e-16 + vh = percentile_normalization(image[..., 3]) + 1e-16 + fci2 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1) + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + ncols = 2 + fig, axs = plt.subplots( + nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True + ) + + axs[0].imshow(fci1) + axs[0].axis("off") + axs[0].set_title("Image Pre") + axs[1].imshow(fci2) + axs[1].axis("off") + axs[1].set_title("Image Post") + + if show_titles: + title = f"Label: {label_class}" + if "magnitude" in sample: + magnitude = cast(float, sample["magnitude"].item()) + title += f" | Magnitude: {magnitude:.2f}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + fig.supxlabel(title, y=0.22) + + if suptitle is not None: + fig.suptitle(suptitle, y=0.8) + + fig.tight_layout() + + return fig