From d1f63eaaa5a3fb5479408ce35339bac9a9c66343 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 15 Dec 2021 19:05:25 +0100 Subject: [PATCH 1/4] plotting method for GID15 dataset --- tests/datasets/test_gid15.py | 5 +++++ torchgeo/datasets/gid15.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index 9674a9847cd..cae6d32bbec 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -65,3 +66,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): GID15(str(tmp_path)) + + def test_plot(self, dataset: GID15) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 73dcfe7f7d9..e7808c61933 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -7,6 +7,7 @@ import os from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -239,3 +240,38 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, sample: Dict[str, Tensor], suptitle: Optional[str] = None + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + suptitle: optional suptitle to use for figure + + Returns; + a matplotlib Figure with the rendered sample + """ + if self.split != "test": + image, mask = sample["image"], sample["mask"] + ncols = 2 + else: + image = sample["image"] + ncols = 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) + + if self.split != "test": + axs[0].imshow(image.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(mask) + axs[1].axis("off") + else: + axs.imshow(image.permute(1, 2, 0)) + axs.axis("off") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig From f50d8e4aa1dd2509ba6ba4892c37ef82ee7ad836 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 15 Dec 2021 19:24:47 +0100 Subject: [PATCH 2/4] version added --- torchgeo/datasets/gid15.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index e7808c61933..47883fdc207 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -252,6 +252,8 @@ def plot( Returns; a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 """ if self.split != "test": image, mask = sample["image"], sample["mask"] From 1736d2be32a0b5a91ceaff485bfcd5a34ab9c7ef Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 15 Dec 2021 23:57:12 +0100 Subject: [PATCH 3/4] plot predictions --- tests/datasets/test_gid15.py | 11 +++++++++++ torchgeo/datasets/gid15.py | 20 ++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index cae6d32bbec..3c9fc661585 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -70,3 +70,14 @@ def test_not_downloaded(self, tmp_path: Path) -> None: def test_plot(self, dataset: GID15) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() + + if dataset.split != "test": + sample = dataset[0] + sample["predictions"] = torch.clone( # type: ignore[attr-defined] + sample["mask"] + ) + dataset.plot(sample, suptitle="Prediction") + else: + sample = dataset[0] + sample["predictions"] = torch.ones((1, 1)) # type: ignore[attr-defined] + dataset.plot(sample) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 47883fdc207..cdeef24396e 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -262,6 +262,13 @@ def plot( image = sample["image"] ncols = 1 + if "predictions" in sample and "mask" in sample: + ncols = 3 + pred = sample["predictions"] + elif "predictions" in sample and "mask" not in sample: + ncols = 2 + pred = sample["predictions"] + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) if self.split != "test": @@ -269,9 +276,18 @@ def plot( axs[0].axis("off") axs[1].imshow(mask) axs[1].axis("off") + if "predictions" in sample: + axs[2].imshow(pred) + axs[2].axis("off") else: - axs.imshow(image.permute(1, 2, 0)) - axs.axis("off") + if "predictions" in sample: + axs[0].imshow(image.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(pred) + axs[1].axis("off") + else: + axs.imshow(image.permute(1, 2, 0)) + axs.axis("off") if suptitle is not None: plt.suptitle(suptitle) From 4509d69cb0b287b0176a965021a0635d00ffcfca Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 16 Dec 2021 21:35:12 +0100 Subject: [PATCH 4/4] prediction in sample --- torchgeo/datasets/gid15.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index cdeef24396e..684a09e8081 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -262,11 +262,8 @@ def plot( image = sample["image"] ncols = 1 - if "predictions" in sample and "mask" in sample: - ncols = 3 - pred = sample["predictions"] - elif "predictions" in sample and "mask" not in sample: - ncols = 2 + if "predictions" in sample: + ncols += 1 pred = sample["predictions"] fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10))