From 37fde05db83c227bbece5075cbe01ec0bd428473 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Thu, 30 Dec 2021 19:05:42 +0100 Subject: [PATCH] Add plot method to Levir, and change directory path (#335) * add plotting method * implement test * axis off * prediction flag * requested changes * indexing fix --- tests/data/levircd/LEVIR-CD+.zip | Bin 3798 -> 4790 bytes tests/datasets/test_levircd.py | 12 +++++- torchgeo/datasets/levircd.py | 69 +++++++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/tests/data/levircd/LEVIR-CD+.zip b/tests/data/levircd/LEVIR-CD+.zip index b51dc099207cccc0e2eac3e6f2e1eca07c2164f8..9a5fa4e1a7c02fc6a6e81f230915ba98c735c6b9 100644 GIT binary patch literal 4790 zcmc&%TS!zv7@oc8?wXO8bu|pKP@|&N5`s+1zca4iMOwE{bvU10z#GK2L6yF-0CWX}&B(kX*<6F2EWkd3s~ zQW<3R4*x*^nRwx+DDV~f2!OxlehYv8G&bMe)it;HeZXU%87X^bD=&HarT3Pg^owlKPzsOooqX?_ukN9Pb&$uxZG2tvrW6H3Q(f=5Hy}v*0_CHV`m06 z!Vx)wG{QdL+{XspmRc2CS`l0`g)0`pAxmnKEcb=8eA6A*M@M_z1OM#}Md7i+f`Y4c zzU4DeEaHRIPeX(qM|FB$1m^npufpLKsPeeV*}XUksk>*FAv|^VgA>h~^WZ zp_#}q1(xuWjWJ;ot&?IRvjG_%0rEc?ib;s%c4tx~c#B2CzCh5)rLmZ_GjIb-gm@{q zl7=gYYqrq1a#K>&2tHtquowBVG#S(8i(xb-dk|Bs8BzqUS>_AX_+Nan*}V>V?gWQm z{np@}ObX?=q8tinuwvQ)`sTQbsg#3@a%mhn4u01RjwZ5YO4TV?PQ~dEaw=QV$byUl zjRK3*8ff7}5H{sRBQ{f`^jgrlv520G2$YKwBdaK;RohKOD*2Ju0r3n6H%O3dMI*x^ zc+@P5u*PZtSFA}k(tn2%PE`ZGU&=5S~U=PWPt(exCWvE0kXX$+cltgq$T5QF%#{zV8%Pqem z+za5qmk*K`ThYibDF-^a7GDZ9!jfEDi5TkBVH6Yj)+r)}sv5l{a;-aZd)lzCPDHL4zst6(_ z6hYi5MWGA3bmh*CMRZ}iwO|Q|q6jV&MO=7J=FXkD_s-lrY7%Z9i~HUG|7Xsblj?Q3 z>j=J9rXnYbKeP`IX(72xHrICQWPrFiC`|0Cs zVQ6r0_2-X?fOGj?>sv>pY2nMr9p|oC|K`Q~-Tkiv-TN9`PZs}dzukJK|8&=b$-cl* zf>x-C^CeX5(N&zSi3++N$4r6xSQbedW;;B)!XY8$an=7 zUb-VY?)yHtj(0wa&1~D0ooxt(u6KsLuPztDbKT)v^Uof?X`i2t3R&8T;5pVz1dZ~r zD#`?nG71^&w2QJw6=kALh;l0+(1~e#OZQ|IQ6_Cf@F<&WC4xqoGezmeO?4%GX`CjB zB`GOO-X_fV2le}fK4C+l6^azGPlQYqVNnz;>7rEHC`!uU%TgN|@-^}6X=trN@o?TM z(1_S7)(eCxv#er++p4H8%8FE!N>l3=Ye>`3E)*Q;&aB7>uMzp0f{L1_>2Pjw(LAkl zMH3uG_}W0eCGcCDGCXw3Fm=yVFGpEt3=!4EP^^;xP&#cOP&IEbAR$k0OdEl$5CDtQ zhg7F(fDo~D4$U&v$<_j}I8I2hG)e13z_6HPo?)@*0fzH|6h<{#4`v2K>q{Ia*p?h_ zba~1_4mdU2Nqzf)N+pt7ww~NfC0c&ZkXV@j!*krmvh@^aFtpqPGi`Av@*)K!p6ND7 z>zOPf=LDoyae!=AA?5wYOBfO@N8O-OJ8yx7Jwz%`&I%e>BRWf{7$uRevGvGf>M+9v z6H(YZYvRT|%(bOvlw5fNenSE;bgou3N*Z@>?5U#rzXGi?=W6CnFLQ1v8>P;2T>E{D ksnv~=R*vf<0_}gfekU?{ None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): LEVIRCDPlus(str(tmp_path)) + + def test_plot(self, dataset: LEVIRCDPlus) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 24d76ca6594..9098a23b585 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -7,6 +7,7 @@ import os from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -47,6 +48,7 @@ class LEVIRCDPlus(VisionDataset): url = "https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81" md5 = "1adf156f628aa32fb2e8fe6cada16c04" filename = "LEVIR-CD+.zip" + directory = "LEVIR-CD+" splits = ["train", "test"] def __init__( @@ -88,7 +90,7 @@ def __init__( + "You can use download=True to download it" ) - self.files = self._load_files(self.root, self.split) + self.files = self._load_files(self.root, self.directory, self.split) def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. @@ -120,23 +122,26 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files( + self, root: str, directory: str, split: str + ) -> List[Dict[str, str]]: """Return the paths of the files in the dataset. Args: root: root dir of dataset + directory: sub directory LEVIR-CD+ split: subset of dataset, one of [train, test] Returns: list of dicts containing paths for each pair of image1, image2, mask """ files = [] - images = glob.glob(os.path.join(root, split, "A", "*.png")) + images = glob.glob(os.path.join(root, directory, split, "A", "*.png")) images = sorted([os.path.basename(image) for image in images]) for image in images: - image1 = os.path.join(root, split, "A", image) - image2 = os.path.join(root, split, "B", image) - mask = os.path.join(root, split, "label", image) + image1 = os.path.join(root, directory, split, "A", image) + image2 = os.path.join(root, directory, split, "B", image) + mask = os.path.join(root, directory, split, "label", image) files.append(dict(image1=image1, image2=image2, mask=mask)) return files @@ -181,7 +186,7 @@ def _check_integrity(self) -> bool: True if the dataset directories and split files are found, else False """ for filename in self.splits: - filepath = os.path.join(self.root, filename) + filepath = os.path.join(self.root, self.directory, filename) if not os.path.exists(filepath): return False return True @@ -202,3 +207,53 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"]) + ncols = 3 + + if "prediction" in sample: + prediction = sample["prediction"] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].axis("off") + axs[2].imshow(mask) + axs[2].axis("off") + + if "prediction" in sample: + axs[3].imshow(prediction) + axs[3].axis("off") + if show_titles: + axs[3].set_title("Prediction") + + if show_titles: + axs[0].set_title("Image 1") + axs[1].set_title("Image 2") + axs[2].set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig