Skip to content

Commit

Permalink
Add plotting method for CV4A Kenya Crop Type Dataset (microsoft#312)
Browse files Browse the repository at this point in the history
* Add plotting method for CV4A Kenya Crop Type Dataset

* remove print statements, still fix test_plot

* fix rgb plot test

* fix rgb plot test

* requested changes
  • Loading branch information
nilsleh authored Dec 29, 2021
1 parent 3d8f9f1 commit 9b95643
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/datasets/test_cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -113,3 +114,17 @@ def test_invalid_bands(self) -> None:

with pytest.raises(ValueError, match="is an invalid band name."):
CV4AKenyaCropType(bands=("foo", "bar"))

def test_plot(self, dataset: CV4AKenyaCropType) -> None:
dataset.plot(dataset[0], time_step=0, suptitle="Test")
plt.close()

sample = dataset[0]
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample, time_step=0, suptitle="Pred")
plt.close()

def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None:
dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(["B01"]))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], time_step=0, suptitle="Single Band")
68 changes: 68 additions & 0 deletions torchgeo/datasets/cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -102,6 +103,8 @@ class CV4AKenyaCropType(VisionDataset):
"CLD",
)

RGB_BANDS = ["B04", "B03", "B02"]

# Same for all tiles
tile_height = 3035
tile_width = 2016
Expand Down Expand Up @@ -400,3 +403,68 @@ def _download(self, api_key: Optional[str] = None) -> None:
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
for fn in [image_archive_path, target_archive_path]:
extract_archive(fn, self.root)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
time_step: time step at which to access image, beginning with 0
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

if "prediction" in sample:
prediction = sample["prediction"]
n_cols = 3
else:
n_cols = 2

image, mask = sample["image"], sample["mask"]

assert time_step <= image.shape[0] - 1, (
"The specified time step"
" does not exist, image only contains {} time"
" instances."
).format(image.shape[0])

image = image[time_step, rgb_indices, :, :]

fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5))

axs[0].imshow(image.permute(1, 2, 0))
axs[0].axis("off")
axs[1].imshow(mask)
axs[1].axis("off")

if "prediction" in sample:
axs[2].imshow(prediction)
axs[2].axis("off")
if show_titles:
axs[2].set_title("Prediction")

if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")

if suptitle is not None:
plt.suptitle(suptitle)

return fig

0 comments on commit 9b95643

Please sign in to comment.