Skip to content

Commit

Permalink
Support for radiant-mlhub 0.5+ (#1102)
Browse files Browse the repository at this point in the history
* update datasets and tests to support radiant-mlhub>0.5

* add test coverage for nasa_marine_debris corrupted cases

* style fixes

* Correct return type in test_nasa_marine_debris.py

* Update setup.cfg to limit radiant-mlhub version

Co-authored-by: Adam J. Stewart <[email protected]>

* radiant-mlhub version updates to <0.6

* Update environment.yml to not upper bound radiant-mlhub

Co-authored-by: Adam J. Stewart <[email protected]>

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
SpontaneousDuck and adamjstewart authored Feb 17, 2023
1 parent fec90bf commit dc98b60
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 45 deletions.
3 changes: 0 additions & 3 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ updates:
schedule:
interval: "daily"
ignore:
# radiant-mlhub 0.5+ changed download behavior:
# https://github.com/radiantearth/radiant-mlhub/pull/104
- dependency-name: "radiant-mlhub"
# setuptools releases new versions almost daily
- dependency-name: "setuptools"
update-types: ["version-update:semver-patch"]
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies:
- pytorch-lightning>=1.5.1
- git+https://github.com/pytorch/pytorch_sphinx_theme
- pyupgrade>=2.4
- radiant-mlhub>=0.2.1,<0.5
- radiant-mlhub>=0.2.1
- rtree>=1
- scikit-image>=0.18
- scikit-learn>=0.22
Expand Down
4 changes: 1 addition & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ datasets =
pyvista>=0.25.2,<0.39
# radiant-mlhub 0.2.1+ required for api_key bugfix:
# https://github.com/radiantearth/radiant-mlhub/pull/48
# radiant-mlhub 0.5+ changed download behavior:
# https://github.com/radiantearth/radiant-mlhub/pull/104
radiant-mlhub>=0.2.1,<0.5
radiant-mlhub>=0.2.1,<0.6
# rarfile 4+ required for wheels
rarfile>=4,<5
# scikit-image 0.18+ required for numpy 1.17+ compatibility
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from torchgeo.datasets import BeninSmallHolderCashews


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestBeninSmallHolderCashews:
Expand All @@ -33,7 +33,7 @@ def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> BeninSmallHolderCashews:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
source_md5 = "255efff0f03bc6322470949a09bc76db"
labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127"
monkeypatch.setitem(BeninSmallHolderCashews.image_meta, "md5", source_md5)
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchgeo.datasets import CloudCoverDetection


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_cloud_cover_detection_challenge_v1", "*.tar.gz"
Expand All @@ -24,15 +24,15 @@ def download(self, output_dir: str, **kwargs: str) -> None:
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestCloudCoverDetection:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)

