From 3ea316e66ad4eda1875193b6dcc09f12914b3521 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Tue, 21 Dec 2021 09:40:39 +0100 Subject: [PATCH 1/6] add plotting method --- tests/data/levircd/LEVIR-CD+.zip | Bin 3798 -> 5370 bytes torchgeo/datasets/levircd.py | 41 +++++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/tests/data/levircd/LEVIR-CD+.zip b/tests/data/levircd/LEVIR-CD+.zip index b51dc099207cccc0e2eac3e6f2e1eca07c2164f8..2449e91702cb3b082a0938fd053fd34b3a466a9b 100644 GIT binary patch literal 5370 zcmc&%TS!z<6utA%nbCq!n;r&^F&a^6Wdxa&*9R58CbOa_aD0@7m^2DOPz@D9M3@v& zA4N1ER3Cmw|MsCo{TO|PkOfr5gFgwCEjr?Y3cXWuiUPDkcC=Vz_G*E(n4HGu|) zvp|E@*LdeZ>`VNSqxrPP`a}C#H`Uf{C_mV?RC7(;8V|l@ANW6bE?2WTKRGll_6SIE z3KDQk5CM%uoF4TjRJKZjIM5v$FbbWSstA6zg6+vJNq{Y)V0(id-Mt2Q*DnM7g-ivv z=GM9pcr8=3#2S3_*0TP-lQE%>F^co-V?yzD$+ytQ55uz^?d>x^fAo9oZ*KTr+x#n^ zeK~Q>zP_ww#pKZS9WOoA+w&cdCx4e*^&M$By!-B8v!_W5G`QSjL(_FTSv^pb;Tcrh znMSo+{#AQ4gVn;}+>fe-eIcc7jbU1)ifK&smhx@ac0*0sM>pZ@GHZ`vZ}Q` zNLkpEmg$CNr7D)yB4A6)^m1fb5rYRMRPWdVSf-~~tgx&wv&Di%Ayh2vSIaKL+2E-@ z!!a0HmSZ;=*GJ-_xy#W_3=_8kqr;3VRX|6F0yyAf!m%iWH?4BTvB+e(a0YO>uwO%} z8y~$X`v&)Ttp)NtJX?T^jCD||Hhs(HSuMDPtA#!7T73F0f@=xmH`Sv%PaPoTU>&So zW7Yl_*KGCzhdhykf3W6sc$%bXX+D?ZLPFw>{7v2%u~eqX|%nBN7B zS%Mcd-zggOULo_mLB)z-v=55CB$+$RR=QI&IkwhBCGD3ot({*N%eW7_;D9^K;4&=fU z8}^Vv`HJexvL$qiiPTUL@RE5TlkyeSdq^HMNyvAm@C?}SwT6t!S0p^4z=n(?-lJSe z{_c}=4!S5MZ9!jfEDi5TkBVH6Yj)+r)}sv5l{a;-aZd)lzCPDHL4zst6(_ z6hYi5MWGA3bmh*CMRZ}iwO|Q|q6jV&MO=7J=FXkD_s-lrY7%Z9i~HUG|7Xsblj?Q3 z>j=J9rXnYbKeP`IX(72xHrICQWPrFiC`|0Cs zVQ6r0_2-X?fOGj?>sv>pY2nMr9p|oC|K`Q~-Tkiv-TN9`PZs}dzukJK|8&=b$-cl* zf>x-C^CeX5(N&zSi3++N$4r6xSQbedW;;B)!XY8$an=7 zUb-VY?)yHtj(0wa&1~D0ooxt(u6KsLuPztDbKT)v^Uof?X`i2t3R&8T;5pVz1dZ~r zD#`?nG71^&w2QJw6=kALh;l0+(1~e#OZQ|IQ6_Cf@F<&WC4xqoGezmeO?4%GX`CjB zB`GOO-X_fV2le}fK4C+l6^azGPlQYqVNnz;>7rEHC`!uU%TgN|@-^}6X=trN@o?TM z(1_S7)(eCxv#er++p4H8%8FE!N>l3=Ye>`3E)*Q;&aB7>uMzp0f{L1_>2Pjw(LAkl zMH3uG_}W0eCGcCDGCXw3Fm=yVFGpEt3=!4EP^^;xP&#cOP&IEbAR$k0OdEl$5CDtQ zhg7F(fDo~D4$U&v$<_j}I8I2hG)e13z_6HPo?)@*0fzH|6h<{#4`v2K>q{Ia*p?h_ zba~1_4mdU2Nqzf)N+pt7ww~NfC0c&ZkXV@j!*krmvh@^aFtpqPGi`Av@*)K!p6ND7 z>zOPf=LDoyae!=AA?5wYOBfO@N8O-OJ8yx7Jwz%`&I%e>BRWf{7$uRevGvGf>M+9v z6H(YZYvRT|%(bOvlw5fNenSE;bgou3N*Z@>?5U#rzXGi?=W6CnFLQ1v8>P;2T>E{D ksnv~=R*vf<0_}gfekU?{ Dict[str, Tensor]: """Return an index within the dataset. @@ -120,23 +122,26 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files( + self, root: str, directory: str, split: str + ) -> List[Dict[str, str]]: """Return the paths of the files in the dataset. Args: root: root dir of dataset + directory: sub directory LEVIR-CD+ split: subset of dataset, one of [train, test] Returns: list of dicts containing paths for each pair of image1, image2, mask """ files = [] - images = glob.glob(os.path.join(root, split, "A", "*.png")) + images = glob.glob(os.path.join(root, directory, split, "A", "*.png")) images = sorted([os.path.basename(image) for image in images]) for image in images: - image1 = os.path.join(root, split, "A", image) - image2 = os.path.join(root, split, "B", image) - mask = os.path.join(root, split, "label", image) + image1 = os.path.join(root, directory, split, "A", image) + image2 = os.path.join(root, directory, split, "B", image) + mask = os.path.join(root, directory, split, "label", image) files.append(dict(image1=image1, image2=image2, mask=mask)) return files @@ -181,7 +186,7 @@ def _check_integrity(self) -> bool: True if the dataset directories and split files are found, else False """ for filename in self.splits: - filepath = os.path.join(self.root, filename) + filepath = os.path.join(self.root, self.directory, filename) if not os.path.exists(filepath): return False return True @@ -202,3 +207,25 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, sample: Dict[str, Tensor], suptitle: Optional[str] = None + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + suptitle: optional suptitle to use for figure + + Returns; + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image, mask = sample["image"], sample["mask"] + + fig, axs = plt.subplots(1, 2, figsize=(10, 10)) + axs[0].imshow(image.permute(1, 2, 0)) + axs[1].imshow(mask) + + return fig From f8cd41df3ca953808f0dcb2618bb1b284ebc3410 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 22 Dec 2021 11:05:38 +0100 Subject: [PATCH 2/6] implement test --- tests/data/levircd/LEVIR-CD+.zip | Bin 5370 -> 4790 bytes tests/datasets/test_levircd.py | 10 +++++++ torchgeo/datasets/levircd.py | 45 ++++++++++++++++++++++++++----- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/tests/data/levircd/LEVIR-CD+.zip b/tests/data/levircd/LEVIR-CD+.zip index 2449e91702cb3b082a0938fd053fd34b3a466a9b..9a5fa4e1a7c02fc6a6e81f230915ba98c735c6b9 100644 GIT binary patch literal 4790 zcmc&%TS!zv7@oc8?wXO8bu|pKP@|&N5`s+1zca4iMOwE{bvU10z#GK2L6yF-0CWX}&B(kX*<6F2EWkd3s~ zQW<3R4*x*^nRwx+DDV~f2!OxlehYv8G&bMe)it;HeZXU%87X^bD=&HarT3Pg^owlKPzsOooqX?_ukN9Pb&$uxZG2tvrW6H3Q(f=5Hy}v*0_CHV`m06 z!Vx)wG{QdL+{XspmRc2CS`l0`g)0`pAxmnKEcb=8eA6A*M@M_z1OM#}Md7i+f`Y4c zzU4DeEaHRIPeX(qM|FB$1m^npufpLKsPeeV*}XUksk>*FAv|^VgA>h~^WZ zp_#}q1(xuWjWJ;ot&?IRvjG_%0rEc?ib;s%c4tx~c#B2CzCh5)rLmZ_GjIb-gm@{q zl7=gYYqrq1a#K>&2tHtquowBVG#S(8i(xb-dk|Bs8BzqUS>_AX_+Nan*}V>V?gWQm z{np@}ObX?=q8tinuwvQ)`sTQbsg#3@a%mhn4u01RjwZ5YO4TV?PQ~dEaw=QV$byUl zjRK3*8ff7}5H{sRBQ{f`^jgrlv520G2$YKwBdaK;RohKOD*2Ju0r3n6H%O3dMI*x^ zc+@P5u*PZtSFA}k(tn2%PE`ZGU&=5S~U=PWPt(exCWvE0kXX$+cltgq$T5QF%#{zV8%Pqem z+za5qmk*K`ThYibDF-^a7GDfr5gFgwCEjr?Y3cXWuiUPDkcC=Vz_G*E(n4HGu|) zvp|E@*LdeZ>`VNSqxrPP`a}C#H`Uf{C_mV?RC7(;8V|l@ANW6bE?2WTKRGll_6SIE z3KDQk5CM%uoF4TjRJKZjIM5v$FbbWSstA6zg6+vJNq{Y)V0(id-Mt2Q*DnM7g-ivv z=GM9pcr8=3#2S3_*0TP-lQE%>F^co-V?yzD$+ytQ55uz^?d>x^fAo9oZ*KTr+x#n^ zeK~Q>zP_ww#pKZS9WOoA+w&cdCx4e*^&M$By!-B8v!_W5G`QSjL(_FTSv^pb;Tcrh znMSo+{#AQ4gVn;}+>fe-eIcc7jbU1)ifK&smhx@ac0*0sM>pZ@GHZ`vZ}Q` zNLkpEmg$CNr7D)yB4A6)^m1fb5rYRMRPWdVSf-~~tgx&wv&Di%Ayh2vSIaKL+2E-@ z!!a0HmSZ;=*GJ-_xy#W_3=_8kqr;3VRX|6F0yyAf!m%iWH?4BTvB+e(a0YO>uwO%} z8y~$X`v&)Ttp)NtJX?T^jCD||Hhs(HSuMDPtA#!7T73F0f@=xmH`Sv%PaPoTU>&So zW7Yl_*KGCzhdhykf3W6sc$%bXX+D?ZLPFw>{7v2%u~eqX|%nBN7B zS%Mcd-zggOULo_mLB)z-v=55CB$+$RR=QI&IkwhBCGD3ot({*N%eW7_;D9^K;4&=fU z8}^Vv`HJexvL$qiiPTUL@RE5TlkyeSdq^HMNyvAm@C?}SwT6t!S0p^4z=n(?-lJSe z{_c}=4!S5M None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): LEVIRCDPlus(str(tmp_path)) + + def test_plot(self, dataset: LEVIRCDPlus) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 9e1c99617d5..3a9b8e3f9bd 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -209,12 +209,16 @@ def _download(self) -> None: ) def plot( - self, sample: Dict[str, Tensor], suptitle: Optional[str] = None + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, ) -> plt.Figure: """Plot a sample from the dataset. Args: sample: a sample return by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel suptitle: optional suptitle to use for figure Returns; @@ -222,10 +226,39 @@ def plot( .. versionadded:: 0.2 """ - image, mask = sample["image"], sample["mask"] - - fig, axs = plt.subplots(1, 2, figsize=(10, 10)) - axs[0].imshow(image.permute(1, 2, 0)) - axs[1].imshow(mask) + image1, image2, mask = ( + sample["image"][0, ...], + sample["image"][1, ...], + sample["mask"], + ) + ncols = 3 + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = sample["prediction"] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].axis("off") + axs[2].imshow(mask) + axs[2].axis("off") + + if showing_predictions: + axs[3].imshow(prediction) + if show_titles: + axs[3].set_title("Prediction") + axs[3].axis("off") + + if show_titles: + axs[0].set_title("Image 1") + axs[1].set_title("Image 2") + axs[2].set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) return fig From 6c87550f6b46ebf4e2b363dd64b5ff3b1efb7092 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 29 Dec 2021 11:46:18 +0100 Subject: [PATCH 3/6] axis off --- torchgeo/datasets/levircd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 3a9b8e3f9bd..f55ae036014 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -249,9 +249,9 @@ def plot( if showing_predictions: axs[3].imshow(prediction) + axs[3].axis("off") if show_titles: axs[3].set_title("Prediction") - axs[3].axis("off") if show_titles: axs[0].set_title("Image 1") From b8c5e2d794670d6ff637259d6dade6649eff922b Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 29 Dec 2021 12:03:42 +0100 Subject: [PATCH 4/6] prediction flag --- torchgeo/datasets/levircd.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index f55ae036014..082a6da8ee4 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -233,8 +233,7 @@ def plot( ) ncols = 3 - showing_predictions = "prediction" in sample - if showing_predictions: + if "prediction" in sample: prediction = sample["prediction"] ncols += 1 @@ -247,7 +246,7 @@ def plot( axs[2].imshow(mask) axs[2].axis("off") - if showing_predictions: + if "prediction" in sample: axs[3].imshow(prediction) axs[3].axis("off") if show_titles: From 3d914e1f0b593280d1bf556bfe9cab6eda0caee3 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 29 Dec 2021 19:12:52 +0100 Subject: [PATCH 5/6] requested changes --- tests/datasets/test_levircd.py | 2 +- torchgeo/datasets/levircd.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index 752c510ce61..f61bc241be8 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -32,7 +32,7 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.utils, "download_url", download_url ) - md5 = "b61c300e9fd7146eb2c8e2512c0e9d39" + md5 = "1adf156f628aa32fb2e8fe6cada16c04" monkeypatch.setattr(LEVIRCDPlus, "md5", md5) # type: ignore[attr-defined] url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined] diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 082a6da8ee4..37b3d14c779 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -217,11 +217,11 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample return by :meth:`__getitem__` + sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional suptitle to use for figure - Returns; + Returns: a matplotlib Figure with the rendered sample .. versionadded:: 0.2 From c69d6f9bd2b332ba8576f716aaac5ae294e9dbd5 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 30 Dec 2021 09:51:39 +0100 Subject: [PATCH 6/6] indexing fix --- torchgeo/datasets/levircd.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 37b3d14c779..9098a23b585 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -226,11 +226,7 @@ def plot( .. versionadded:: 0.2 """ - image1, image2, mask = ( - sample["image"][0, ...], - sample["image"][1, ...], - sample["mask"], - ) + image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"]) ncols = 3 if "prediction" in sample: