diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index 30387142b53..ff2283c83d3 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -86,3 +87,13 @@ def test_mock_missing_module( match="scipy is not installed and is required to use this dataset", ): dataset[0] + + def test_plot(self, dataset: ADVANCE) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index f4f6d285f2b..ea19ad3a378 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -49,9 +50,12 @@ def dataset( ) root = str(tmp_path) transforms = nn.Identity() # type: ignore[attr-defined] + bands = BeninSmallHolderCashews.ALL_BANDS + return BeninSmallHolderCashews( root, transforms=transforms, + bands=bands, download=True, api_key="", checksum=True, @@ -87,3 +91,19 @@ def test_invalid_bands(self) -> None: with pytest.raises(ValueError, match="is an invalid band name."): BeninSmallHolderCashews(bands=("foo", "bar")) + + def test_plot(self, dataset: BeninSmallHolderCashews) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() + + def test_failed_plot(self, dataset: BeninSmallHolderCashews) -> None: + single_band_dataset = BeninSmallHolderCashews(root=dataset.root, bands=("B01",)) + with pytest.raises(ValueError, match="Dataset doesn't contain"): + x = single_band_dataset[0].copy() + single_band_dataset.plot(x, suptitle="Test") diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index d1f9ca80f68..84307a484a6 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -148,3 +149,13 @@ def test_not_downloaded(self, tmp_path: Path) -> None: "to automaticaly download the dataset." with pytest.raises(RuntimeError, match=err): BigEarthNet(str(tmp_path)) + + def test_plot(self, dataset: BigEarthNet) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index e5eee11b97b..c22819b0b7d 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -5,8 +5,9 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, cast +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -229,3 +230,43 @@ def _download(self) -> None: download_and_extract_archive( url, self.root, filename=filename, md5=md5 if self.checksum else None ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index fffcc7f67c2..3d968de12db 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -8,6 +8,7 @@ from functools import lru_cache from typing import Callable, Dict, Optional, Tuple +import matplotlib.pyplot as plt import numpy as np import rasterio import rasterio.features @@ -135,7 +136,7 @@ class BeninSmallHolderCashews(VisionDataset): "2020_10_30", ) - band_names = ( + ALL_BANDS = ( "B01", "B02", "B03", @@ -150,16 +151,17 @@ class BeninSmallHolderCashews(VisionDataset): "B12", "CLD", ) - - class_names = { - 0: "No data", - 1: "Well-managed planatation", - 2: "Poorly-managed planatation", - 3: "Non-planatation", - 4: "Residential", - 5: "Background", - 6: "Uncertain", - } + RGB_BANDS = ("B04", "B03", "B02") + + classes = [ + "No data", + "Well-managed planatation", + "Poorly-managed planatation", + "Non-planatation", + "Residential", + "Background", + "Uncertain", + ] # Same for all tiles tile_height = 1186 @@ -170,7 +172,7 @@ def __init__( root: str = "data", chip_size: int = 256, stride: int = 128, - bands: Tuple[str, ...] = band_names, + bands: Tuple[str, ...] = ALL_BANDS, transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, @@ -273,11 +275,11 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None: """ assert isinstance(bands, tuple), "The list of bands must be a tuple" for band in bands: - if band not in self.band_names: + if band not in self.ALL_BANDS: raise ValueError(f"'{band}' is an invalid band name.") @lru_cache(maxsize=128) - def _load_all_imagery(self, bands: Tuple[str, ...] = band_names) -> Tensor: + def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor: """Load all the imagery (across time) for the dataset. Optionally allows for subsetting of the bands that are loaded. @@ -410,3 +412,68 @@ def _download(self, api_key: Optional[str] = None) -> None: target_archive_path = os.path.join(self.root, self.target_meta["filename"]) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + time_step: int = 0, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + time_step: time step at which to access image, beginning with 0 + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are not included in ``self.bands`` + + .. versionadded:: 0.2 + """ + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + num_time_points = sample["image"].shape[0] + assert time_step < num_time_points + + image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3) + image = np.clip(image / 3000, 0, 1) + mask = sample["mask"].numpy() + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4)) + + axs[0].imshow(image) + axs[0].axis("off") + if show_titles: + axs[0].set_title(f"t={time_step}") + + axs[1].imshow(mask, vmin=0, vmax=6, interpolation="none") + axs[1].axis("off") + if show_titles: + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow(predictions, vmin=0, vmax=6, interpolation="none") + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 48d24766553..2ef2f8285ad 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -8,6 +8,7 @@ import os from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np import rasterio import torch @@ -118,74 +119,77 @@ class BigEarthNet(VisionDataset): """ - classes_43 = [ - "Agro-forestry areas", - "Airports", - "Annual crops associated with permanent crops", - "Bare rock", - "Beaches, dunes, sands", - "Broad-leaved forest", - "Burnt areas", - "Coastal lagoons", - "Complex cultivation patterns", - "Coniferous forest", - "Construction sites", - "Continuous urban fabric", - "Discontinuous urban fabric", - "Dump sites", - "Estuaries", - "Fruit trees and berry plantations", - "Green urban areas", - "Industrial or commercial units", - "Inland marshes", - "Intertidal flats", - "Land principally occupied by agriculture, with significant areas of " - "natural vegetation", - "Mineral extraction sites", - "Mixed forest", - "Moors and heathland", - "Natural grassland", - "Non-irrigated arable land", - "Olive groves", - "Pastures", - "Peatbogs", - "Permanently irrigated land", - "Port areas", - "Rice fields", - "Road and rail networks and associated land", - "Salines", - "Salt marshes", - "Sclerophyllous vegetation", - "Sea and ocean", - "Sparsely vegetated areas", - "Sport and leisure facilities", - "Transitional woodland/shrub", - "Vineyards", - "Water bodies", - "Water courses", - ] - classes_19 = [ - "Urban fabric", - "Industrial or commercial units", - "Arable land", - "Permanent crops", - "Pastures", - "Complex cultivation patterns", - "Land principally occupied by agriculture, with significant areas of natural " - "vegetation", - "Agro-forestry areas", - "Broad-leaved forest", - "Coniferous forest", - "Mixed forest", - "Natural grassland and sparsely vegetated areas", - "Moors, heathland and sclerophyllous vegetation", - "Transitional woodland, shrub", - "Beaches, dunes, sands", - "Inland wetlands", - "Coastal wetlands", - "Inland waters", - "Marine waters", - ] + class_sets = { + 19: [ + "Urban fabric", + "Industrial or commercial units", + "Arable land", + "Permanent crops", + "Pastures", + "Complex cultivation patterns", + "Land principally occupied by agriculture, with significant areas of" + " natural vegetation", + "Agro-forestry areas", + "Broad-leaved forest", + "Coniferous forest", + "Mixed forest", + "Natural grassland and sparsely vegetated areas", + "Moors, heathland and sclerophyllous vegetation", + "Transitional woodland, shrub", + "Beaches, dunes, sands", + "Inland wetlands", + "Coastal wetlands", + "Inland waters", + "Marine waters", + ], + 43: [ + "Agro-forestry areas", + "Airports", + "Annual crops associated with permanent crops", + "Bare rock", + "Beaches, dunes, sands", + "Broad-leaved forest", + "Burnt areas", + "Coastal lagoons", + "Complex cultivation patterns", + "Coniferous forest", + "Construction sites", + "Continuous urban fabric", + "Discontinuous urban fabric", + "Dump sites", + "Estuaries", + "Fruit trees and berry plantations", + "Green urban areas", + "Industrial or commercial units", + "Inland marshes", + "Intertidal flats", + "Land principally occupied by agriculture, with significant areas of" + " natural vegetation", + "Mineral extraction sites", + "Mixed forest", + "Moors and heathland", + "Natural grassland", + "Non-irrigated arable land", + "Olive groves", + "Pastures", + "Peatbogs", + "Permanently irrigated land", + "Port areas", + "Rice fields", + "Road and rail networks and associated land", + "Salines", + "Salt marshes", + "Sclerophyllous vegetation", + "Sea and ocean", + "Sparsely vegetated areas", + "Sport and leisure facilities", + "Transitional woodland/shrub", + "Vineyards", + "Water bodies", + "Water courses", + ], + } + label_converter = { 0: 0, 1: 0, @@ -220,6 +224,7 @@ class BigEarthNet(VisionDataset): 41: 18, 42: 18, } + splits_metadata = { "train": { "url": "https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/master/splits/train.csv?inline=false", # noqa: E501 @@ -285,7 +290,7 @@ def __init__( self.transforms = transforms self.download = download self.checksum = checksum - self.class2idx = {c: i for i, c in enumerate(self.classes_43)} + self.class2idx = {c: i for i, c in enumerate(self.class_sets[43])} self._verify() self.folders = self._load_folders() @@ -504,3 +509,71 @@ def _extract(self, filepath: str) -> None: """ if not filepath.endswith(".csv"): extract_archive(filepath) + + def _onehot_labels_to_names( + self, label_mask: "np.typing.NDArray[np.bool_]" + ) -> List[str]: + """Gets a list of class names given a label mask. + + Args: + label_mask: a boolean mask corresponding to a set of labels or predictions + + Returns + a list of class names corresponding to the input mask + """ + labels = [] + for i, mask in enumerate(label_mask): + if mask: + labels.append(self.class_sets[self.num_classes][i]) + return labels + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if ``self.bands`` is "s1" + + .. versionadded:: 0.2 + """ + if self.bands == "s2": + image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3) + image = np.clip(image / 2000, 0, 1) + elif self.bands == "all": + image = np.rollaxis(sample["image"][[5, 4, 3]].numpy(), 0, 3) + image = np.clip(image / 2000, 0, 1) + elif self.bands == "s1": + image = sample["image"][0].numpy() + + label_mask = sample["label"].numpy().astype(np.bool_) + labels = self._onehot_labels_to_names(label_mask) + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_mask = sample["prediction"].numpy().astype(np.bool_) + predictions = self._onehot_labels_to_names(prediction_mask) + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Labels: {', '.join(labels)}" + if showing_predictions: + title += f"\nPredictions: {', '.join(predictions)}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig