diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index 9674a9847cd..3c9fc661585 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -65,3 +66,18 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): GID15(str(tmp_path)) + + def test_plot(self, dataset: GID15) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + if dataset.split != "test": + sample = dataset[0] + sample["predictions"] = torch.clone( # type: ignore[attr-defined] + sample["mask"] + ) + dataset.plot(sample, suptitle="Prediction") + else: + sample = dataset[0] + sample["predictions"] = torch.ones((1, 1)) # type: ignore[attr-defined] + dataset.plot(sample) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 73dcfe7f7d9..684a09e8081 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -7,6 +7,7 @@ import os from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -239,3 +240,53 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, sample: Dict[str, Tensor], suptitle: Optional[str] = None + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + suptitle: optional suptitle to use for figure + + Returns; + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + if self.split != "test": + image, mask = sample["image"], sample["mask"] + ncols = 2 + else: + image = sample["image"] + ncols = 1 + + if "predictions" in sample: + ncols += 1 + pred = sample["predictions"] + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) + + if self.split != "test": + axs[0].imshow(image.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(mask) + axs[1].axis("off") + if "predictions" in sample: + axs[2].imshow(pred) + axs[2].axis("off") + else: + if "predictions" in sample: + axs[0].imshow(image.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(pred) + axs[1].axis("off") + else: + axs.imshow(image.permute(1, 2, 0)) + axs.axis("off") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig