Skip to content

Commit

Permalink
Add plot method to ADVANCE, BeninSmallHolderCashews, and BigEarthNet …
Browse files Browse the repository at this point in the history
…datasets (microsoft#264)

* Adding plot to ADVANCE dataset

* Adding plot to BeninSmallHolderCashews

* Adding plot to BigEarthNet

* Doctstring adjustment for BigEarthNet plot

* Cleaning up ugly test

* Cleaning up bigearthnet classes

* Added time step plot to benin_cashews

* Formatting

* Update benin cashew tests

* Add S1 plotting and type to np.ndarray

* Trying numpy with quotes
  • Loading branch information
calebrob6 authored Dec 28, 2021
1 parent 29fab18 commit da0c668
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 84 deletions.
11 changes: 11 additions & 0 deletions tests/datasets/test_advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Any, Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -86,3 +87,13 @@ def test_mock_missing_module(
match="scipy is not installed and is required to use this dataset",
):
dataset[0]

def test_plot(self, dataset: ADVANCE) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()
20 changes: 20 additions & 0 deletions tests/datasets/test_benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -49,9 +50,12 @@ def dataset(
)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
bands = BeninSmallHolderCashews.ALL_BANDS

return BeninSmallHolderCashews(
root,
transforms=transforms,
bands=bands,
download=True,
api_key="",
checksum=True,
Expand Down Expand Up @@ -87,3 +91,19 @@ def test_invalid_bands(self) -> None:

with pytest.raises(ValueError, match="is an invalid band name."):
BeninSmallHolderCashews(bands=("foo", "bar"))

def test_plot(self, dataset: BeninSmallHolderCashews) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()

def test_failed_plot(self, dataset: BeninSmallHolderCashews) -> None:
single_band_dataset = BeninSmallHolderCashews(root=dataset.root, bands=("B01",))
with pytest.raises(ValueError, match="Dataset doesn't contain"):
x = single_band_dataset[0].copy()
single_band_dataset.plot(x, suptitle="Test")
11 changes: 11 additions & 0 deletions tests/datasets/test_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -148,3 +149,13 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
BigEarthNet(str(tmp_path))

def test_plot(self, dataset: BigEarthNet) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()
43 changes: 42 additions & 1 deletion torchgeo/datasets/advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import glob
import os
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, cast

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -229,3 +230,43 @@ def _download(self) -> None:
download_and_extract_archive(
url, self.root, filename=filename, md5=md5 if self.checksum else None
)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
label = cast(int, sample["label"].item())
label_class = self.classes[label]

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = cast(int, sample["prediction"].item())
prediction_class = self.classes[prediction]

fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(image)
ax.axis("off")
if show_titles:
title = f"Label: {label_class}"
if showing_predictions:
title += f"\nPrediction: {prediction_class}"
ax.set_title(title)

if suptitle is not None:
plt.suptitle(suptitle)
return fig
95 changes: 81 additions & 14 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import lru_cache
from typing import Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import rasterio
import rasterio.features
Expand Down Expand Up @@ -135,7 +136,7 @@ class BeninSmallHolderCashews(VisionDataset):
"2020_10_30",
)

band_names = (
ALL_BANDS = (
"B01",
"B02",
"B03",
Expand All @@ -150,16 +151,17 @@ class BeninSmallHolderCashews(VisionDataset):
"B12",
"CLD",
)

class_names = {
0: "No data",
1: "Well-managed planatation",
2: "Poorly-managed planatation",
3: "Non-planatation",
4: "Residential",
5: "Background",
6: "Uncertain",
}
RGB_BANDS = ("B04", "B03", "B02")

classes = [
"No data",
"Well-managed planatation",
"Poorly-managed planatation",
"Non-planatation",
"Residential",
"Background",
"Uncertain",
]

# Same for all tiles
tile_height = 1186
Expand All @@ -170,7 +172,7 @@ def __init__(
root: str = "data",
chip_size: int = 256,
stride: int = 128,
bands: Tuple[str, ...] = band_names,
bands: Tuple[str, ...] = ALL_BANDS,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
api_key: Optional[str] = None,
Expand Down Expand Up @@ -273,11 +275,11 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None:
"""
assert isinstance(bands, tuple), "The list of bands must be a tuple"
for band in bands:
if band not in self.band_names:
if band not in self.ALL_BANDS:
raise ValueError(f"'{band}' is an invalid band name.")

@lru_cache(maxsize=128)
def _load_all_imagery(self, bands: Tuple[str, ...] = band_names) -> Tensor:
def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor:
"""Load all the imagery (across time) for the dataset.
Optionally allows for subsetting of the bands that are loaded.
Expand Down Expand Up @@ -410,3 +412,68 @@ def _download(self, api_key: Optional[str] = None) -> None:
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
for fn in [image_archive_path, target_archive_path]:
extract_archive(fn, self.root)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
time_step: time step at which to access image, beginning with 0
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
Raises:
ValueError: if the RGB bands are not included in ``self.bands``
.. versionadded:: 0.2
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

num_time_points = sample["image"].shape[0]
assert time_step < num_time_points

image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3)
image = np.clip(image / 3000, 0, 1)
mask = sample["mask"].numpy()

num_panels = 2
showing_predictions = "prediction" in sample
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1

fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4))

axs[0].imshow(image)
axs[0].axis("off")
if show_titles:
axs[0].set_title(f"t={time_step}")

axs[1].imshow(mask, vmin=0, vmax=6, interpolation="none")
axs[1].axis("off")
if show_titles:
axs[1].set_title("Mask")

if showing_predictions:
axs[2].imshow(predictions, vmin=0, vmax=6, interpolation="none")
axs[2].axis("off")
if show_titles:
axs[2].set_title("Predictions")

if suptitle is not None:
plt.suptitle(suptitle)
return fig
Loading

0 comments on commit da0c668

Please sign in to comment.