diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 94c7964340d..68b0a8aafc0 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -82,3 +83,21 @@ def test_band_subsets(self) -> None: ds = SEN12MS(root, bands=bands, checksum=False) x = ds[0]["image"] assert x.shape[0] == len(bands) + + def test_invalid_bands(self) -> None: + with pytest.raises(ValueError): + SEN12MS(bands=("OK", "BK")) + + def test_plot(self, dataset: SEN12MS) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, suptitle="prediction") + plt.close() + + def test_plot_rgb(self, dataset: SEN12MS) -> None: + dataset = SEN12MS(root=dataset.root, bands=("B03",)) + with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): + dataset.plot(dataset[0], suptitle="Single Band") diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 8ff0e9a44b8..5e142382d15 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -4,15 +4,16 @@ """SEN12MS dataset.""" import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, Optional, Sequence, Tuple +import matplotlib.pyplot as plt import numpy as np import rasterio import torch from torch import Tensor from .geo import VisionDataset -from .utils import check_integrity +from .utils import check_integrity, percentile_normalization class SEN12MS(VisionDataset): @@ -62,13 +63,63 @@ class SEN12MS(VisionDataset): This download will likely take several hours. """ # noqa: E501 - BAND_SETS: Dict[str, List[int]] = { - "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - "s1": [0, 1], - "s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - "s2-reduced": [3, 4, 5, 9, 12, 13], + BAND_SETS: Dict[str, Tuple[str, ...]] = { + "all": ( + "VV", + "VH", + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ), + "s1": ("VV", "VH"), + "s2-all": ( + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ), + "s2-reduced": ("B02", "B03", "B04", "B08", "B10", "B11"), } + band_names = ( + "VV", + "VH", + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ) + + RGB_BANDS = ["B04", "B03", "B02"] + filenames = [ "ROIs1158_spring_lc.tar.gz", "ROIs1158_spring_s1.tar.gz", @@ -114,7 +165,7 @@ def __init__( self, root: str = "data", split: str = "train", - bands: List[int] = BAND_SETS["all"], + bands: Sequence[str] = BAND_SETS["all"], transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: @@ -128,7 +179,7 @@ def __init__( Args: root: root directory where dataset can be found split: one of "train" or "test" - bands: a list of band indices to use where the indices correspond to the + bands: a sequence of band indices to use where the indices correspond to the array index of combined Sentinel 1 and Sentinel 2 transforms: a function/transform that takes input sample and its target as entry and returns a transformed version @@ -140,9 +191,14 @@ def __init__( """ assert split in ["train", "test"] + self._validate_bands(bands) + self.band_indices = torch.tensor( # type: ignore[attr-defined] + [self.band_names.index(b) for b in bands] + ).long() + self.bands = bands + self.root = root self.split = split - self.bands = torch.tensor(bands).long() # type: ignore[attr-defined] self.transforms = transforms self.checksum = checksum @@ -173,7 +229,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: image = torch.cat(tensors=[s1, s2], dim=0) # type: ignore[attr-defined] image = torch.index_select( # type: ignore[attr-defined] - image, dim=0, index=self.bands + image, dim=0, index=self.band_indices ) sample: Dict[str, Tensor] = {"image": image, "mask": lc} @@ -216,6 +272,21 @@ def _load_raster(self, filename: str, source: str) -> Tensor: tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] return tensor + def _validate_bands(self, bands: Sequence[str]) -> None: + """Validate list of bands. + + Args: + bands: user-provided sequence of bands to load + + Raises: + AssertionError: if ``bands`` is not a sequence + ValueError: if an invalid band name is provided + """ + assert isinstance(bands, tuple), "'bands' must be a sequence" + for band in bands: + if band not in self.band_names: + raise ValueError(f"'{band}' is an invalid band name.") + def _check_integrity_light(self) -> bool: """Checks the integrity of the dataset structure. @@ -239,3 +310,59 @@ def _check_integrity(self) -> bool: if not check_integrity(filepath, md5 if self.checksum else None): return False return True + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + image, mask = sample["image"][rgb_indices].numpy(), sample["mask"][0] + image = percentile_normalization(image) + ncols = 2 + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = sample["prediction"][0] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(np.transpose(image, (1, 2, 0))) + axs[0].axis("off") + axs[1].imshow(mask) + axs[1].axis("off") + + if showing_predictions: + axs[2].imshow(prediction) + axs[2].axis("off") + + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + if showing_predictions: + axs[2].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig