Skip to content

Commit

Permalink
Add plot method to Levir, and change directory path (microsoft#335)
Browse files Browse the repository at this point in the history
* add plotting method

* implement test

* axis off

* prediction flag

* requested changes

* indexing fix
  • Loading branch information
nilsleh authored Dec 30, 2021
1 parent 9b95643 commit 37fde05
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
Binary file modified tests/data/levircd/LEVIR-CD+.zip
Binary file not shown.
12 changes: 11 additions & 1 deletion tests/datasets/test_levircd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand All @@ -31,7 +32,7 @@ def dataset(
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
md5 = "b61c300e9fd7146eb2c8e2512c0e9d39"
md5 = "1adf156f628aa32fb2e8fe6cada16c04"
monkeypatch.setattr(LEVIRCDPlus, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip")
monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined]
Expand Down Expand Up @@ -60,3 +61,12 @@ def test_invalid_split(self) -> None:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
LEVIRCDPlus(str(tmp_path))

def test_plot(self, dataset: LEVIRCDPlus) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()

sample = dataset[0]
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample, suptitle="Prediction")
plt.close()
69 changes: 62 additions & 7 deletions torchgeo/datasets/levircd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from typing import Callable, Dict, List, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -47,6 +48,7 @@ class LEVIRCDPlus(VisionDataset):
url = "https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81"
md5 = "1adf156f628aa32fb2e8fe6cada16c04"
filename = "LEVIR-CD+.zip"
directory = "LEVIR-CD+"
splits = ["train", "test"]

def __init__(
Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(
+ "You can use download=True to download it"
)

self.files = self._load_files(self.root, self.split)
self.files = self._load_files(self.root, self.directory, self.split)

def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Expand Down Expand Up @@ -120,23 +122,26 @@ def __len__(self) -> int:
"""
return len(self.files)

def _load_files(self, root: str, split: str) -> List[Dict[str, str]]:
def _load_files(
self, root: str, directory: str, split: str
) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
directory: sub directory LEVIR-CD+
split: subset of dataset, one of [train, test]
Returns:
list of dicts containing paths for each pair of image1, image2, mask
"""
files = []
images = glob.glob(os.path.join(root, split, "A", "*.png"))
images = glob.glob(os.path.join(root, directory, split, "A", "*.png"))
images = sorted([os.path.basename(image) for image in images])
for image in images:
image1 = os.path.join(root, split, "A", image)
image2 = os.path.join(root, split, "B", image)
mask = os.path.join(root, split, "label", image)
image1 = os.path.join(root, directory, split, "A", image)
image2 = os.path.join(root, directory, split, "B", image)
mask = os.path.join(root, directory, split, "label", image)
files.append(dict(image1=image1, image2=image2, mask=mask))
return files

Expand Down Expand Up @@ -181,7 +186,7 @@ def _check_integrity(self) -> bool:
True if the dataset directories and split files are found, else False
"""
for filename in self.splits:
filepath = os.path.join(self.root, filename)
filepath = os.path.join(self.root, self.directory, filename)
if not os.path.exists(filepath):
return False
return True
Expand All @@ -202,3 +207,53 @@ def _download(self) -> None:
filename=self.filename,
md5=self.md5 if self.checksum else None,
)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"])
ncols = 3

if "prediction" in sample:
prediction = sample["prediction"]
ncols += 1

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))

axs[0].imshow(image1.permute(1, 2, 0))
axs[0].axis("off")
axs[1].imshow(image2.permute(1, 2, 0))
axs[1].axis("off")
axs[2].imshow(mask)
axs[2].axis("off")

if "prediction" in sample:
axs[3].imshow(prediction)
axs[3].axis("off")
if show_titles:
axs[3].set_title("Prediction")

if show_titles:
axs[0].set_title("Image 1")
axs[1].set_title("Image 2")
axs[2].set_title("Mask")

if suptitle is not None:
plt.suptitle(suptitle)

return fig

0 comments on commit 37fde05

Please sign in to comment.