Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuakeSet dataset #1997

Merged
merged 18 commits into from
Apr 19, 2024
Merged
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ Potsdam

.. autoclass:: Potsdam2DDataModule

QuakeSet
^^^^^^^^

.. autoclass:: QuakeSetDataModule

RESISC45
^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ Potsdam

.. autoclass:: Potsdam2D

QuakeSet
^^^^^^^^

.. autoclass:: QuakeSet

ReforesTree
^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI
`PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB
`Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI
`QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR
`ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB
`RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB
`Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR
Expand Down
14 changes: 14 additions & 0 deletions tests/conf/quakeset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model:
class_path: ClassificationTask
init_args:
loss: "ce"
model: "resnet18"
in_channels: 4
num_classes: 2
data:
class_path: QuakeSetDataModule
init_args:
batch_size: 2
dict_kwargs:
root: "tests/data/quakeset"
download: false
49 changes: 49 additions & 0 deletions tests/data/quakeset/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os

import h5py
import numpy as np

NUM_CHANNELS = 2
SIZE = 32

np.random.seed(0)

filename = "earthquakes.h5"

splits = {
"train": ["611645479", "611658170"],
"validation": ["611684805", "611744956"],
"test": ["611798698", "611818836"],
}

# Remove old data
if os.path.exists(filename):
os.remove(filename)

# Create dataset file
data = np.random.randn(SIZE, SIZE, NUM_CHANNELS)
data = data.astype(np.float32)


with h5py.File(filename, "w") as f:
for split, keys in splits.items():
for key in keys:
sample = f.create_group(key)
sample.attrs.create(name="magnitude", data=np.float32(0.0))
sample.attrs.create(name="split", data=split)
for i in range(2):
patch = sample.create_group(f"patch_{i}")
patch.create_dataset("before", data=data)
patch.create_dataset("pre", data=data)
patch.create_dataset("post", data=data)

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"md5: {md5}")
Binary file added tests/data/quakeset/earthquakes.h5
Binary file not shown.
88 changes: 88 additions & 0 deletions tests/datasets/test_quakeset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, QuakeSet


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestQuakeSet:
@pytest.fixture(params=["train", "val", "test"])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> QuakeSet:
monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url)
url = os.path.join("tests", "data", "quakeset", "earthquakes.h5")
md5 = "127d0d6a1f82d517129535f50053a4c9"
monkeypatch.setattr(QuakeSet, "md5", md5)
monkeypatch.setattr(QuakeSet, "url", url)
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return QuakeSet(
root, split, transforms=transforms, download=True, checksum=True
)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)

def test_mock_missing_module(
self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="h5py is not installed and is required to use this dataset",
):
QuakeSet(dataset.root, download=True, checksum=True)

def test_getitem(self, dataset: QuakeSet) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 4

def test_len(self, dataset: QuakeSet) -> None:
assert len(dataset) == 8

def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None:
QuakeSet(root=str(tmp_path), download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
QuakeSet(str(tmp_path))

def test_plot(self, dataset: QuakeSet) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
x["magnitude"] = torch.tensor(0.0)
dataset.plot(x)
plt.close()
3 changes: 2 additions & 1 deletion tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class TestClassificationTask:
"eurosat",
"eurosat100",
"fire_risk",
"quakeset",
"resisc45",
"so2sat_all",
"so2sat_s1",
Expand All @@ -87,7 +88,7 @@ class TestClassificationTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name.startswith("so2sat"):
if name.startswith("so2sat") or name == "quakeset":
pytest.importorskip("h5py", minversion="3")

config = os.path.join("tests", "conf", name + ".yaml")
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
from .potsdam import Potsdam2DDataModule
from .quakeset import QuakeSetDataModule
from .resisc45 import RESISC45DataModule
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
Expand Down Expand Up @@ -76,6 +77,7 @@
"NASAMarineDebrisDataModule",
"OSCDDataModule",
"Potsdam2DDataModule",
"QuakeSetDataModule",
"RESISC45DataModule",
"SeasonalContrastS2DataModule",
"SEN12MSDataModule",
Expand Down
42 changes: 42 additions & 0 deletions torchgeo/datamodules/quakeset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""QuakeSet datamodule."""

from typing import Any

import kornia.augmentation as K
import torch

from ..datasets import QuakeSet
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


class QuakeSetDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the QuakeSet dataset.

.. versionadded:: 0.6
"""

mean = torch.tensor(0.0)
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
std = torch.tensor(1.0)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new QuakeSetDataModule instance.

Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.QuakeSet`.
"""
super().__init__(QuakeSet, batch_size, num_workers, **kwargs)
self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["image"],
)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .prisma import PRISMA
from .quakeset import QuakeSet
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rwanda_field_boundary import RwandaFieldBoundary
Expand Down Expand Up @@ -226,6 +227,7 @@
"PASTIS",
"PatternNet",
"Potsdam2D",
"QuakeSet",
"RESISC45",
"ReforesTree",
"RwandaFieldBoundary",
Expand Down
Loading
Loading