test_image_meta = {
"filename": "ref_cloud_cover_detection_challenge_v1_test_source.tar.gz",
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchgeo.datasets import CV4AKenyaCropType


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
Expand All @@ -25,15 +25,15 @@ def download(self, output_dir: str, **kwargs: str) -> None:
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestCV4AKenyaCropType:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
source_md5 = "7f4dcb3f33743dddd73f453176308bfb"
labels_md5 = "95fc59f1d94a85ec00931d4d1280bec9"
monkeypatch.setitem(CV4AKenyaCropType.image_meta, "md5", source_md5)
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from torchgeo.datasets import TropicalCyclone


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
for tarball in glob.iglob(os.path.join("tests", "data", "cyclone", "*.tar.gz")):
shutil.copy(tarball, output_dir)


def fetch(collection_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()


class TestTropicalCyclone:
Expand All @@ -33,7 +33,7 @@ def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> TropicalCyclone:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
md5s = {
"train": {
"source": "2b818e0a0873728dabf52c7054a0ce4c",
Expand Down
40 changes: 34 additions & 6 deletions tests/datasets/test_nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,35 @@
from torchgeo.datasets import NASAMarineDebris


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "nasa_marine_debris", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()


class Collection_corrupted:
def download(self, output_dir: str, **kwargs: str) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(output_dir, filename), "w") as f:
f.write("bad")


def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted:
return Collection_corrupted()


class TestNASAMarineDebris:
@pytest.fixture()
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
md5s = ["6f4f0d2313323950e45bf3fc0c09b5de", "540cf1cf4fd2c13b609d0355abe955d7"]
monkeypatch.setattr(NASAMarineDebris, "md5s", md5s)
root = str(tmp_path)
transforms = nn.Identity()
Expand All @@ -58,9 +70,25 @@ def test_already_downloaded_not_extracted(
) -> None:
shutil.rmtree(dataset.root)
os.makedirs(str(tmp_path), exist_ok=True)
Dataset().download(output_dir=str(tmp_path))
Collection().download(output_dir=str(tmp_path))
NASAMarineDebris(root=str(tmp_path), download=False)

def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(tmp_path, filename), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset checksum mismatch."):
NASAMarineDebris(root=str(tmp_path), download=False, checksum=True)

def test_corrupted_new_download(
self, tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
with pytest.raises(RuntimeError, match="Dataset checksum mismatch."):
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_corrupted)
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
Expand Down
6 changes: 4 additions & 2 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


# TODO: read geospatial information from stac.json files
Expand Down Expand Up @@ -56,6 +56,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
"""

dataset_id = "ts_cashew_benin"
collection_ids = ["ts_cashew_benin_source", "ts_cashew_benin_labels"]
image_meta = {
"filename": "ts_cashew_benin_source.tar.gz",
"md5": "957272c86e518a925a4e0d90dab4f92d",
Expand Down Expand Up @@ -416,7 +417,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
Expand Down
12 changes: 9 additions & 3 deletions torchgeo/datasets/cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


# TODO: read geospatial information from stac.json files
Expand Down Expand Up @@ -54,7 +54,12 @@ class CloudCoverDetection(NonGeoDataset):
.. versionadded:: 0.4
"""

dataset_id = "ref_cloud_cover_detection_challenge_v1"
collection_ids = [
"ref_cloud_cover_detection_challenge_v1_train_source",
"ref_cloud_cover_detection_challenge_v1_train_labels",
"ref_cloud_cover_detection_challenge_v1_test_source",
"ref_cloud_cover_detection_challenge_v1_test_labels",
]

image_meta = {
"train": {
Expand Down Expand Up @@ -332,7 +337,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

image_archive_path = os.path.join(
self.root, self.image_meta[self.split]["filename"]
Expand Down
10 changes: 7 additions & 3 deletions torchgeo/datasets/cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


# TODO: read geospatial information from stac.json files
Expand Down Expand Up @@ -56,7 +56,10 @@ class CV4AKenyaCropType(NonGeoDataset):
imagery and labels from the Radiant Earth MLHub
"""

dataset_id = "ref_african_crops_kenya_02"
collection_ids = [
"ref_african_crops_kenya_02_labels",
"ref_african_crops_kenya_02_source",
]
image_meta = {
"filename": "ref_african_crops_kenya_02_source.tar.gz",
"md5": "9c2004782f6dc83abb1bf45ba4d0da46",
Expand Down Expand Up @@ -394,7 +397,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
Expand Down
11 changes: 9 additions & 2 deletions torchgeo/datasets/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


class TropicalCyclone(NonGeoDataset):
Expand Down Expand Up @@ -45,6 +45,12 @@ class TropicalCyclone(NonGeoDataset):
"""

collection_id = "nasa_tropical_storm_competition"
collection_ids = [
"nasa_tropical_storm_competition_train_source",
"nasa_tropical_storm_competition_test_source",
"nasa_tropical_storm_competition_train_labels",
"nasa_tropical_storm_competition_test_labels",
]
md5s = {
"train": {
"source": "97e913667a398704ea8d28196d91dad6",
Expand Down Expand Up @@ -207,7 +213,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.collection_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

for split, resources in self.md5s.items():
for resource_type in resources:
Expand Down
16 changes: 10 additions & 6 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchvision.utils import draw_bounding_boxes

from .geo import NonGeoDataset
from .utils import download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


class NASAMarineDebris(NonGeoDataset):
Expand Down Expand Up @@ -51,7 +51,7 @@ class NASAMarineDebris(NonGeoDataset):
.. versionadded:: 0.2
"""

dataset_id = "nasa_marine_debris"
collection_ids = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
directories = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
filenames = ["nasa_marine_debris_source.tar.gz", "nasa_marine_debris_labels.tar.gz"]
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
Expand Down Expand Up @@ -189,9 +189,11 @@ def _verify(self) -> None:

# Check if zip file already exists (if so then extract)
exists = []
for filename in self.filenames:
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset checksum mismatch.")
exists.append(True)
extract_archive(filepath)
else:
Expand All @@ -208,11 +210,13 @@ def _verify(self) -> None:
"to automatically download the dataset."
)

# TODO: need a checksum check in here post downloading
# Download and extract the dataset
download_radiant_mlhub_dataset(self.dataset_id, self.root, self.api_key)
for filename in self.filenames:
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset checksum mismatch.")
extract_archive(filepath)

def plot(
Expand Down

0 comments on commit dc98b60

Please sign in to comment.