diff --git a/tests/data/nasa_marine_debris/data.py b/tests/data/nasa_marine_debris/data.py new file mode 100755 index 0000000000..a782dea3d2 --- /dev/null +++ b/tests/data/nasa_marine_debris/data.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio as rio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 +DTYPE = np.uint8 + +np.random.seed(0) + +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 3, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( + 2.1457672119140625e-05, + 0.0, + -87.626953125, + 0.0, + -2.0629065249348766e-05, + 15.977172621632805, + ), +} + +os.makedirs('source', exist_ok=True) +os.makedirs('labels', exist_ok=True) + +files = [ + '20160928_153233_0e16_16816-29821-16', + '20160928_153233_0e16_16816-29824-16', + '20160928_153233_0e16_16816-29825-16', + '20160928_153233_0e16_16816-29828-16', + '20160928_153233_0e16_16816-29829-16', +] +for file in files: + with rio.open(os.path.join('source', f'{file}.tif'), 'w', **profile) as f: + for i in range(1, 4): + Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + f.write(Z, i) + + count = np.random.randint(5) + x = np.random.randint(SIZE, size=count) + y = np.random.randint(SIZE, size=count) + dx = np.random.randint(5, size=count) + dy = np.random.randint(5, size=count) + label = np.ones(count) + Z = np.stack([x, y, x + dx, y + dy, label], axis=-1) + np.save(os.path.join('labels', f'{file}.npy'), Z) diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29821-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29821-16.npy new file mode 100644 index 0000000000..104f61ed70 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29821-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29824-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29824-16.npy new file mode 100644 index 0000000000..4de2e99e4f Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29824-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29825-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29825-16.npy new file mode 100644 index 0000000000..e90c5b706e Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29825-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29828-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29828-16.npy new file mode 100644 index 0000000000..99aa923126 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29828-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29829-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29829-16.npy new file mode 100644 index 0000000000..577edba2c8 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29829-16.npy differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels.tar.gz b/tests/data/nasa_marine_debris/nasa_marine_debris_labels.tar.gz deleted file mode 100644 index 3a4c8edbee..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels.tar.gz and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29821-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29821-16/pixel_bounds.npy deleted file mode 100755 index eeaaf46f29..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29821-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29824-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29824-16/pixel_bounds.npy deleted file mode 100755 index eeaaf46f29..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29824-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29825-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29825-16/pixel_bounds.npy deleted file mode 100755 index f559349de8..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29825-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29828-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29828-16/pixel_bounds.npy deleted file mode 100755 index eeaaf46f29..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29828-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source.tar.gz b/tests/data/nasa_marine_debris/nasa_marine_debris_source.tar.gz deleted file mode 100644 index b1a3b53d41..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source.tar.gz and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29821-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29821-16/image_geotiff.tif deleted file mode 100644 index 471c657e5f..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29821-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29824-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29824-16/image_geotiff.tif deleted file mode 100644 index c472caabbe..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29824-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29825-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29825-16/image_geotiff.tif deleted file mode 100644 index d6fd058b20..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29825-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29828-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29828-16/image_geotiff.tif deleted file mode 100644 index 65ff7677ab..0000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29828-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29821-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29821-16.tif new file mode 100644 index 0000000000..66b8ae8e12 Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29821-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29824-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29824-16.tif new file mode 100644 index 0000000000..2b1609165a Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29824-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29825-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29825-16.tif new file mode 100644 index 0000000000..1366468ec1 Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29825-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29828-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29828-16.tif new file mode 100644 index 0000000000..f8d75a064a Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29828-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29829-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29829-16.tif new file mode 100644 index 0000000000..85c99a80c6 Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29829-16.tif differ diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 1697195f46..e2787b51b7 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -13,41 +11,18 @@ from pytest import MonkeyPatch from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'nasa_marine_debris', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(collection_id: str, **kwargs: str) -> Collection: - return Collection() - - -class Collection_corrupted: - def download(self, output_dir: str, **kwargs: str) -> None: - filenames = NASAMarineDebris.filenames - for filename in filenames: - with open(os.path.join(output_dir, filename), 'w') as f: - f.write('bad') - - -def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted: - return Collection_corrupted() +from torchgeo.datasets.utils import Executable class TestNASAMarineDebris: - @pytest.fixture() - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7'] - monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s) - root = tmp_path + @pytest.fixture + def dataset( + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> NASAMarineDebris: + url = os.path.join('tests', 'data', 'nasa_marine_debris') + monkeypatch.setattr(NASAMarineDebris, 'url', url) transforms = nn.Identity() - return NASAMarineDebris(root, transforms, download=True, checksum=True) + return NASAMarineDebris(tmp_path, transforms, download=True) def test_getitem(self, dataset: NASAMarineDebris) -> None: x = dataset[0] @@ -58,36 +33,12 @@ def test_getitem(self, dataset: NASAMarineDebris) -> None: assert x['boxes'].shape[-1] == 4 def test_len(self, dataset: NASAMarineDebris) -> None: - assert len(dataset) == 4 + assert len(dataset) == 5 def test_already_downloaded( self, dataset: NASAMarineDebris, tmp_path: Path ) -> None: - NASAMarineDebris(root=tmp_path, download=True) - - def test_already_downloaded_not_extracted( - self, dataset: NASAMarineDebris, tmp_path: Path - ) -> None: - shutil.rmtree(dataset.root) - os.makedirs(tmp_path, exist_ok=True) - Collection().download(output_dir=str(tmp_path)) - NASAMarineDebris(root=tmp_path, download=False) - - def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None: - filenames = NASAMarineDebris.filenames - for filename in filenames: - with open(os.path.join(tmp_path, filename), 'w') as f: - f.write('bad') - with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): - NASAMarineDebris(root=tmp_path, download=False, checksum=True) - - def test_corrupted_new_download( - self, tmp_path: Path, monkeypatch: MonkeyPatch - ) -> None: - with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted) - NASAMarineDebris(root=tmp_path, download=True, checksum=True) + NASAMarineDebris(tmp_path, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index cf018150f2..9c57e29040 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -3,6 +3,7 @@ """NASA Marine Debris dataset.""" +import glob import os from collections.abc import Callable @@ -16,18 +17,13 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import ( - Path, - check_integrity, - download_radiant_mlhub_collection, - extract_archive, -) +from .utils import Path, which class NASAMarineDebris(NonGeoDataset): """NASA Marine Debris dataset. - The `NASA Marine Debris `__ + The `NASA Marine Debris `__ dataset is a dataset for detection of floating marine debris in satellite imagery. Dataset features: @@ -52,26 +48,19 @@ class NASAMarineDebris(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.2 """ - collection_ids = ['nasa_marine_debris_source', 'nasa_marine_debris_labels'] - directories = ['nasa_marine_debris_source', 'nasa_marine_debris_labels'] - filenames = ['nasa_marine_debris_source.tar.gz', 'nasa_marine_debris_labels.tar.gz'] - md5s = ['fe8698d1e68b3f24f0b86b04419a797d', 'd8084f5a72778349e07ac90ec1e1d990'] - class_label = 'marine_debris' + url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa-marine-debris' def __init__( self, root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new NASA Marine Debris Dataset instance. @@ -80,9 +69,6 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: DatasetNotFoundError: If dataset is not found and *download* is False. @@ -90,11 +76,11 @@ def __init__( self.root = root self.transforms = transforms self.download = download - self.api_key = api_key - self.checksum = checksum - self.verbose = verbose + self._verify() - self.files = self._load_files() + + self.source = sorted(glob.glob(os.path.join(self.root, 'source', '*.tif'))) + self.labels = sorted(glob.glob(os.path.join(self.root, 'labels', '*.npy'))) def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -105,15 +91,21 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and labels at that index """ - image = self._load_image(self.files[index]['image']) - boxes = self._load_target(self.files[index]['target']) - sample = {'image': image, 'boxes': boxes} + with rasterio.open(self.source[index]) as source: + image = torch.from_numpy(source.read()).float() + + labels = np.load(self.labels[index]) + + # Boxes contain unnecessary value of 1 after xyxy coords + boxes = torch.from_numpy(labels[:, :4]) # Filter invalid boxes - w_check = (sample['boxes'][:, 2] - sample['boxes'][:, 0]) > 0 - h_check = (sample['boxes'][:, 3] - sample['boxes'][:, 1]) > 0 + w_check = (boxes[:, 2] - boxes[:, 0]) > 0 + h_check = (boxes[:, 3] - boxes[:, 1]) > 0 indices = w_check & h_check - sample['boxes'] = sample['boxes'][indices] + boxes = boxes[indices] + + sample = {'image': image, 'boxes': boxes} if self.transforms is not None: sample = self.transforms(sample) @@ -126,85 +118,13 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.files) - - def _load_image(self, path: Path) -> Tensor: - """Load a single image. - - Args: - path: path to the image - - Returns: - the image - """ - with rasterio.open(path) as f: - array = f.read() - tensor = torch.from_numpy(array).float() - return tensor - - def _load_target(self, path: Path) -> Tensor: - """Load the target bounding boxes for a single image. - - Args: - path: path to the labels - - Returns: - the target boxes - """ - array = np.load(path) - # boxes contain unecessary value of 1 after xyxy coords - array = array[:, :4] - tensor = torch.from_numpy(array) - return tensor - - def _load_files(self) -> list[dict[str, str]]: - """Load a image and label files. - - Returns: - list of dicts containing image and label files - """ - image_root = os.path.join(self.root, self.directories[0]) - target_root = os.path.join(self.root, self.directories[1]) - image_folders = sorted( - f for f in os.listdir(image_root) if not f.endswith('json') - ) - - files = [] - for folder in image_folders: - files.append( - { - 'image': os.path.join(image_root, folder, 'image_geotiff.tif'), - 'target': os.path.join( - target_root, - folder.replace('source', 'labels'), - 'pixel_bounds.npy', - ), - } - ) - return files + return len(self.source) def _verify(self) -> None: """Verify the integrity of the dataset.""" - # Check if the files already exist - exists = [ - os.path.exists(os.path.join(self.root, directory)) - for directory in self.directories - ] - if all(exists): - return - - # Check if zip file already exists (if so then extract) - exists = [] - for filename, md5 in zip(self.filenames, self.md5s): - filepath = os.path.join(self.root, filename) - if os.path.exists(filepath): - if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError('Dataset checksum mismatch.') - exists.append(True) - extract_archive(filepath) - else: - exists.append(False) - + # Check if the directories already exist + dirs = ['source', 'labels'] + exists = [os.path.exists(os.path.join(self.root, d)) for d in dirs] if all(exists): return @@ -212,14 +132,14 @@ def _verify(self) -> None: if not self.download: raise DatasetNotFoundError(self) - # Download and extract the dataset - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, self.api_key) - for filename, md5 in zip(self.filenames, self.md5s): - filepath = os.path.join(self.root, filename) - if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError('Dataset checksum mismatch.') - extract_archive(filepath) + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self,