Skip to content

Commit

Permalink
Add plotting method and band selection to Sen12ms, replacing #320 (#338)
Browse files Browse the repository at this point in the history
* add plot method to sen12

* tuple
  • Loading branch information
nilsleh authored Dec 31, 2021
1 parent 9e07927 commit 7d90045
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 11 deletions.
19 changes: 19 additions & 0 deletions tests/datasets/test_sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -82,3 +83,21 @@ def test_band_subsets(self) -> None:
ds = SEN12MS(root, bands=bands, checksum=False)
x = ds[0]["image"]
assert x.shape[0] == len(bands)

def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
SEN12MS(bands=("OK", "BK"))

def test_plot(self, dataset: SEN12MS) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()

sample = dataset[0]
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample, suptitle="prediction")
plt.close()

def test_plot_rgb(self, dataset: SEN12MS) -> None:
dataset = SEN12MS(root=dataset.root, bands=("B03",))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle="Single Band")
149 changes: 138 additions & 11 deletions torchgeo/datasets/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
"""SEN12MS dataset."""

import os
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from torch import Tensor

from .geo import VisionDataset
from .utils import check_integrity
from .utils import check_integrity, percentile_normalization


class SEN12MS(VisionDataset):
Expand Down Expand Up @@ -62,13 +63,63 @@ class SEN12MS(VisionDataset):
This download will likely take several hours.
""" # noqa: E501

BAND_SETS: Dict[str, List[int]] = {
"all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"s1": [0, 1],
"s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"s2-reduced": [3, 4, 5, 9, 12, 13],
BAND_SETS: Dict[str, Tuple[str, ...]] = {
"all": (
"VV",
"VH",
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
),
"s1": ("VV", "VH"),
"s2-all": (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
),
"s2-reduced": ("B02", "B03", "B04", "B08", "B10", "B11"),
}

band_names = (
"VV",
"VH",
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
)

RGB_BANDS = ["B04", "B03", "B02"]

filenames = [
"ROIs1158_spring_lc.tar.gz",
"ROIs1158_spring_s1.tar.gz",
Expand Down Expand Up @@ -114,7 +165,7 @@ def __init__(
self,
root: str = "data",
split: str = "train",
bands: List[int] = BAND_SETS["all"],
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
checksum: bool = False,
) -> None:
Expand All @@ -128,7 +179,7 @@ def __init__(
Args:
root: root directory where dataset can be found
split: one of "train" or "test"
bands: a list of band indices to use where the indices correspond to the
bands: a sequence of band indices to use where the indices correspond to the
array index of combined Sentinel 1 and Sentinel 2
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Expand All @@ -140,9 +191,14 @@ def __init__(
"""
assert split in ["train", "test"]

self._validate_bands(bands)
self.band_indices = torch.tensor( # type: ignore[attr-defined]
[self.band_names.index(b) for b in bands]
).long()
self.bands = bands

self.root = root
self.split = split
self.bands = torch.tensor(bands).long() # type: ignore[attr-defined]
self.transforms = transforms
self.checksum = checksum

Expand Down Expand Up @@ -173,7 +229,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:

image = torch.cat(tensors=[s1, s2], dim=0) # type: ignore[attr-defined]
image = torch.index_select( # type: ignore[attr-defined]
image, dim=0, index=self.bands
image, dim=0, index=self.band_indices
)

sample: Dict[str, Tensor] = {"image": image, "mask": lc}
Expand Down Expand Up @@ -216,6 +272,21 @@ def _load_raster(self, filename: str, source: str) -> Tensor:
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor

def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.
Args:
bands: user-provided sequence of bands to load
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided
"""
assert isinstance(bands, tuple), "'bands' must be a sequence"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")

def _check_integrity_light(self) -> bool:
"""Checks the integrity of the dataset structure.
Expand All @@ -239,3 +310,59 @@ def _check_integrity(self) -> bool:
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

image, mask = sample["image"][rgb_indices].numpy(), sample["mask"][0]
image = percentile_normalization(image)
ncols = 2

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"][0]
ncols += 1

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))

axs[0].imshow(np.transpose(image, (1, 2, 0)))
axs[0].axis("off")
axs[1].imshow(mask)
axs[1].axis("off")

if showing_predictions:
axs[2].imshow(prediction)
axs[2].axis("off")

if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")
if showing_predictions:
axs[2].set_title("Prediction")

if suptitle is not None:
plt.suptitle(suptitle)

return fig

0 comments on commit 7d90045

Please sign in to comment.