diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 6120c5025b0..a96bb4f389f 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -37,6 +37,11 @@ Cropland Data Layer (CDL) .. autoclass:: CDL +EnviroAtlas +^^^^^^^^^^^ + +.. autoclass:: EnviroAtlas + Landsat ^^^^^^^ diff --git a/tests/data/enviroatlas/data.py b/tests/data/enviroatlas/data.py new file mode 100644 index 00000000000..28da174f31a --- /dev/null +++ b/tests/data/enviroatlas/data.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from typing import Any, Dict + +import fiona +import fiona.transform +import numpy as np +import rasterio +import shapely.geometry +from rasterio.crs import CRS +from rasterio.transform import Affine +from torchvision.datasets.utils import calculate_md5 + +suffix_to_key_map = { + "a_naip": "naip", + "b_nlcd": "nlcd", + "c_roads": "roads", + "d_water": "water", + "d1_waterways": "waterways", + "d2_waterbodies": "waterbodies", + "e_buildings": "buildings", + "h_highres_labels": "lc", + "prior_from_cooccurrences_101_31": "prior", + "prior_from_cooccurrences_101_31_no_osm_no_buildings": "prior_no_osm_no_buildings", +} + +layer_data_profiles: Dict[str, Dict[Any, Any]] = { + "a_naip": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 4, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "pixel", + }, + "data_type": "continuous", + "vals": (4, 255), + }, + "b_nlcd": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15], + }, + "c_roads": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [0, 1], + }, + "d1_waterways": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [0, 1], + }, + "d2_waterbodies": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [0, 1], + }, + "d_water": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [0, 1], + }, + "e_buildings": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [0, 1], + }, + "h_highres_labels": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 1, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "categorical", + "vals": [10, 20, 30, 40, 70], + }, + "prior_from_cooccurrences_101_31": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 5, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "continuous", + "vals": (0, 225), + }, + "prior_from_cooccurrences_101_31_no_osm_no_buildings": { + "profile": { + "driver": "GTiff", + "dtype": "uint8", + "nodata": None, + "count": 5, + "crs": CRS.from_epsg(26914), + "blockxsize": 512, + "blockysize": 512, + "tiled": True, + "compress": "deflate", + "interleave": "band", + }, + "data_type": "continuous", + "vals": (0, 220), + }, +} + +tile_list = [ + "pittsburgh_pa-2010_1m-train_tiles-debuffered/4007925_se", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw", +] + + +def write_data(path: str, profile: Dict[Any, Any], data_type: Any, vals: Any) -> None: + assert all(key in profile for key in ("count", "height", "width", "dtype")) + with rasterio.open(path, "w", **profile) as dst: + size = (profile["count"], profile["height"], profile["width"]) + dtype = np.dtype(profile["dtype"]) + if data_type == "continuous": + data = np.random.randint(vals[0], vals[1] + 1, size=size, dtype=dtype) + elif data_type == "categorical": + data = np.random.choice(vals, size=size).astype(dtype) + else: + raise ValueError(f"{data_type} is not recognized") + dst.write(data) + + +def generate_test_data(root: str) -> str: + """Creates test data archive for the EnviroAtlas dataset and returns its md5 hash. + + Args: + root (str): Path to store test data + + Returns: + str: md5 hash of created archive + """ + size = (64, 64) + folder_path = os.path.join(root, "enviroatlas_lotp") + + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + for prefix in tile_list: + for suffix, data_profile in layer_data_profiles.items(): + + img_path = os.path.join(folder_path, f"{prefix}_{suffix}.tif") + img_dir = os.path.dirname(img_path) + if not os.path.exists(img_dir): + os.makedirs(img_dir) + + data_profile["profile"]["height"] = size[0] + data_profile["profile"]["width"] = size[1] + data_profile["profile"]["transform"] = Affine( + 1.0, 0.0, 608170.0, 0.0, -1.0, 3381430.0 + ) + + write_data( + img_path, + data_profile["profile"], + data_profile["data_type"], + data_profile["vals"], + ) + + # build the spatial index + schema = { + "geometry": "Polygon", + "properties": { + "split": "str", + "naip": "str", + "nlcd": "str", + "roads": "str", + "water": "str", + "waterways": "str", + "waterbodies": "str", + "buildings": "str", + "lc": "str", + "prior_no_osm_no_buildings": "str", + "prior": "str", + }, + } + with fiona.open( + os.path.join(folder_path, "spatial_index.geojson"), + "w", + driver="GeoJSON", + crs="EPSG:3857", + schema=schema, + ) as dst: + for prefix in tile_list: + + img_path = os.path.join(folder_path, f"{prefix}_a_naip.tif") + with rasterio.open(img_path) as f: + geom = shapely.geometry.mapping(shapely.geometry.box(*f.bounds)) + geom = fiona.transform.transform_geom( + f.crs.to_string(), "EPSG:3857", geom + ) + + row = { + "geometry": geom, + "properties": { + "split": prefix.split("/")[0].replace("_tiles-debuffered", "") + }, + } + for suffix, data_profile in layer_data_profiles.items(): + key = suffix_to_key_map[suffix] + row["properties"][key] = f"{prefix}_{suffix}.tif" + dst.write(row) + + # Create archive + archive_path = os.path.join(root, "enviroatlas_lotp") + shutil.make_archive(archive_path, "zip", root_dir=root, base_dir="enviroatlas_lotp") + shutil.rmtree(folder_path) + md5: str = calculate_md5(archive_path + ".zip") + return md5 + + +if __name__ == "__main__": + md5_hash = generate_test_data(os.getcwd()) + print(md5_hash) diff --git a/tests/data/enviroatlas/enviroatlas_lotp.zip b/tests/data/enviroatlas/enviroatlas_lotp.zip new file mode 100644 index 00000000000..fbc869253d5 Binary files /dev/null and b/tests/data/enviroatlas/enviroatlas_lotp.zip differ diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py new file mode 100644 index 00000000000..1123f49f8e6 --- /dev/null +++ b/tests/datasets/test_enviroatlas.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import ( + BoundingBox, + EnviroAtlas, + IntersectionDataset, + UnionDataset, +) +from torchgeo.samplers import RandomGeoSampler + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestEnviroAtlas: + @pytest.fixture( + params=[ + (("naip", "prior", "lc"), False), + (("naip", "prior", "buildings", "lc"), True), + (("naip", "prior"), False), + ] + ) + def dataset( + self, + request: SubRequest, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + ) -> EnviroAtlas: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.enviroatlas, "download_url", download_url + ) + monkeypatch.setattr( # type: ignore[attr-defined] + EnviroAtlas, "md5", "071ec65c611e1d4915a5247bffb5ad87" + ) + monkeypatch.setattr( # type: ignore[attr-defined] + EnviroAtlas, + "url", + os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"), + ) + monkeypatch.setattr( # type: ignore[attr-defined] + EnviroAtlas, + "files", + ["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"], + ) + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + return EnviroAtlas( + root, + layers=request.param[0], + transforms=transforms, + prior_as_input=request.param[1], + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: EnviroAtlas) -> None: + sampler = RandomGeoSampler(dataset, size=16, length=32) + bb = next(iter(sampler)) + x = dataset[bb] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_and(self, dataset: EnviroAtlas) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: EnviroAtlas) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_extracted(self, dataset: EnviroAtlas) -> None: + EnviroAtlas(root=dataset.root, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + root = str(tmp_path) + shutil.copy( + os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"), root + ) + EnviroAtlas(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + EnviroAtlas(str(tmp_path), checksum=True) + + def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] + + def test_multiple_hits_query(self, dataset: EnviroAtlas) -> None: + ds = EnviroAtlas( + root=dataset.root, + splits=["pittsburgh_pa-2010_1m-train", "austin_tx-2012_1m-test"], + layers=dataset.layers, + ) + with pytest.raises( + IndexError, match="query: .* spans multiple tiles which is not valid" + ): + ds[dataset.bounds] + + def test_plot(self, dataset: EnviroAtlas) -> None: + sampler = RandomGeoSampler(dataset, size=16, length=1) + bb = next(iter(sampler)) + x = dataset[bb] + if "naip" not in dataset.layers or "lc" not in dataset.layers: + with pytest.raises(ValueError, match="The 'naip' and"): + dataset.plot(x) + else: + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"][0].clone() + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 4d9ddbcf97f..9250920c301 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -25,6 +25,7 @@ from .cv4a_kenya_crop_type import CV4AKenyaCropType from .cyclone import TropicalCycloneWindEstimation from .dfc2022 import DFC2022 +from .enviroatlas import EnviroAtlas from .etci2021 import ETCI2021 from .eurosat import EuroSAT from .fair1m import FAIR1M @@ -118,6 +119,7 @@ "COWCDetection", "CV4AKenyaCropType", "DFC2022", + "EnviroAtlas", "ETCI2021", "EuroSAT", "FAIR1M", diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 64a7a4fb4a1..f1df9564d97 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -6,7 +6,7 @@ import abc import os import sys -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence import fiona import numpy as np @@ -402,7 +402,7 @@ def __init__( self, root: str = "data", splits: Sequence[str] = ["de-train"], - layers: List[str] = ["naip-new", "lc"], + layers: Sequence[str] = ["naip-new", "lc"], transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, cache: bool = True, download: bool = False, @@ -427,6 +427,7 @@ def __init__( Raises: FileNotFoundError: if no files are found in ``root`` RuntimeError: if ``download=False`` but dataset is missing or checksum fails + AssertionError: if ``splits`` or ``layers`` are not valid """ for split in splits: assert split in self.splits diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py new file mode 100644 index 00000000000..22650af9284 --- /dev/null +++ b/torchgeo/datasets/enviroatlas.py @@ -0,0 +1,537 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""EnviroAtlas High-Resolution Land Cover datasets.""" + +import os +import sys +from typing import Any, Callable, Dict, Optional, Sequence + +import fiona +import matplotlib.pyplot as plt +import numpy as np +import pyproj +import rasterio +import rasterio.mask +import shapely.geometry +import shapely.ops +import torch +from matplotlib.colors import ListedColormap +from rasterio.crs import CRS +from torch import Tensor + +from .geo import GeoDataset +from .utils import BoundingBox, download_url, extract_archive + + +class EnviroAtlas(GeoDataset): + """EnviroAtlas dataset covering four cities with prior and weak input data layers. + + The `EnviroAtlas + `_ dataset contains NAIP aerial imagery, + NLCD land cover labels, OpenStreetMap roads, water, waterways, and waterbodies, + Microsoft building footprint labels, high-resolution land cover labels from the + EPA EnviroAtlas dataset, and high-resolution land cover prior layers. + + This dataset was organized to accompany the 2022 paper, `"Resolving label + uncertainty with implicit generative models" + `_. More details can be found at + https://github.com/estherrolf/qr_for_landcover. + + If you use this dataset in your research, please cite the following paper: + + * https://openreview.net/forum?id=AEa_UepnMDX + + .. versionadded:: 0.3 + """ + + url = "https://zenodo.org/record/5778193/files/enviroatlas_lotp.zip?download=1" + filename = "enviroatlas_lotp.zip" + md5 = "6142f8d1ebfc7f8ad888337f0683dc7a" + + crs = CRS.from_epsg(3857) + res = 1 + + valid_prior_layers = ["prior", "prior_no_osm_no_buildings"] + + valid_layers = [ + "naip", + "nlcd", + "roads", + "water", + "waterways", + "waterbodies", + "buildings", + "lc", + ] + valid_prior_layers + + cities = [ + "pittsburgh_pa-2010_1m", + "durham_nc-2012_1m", + "austin_tx-2012_1m", + "phoenix_az-2010_1m", + ] + splits = ( + [f"{state}-train" for state in cities[:1]] + + [f"{state}-val" for state in cities[:1]] + + [f"{state}-test" for state in cities] + + [f"{state}-val5" for state in cities] + ) + + # these are used to check the integrity of the dataset + files = [ + "austin_tx-2012_1m-test_tiles-debuffered", + "austin_tx-2012_1m-val5_tiles-debuffered", + "durham_nc-2012_1m-test_tiles-debuffered", + "durham_nc-2012_1m-val5_tiles-debuffered", + "phoenix_az-2010_1m-test_tiles-debuffered", + "phoenix_az-2010_1m-val5_tiles-debuffered", + "pittsburgh_pa-2010_1m-test_tiles-debuffered", + "pittsburgh_pa-2010_1m-train_tiles-debuffered", + "pittsburgh_pa-2010_1m-val5_tiles-debuffered", + "pittsburgh_pa-2010_1m-val_tiles-debuffered", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_a_naip.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_b_nlcd.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_c_roads.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d1_waterways.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d2_waterbodies.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif", + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif", # noqa: E501 + "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", # noqa: E501 + "spatial_index.geojson", + ] + + p_src_crs = pyproj.CRS("epsg:3857") + p_transformers = { + "epsg:26917": pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS("epsg:26917"), always_xy=True + ).transform, + "epsg:26918": pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS("epsg:26918"), always_xy=True + ).transform, + "epsg:26914": pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS("epsg:26914"), always_xy=True + ).transform, + "epsg:26912": pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS("epsg:26912"), always_xy=True + ).transform, + } + + # used to convert the 10 high-res classes labeled as [0, 10, 20, 30, 40, 52, 70, 80, + # 82, 91, 92] to sequential labels [0, ..., 10] + raw_enviroatlas_to_idx_map: "np.typing.NDArray[np.uint8]" = np.array( + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 3, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 4, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 5, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 6, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 7, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 9, + 10, + ], + dtype=np.uint8, + ) + + highres_classes = [ + "Unclassified", + "Water", + "Impervious Surface", + "Soil and Barren", + "Trees and Forest", + "Shrubs", + "Grass and Herbaceous", + "Agriculture", + "Orchards", + "Woody Wetlands", + "Emergent Wetlands", + ] + highres_cmap = ListedColormap( + [ + [1.00000000, 1.00000000, 1.00000000], + [0.00000000, 0.77254902, 1.00000000], + [0.61176471, 0.61176471, 0.61176471], + [1.00000000, 0.66666667, 0.00000000], + [0.14901961, 0.45098039, 0.00000000], + [0.80000000, 0.72156863, 0.47450980], + [0.63921569, 1.00000000, 0.45098039], + [0.86274510, 0.85098039, 0.22352941], + [0.67058824, 0.42352941, 0.15686275], + [0.72156863, 0.85098039, 0.92156863], + [0.42352941, 0.62352941, 0.72156863], + ] + ) + + def __init__( + self, + root: str = "data", + splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"], + layers: Sequence[str] = ["naip", "prior"], + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + prior_as_input: bool = False, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + splits: a list of strings in the format "{state}-{train,val,test}" + indicating the subset of data to use, for example "ny-train" + layers: a list containing a subset of ``valid_layers`` indicating which + layers to load + transforms: a function/transform that takes an input sample + and returns a transformed version + prior_as_input: bool describing whether the prior is used as an input (True) + or as supervision (False) + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + FileNotFoundError: if no files are found in ``root`` + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + AssertionError: if ``splits`` or ``layers`` are not valid + """ + for split in splits: + assert split in self.splits + assert all([layer in self.valid_layers for layer in layers]) + self.root = root + self.layers = layers + self.cache = cache + self.download = download + self.checksum = checksum + self.prior_as_input = prior_as_input + + self._verify() + + super().__init__(transforms) + + # Add all tiles into the index in epsg:3857 based on the included geojson + mint: float = 0 + maxt: float = sys.maxsize + with fiona.open( + os.path.join(root, "enviroatlas_lotp", "spatial_index.geojson"), "r" + ) as f: + for i, row in enumerate(f): + if row["properties"]["split"] in splits: + box = shapely.geometry.shape(row["geometry"]) + minx, miny, maxx, maxy = box.bounds + coords = (minx, maxx, miny, maxy, mint, maxt) + + self.index.insert( + i, + coords, + { + "naip": row["properties"]["naip"], + "nlcd": row["properties"]["nlcd"], + "roads": row["properties"]["roads"], + "water": row["properties"]["water"], + "waterways": row["properties"]["waterways"], + "waterbodies": row["properties"]["waterbodies"], + "buildings": row["properties"]["buildings"], + "lc": row["properties"]["lc"], + "prior_no_osm_no_buildings": row["properties"][ + "naip" + ].replace( + "a_naip", + "prior_from_cooccurrences_101_31_no_osm_no_buildings", + ), + "prior": row["properties"]["naip"].replace( + "a_naip", "prior_from_cooccurrences_101_31" + ), + }, + ) + + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + hits = self.index.intersection(tuple(query), objects=True) + filepaths = [hit.object for hit in hits] + + sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} + + if len(filepaths) == 0: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + elif len(filepaths) == 1: + filenames = filepaths[0] + query_geom_transformed = None # is set by the first layer + + minx, maxx, miny, maxy, mint, maxt = query + query_box = shapely.geometry.box(minx, miny, maxx, maxy) + + for layer in self.layers: + + fn = filenames[layer] + + with rasterio.open( + os.path.join(self.root, "enviroatlas_lotp", fn) + ) as f: + dst_crs = f.crs.to_string().lower() + + if query_geom_transformed is None: + query_box_transformed = shapely.ops.transform( + self.p_transformers[dst_crs], query_box + ).envelope + query_geom_transformed = shapely.geometry.mapping( + query_box_transformed + ) + + data, _ = rasterio.mask.mask( + f, [query_geom_transformed], crop=True, all_touched=True + ) + + if layer in [ + "naip", + "buildings", + "roads", + "waterways", + "waterbodies", + "water", + ]: + sample["image"].append(data) + elif layer in ["prior", "prior_no_osm_no_buildings"]: + if self.prior_as_input: + sample["image"].append(data) + else: + sample["mask"].append(data) + elif layer in ["lc"]: + data = self.raw_enviroatlas_to_idx_map[data] + sample["mask"].append(data) + else: + raise IndexError(f"query: {query} spans multiple tiles which is not valid") + + sample["image"] = np.concatenate(sample["image"], axis=0) + sample["mask"] = np.concatenate(sample["mask"], axis=0) + + sample["image"] = torch.from_numpy( # type: ignore[attr-defined] + sample["image"] + ) + sample["mask"] = torch.from_numpy(sample["mask"]) # type: ignore[attr-defined] + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + def exists(filename: str) -> bool: + return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename)) + + if all(map(exists, self.files)): + return + + # Check if the zip files have already been downloaded + if os.path.exists(os.path.join(self.root, self.filename)): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + download_url(self.url, self.root, filename=self.filename, md5=self.md5) + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive(os.path.join(self.root, self.filename)) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Note: only plots the "naip" and "lc" layers. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the NAIP layer isn't included in ``self.layers`` + """ + if "naip" not in self.layers or "lc" not in self.layers: + raise ValueError("The 'naip' and 'lc' layers must be included for plotting") + + image_layers = [] + mask_layers = [] + for layer in self.layers: + if layer in [ + "naip", + "buildings", + "roads", + "waterways", + "waterbodies", + "water", + ]: + image_layers.append(layer) + elif layer in ["prior", "prior_no_osm_no_buildings"]: + if self.prior_as_input: + image_layers.append(layer) + else: + mask_layers.append(layer) + elif layer in ["lc"]: + mask_layers.append(layer) + + naip_index = image_layers.index("naip") + lc_index = mask_layers.index("lc") + + image = np.rollaxis( + sample["image"][naip_index : naip_index + 3, :, :].numpy(), 0, 3 + ) + mask = sample["mask"][lc_index].numpy() + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow( + mask, vmin=0, vmax=10, cmap=self.highres_cmap, interpolation="none" + ) + axs[1].axis("off") + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow( + predictions, + vmin=0, + vmax=10, + cmap=self.highres_cmap, + interpolation="none", + ) + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig