Skip to content

Commit

Permalink
Add own plot method and data.py to CBF (#410)
Browse files Browse the repository at this point in the history
* add own plot method and data.py

* clean up data.py

* version changed instead of added

* Update cbf.py

* Any instead of Tensor

* Fix VectorDataset tests

* Plot method in base class no longer needed/tested

* Removing unused imports

* Remove type ignore from openbuildings

* Fix tests

* Black formatting

Co-authored-by: Caleb Robinson <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
3 people authored Mar 20, 2022
1 parent 0a15632 commit b2e178f
Show file tree
Hide file tree
Showing 31 changed files with 118 additions and 124 deletions.
8 changes: 1 addition & 7 deletions tests/data/cbf/Alberta.geojson
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}]}
Binary file modified tests/data/cbf/Alberta.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/BritishColumbia.geojson

This file was deleted.

Binary file removed tests/data/cbf/BritishColumbia.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/Manitoba.geojson

This file was deleted.

Binary file removed tests/data/cbf/Manitoba.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/NewBrunswick.geojson

This file was deleted.

Binary file removed tests/data/cbf/NewBrunswick.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/NewfoundlandAndLabrador.geojson

This file was deleted.

Binary file removed tests/data/cbf/NewfoundlandAndLabrador.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/NorthwestTerritories.geojson

This file was deleted.

Binary file removed tests/data/cbf/NorthwestTerritories.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/NovaScotia.geojson

This file was deleted.

Binary file removed tests/data/cbf/NovaScotia.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/Nunavut.geojson

This file was deleted.

Binary file removed tests/data/cbf/Nunavut.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/Ontario.geojson

This file was deleted.

Binary file removed tests/data/cbf/Ontario.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/PrinceEdwardIsland.geojson

This file was deleted.

Binary file removed tests/data/cbf/PrinceEdwardIsland.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/Quebec.geojson

This file was deleted.

Binary file removed tests/data/cbf/Quebec.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/Saskatchewan.geojson

This file was deleted.

Binary file removed tests/data/cbf/Saskatchewan.zip
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/data/cbf/YukonTerritory.geojson

This file was deleted.

Binary file removed tests/data/cbf/YukonTerritory.zip
Binary file not shown.
53 changes: 53 additions & 0 deletions tests/data/cbf/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import json
import os
import shutil


def create_geojson():
geojson = {
"type": "FeatureCollection",
"crs": {
"type": "name",
"properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"},
},
"features": [
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
}
],
}
return geojson


if __name__ == "__main__":
filename = "Alberta.zip"
geojson = create_geojson()

with open(filename.replace(".zip", ".geojson"), "w") as f:
json.dump(geojson, f)

# compress single file directly with no directory
shutil.make_archive(
filename.replace(".zip", ""),
"zip",
os.getcwd(),
filename.replace(".zip", ".geojson"),
)

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")
30 changes: 13 additions & 17 deletions tests/datasets/test_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,12 @@ def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> CanadianBuildingFootprints:
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
md5s = [
"8a4a0a57367f67c69608d1452e30df13",
"1829f4054a9a81bb23871ca797a3895c",
"4358a0076fd43e9a2f436e74348813b0",
"ae3726b1263727d72565ecacfed56fb8",
"6861876d3a3ca7e79b28c61ab5906de4",
"d289c9ea49801bb287ddbde1ea5f31ef",
"3a940288297631b4e6a365266bfb949a",
"6b43b3632b165ff79c1ca0c693a61398",
"36283e0b29088ec281e77c989cbee100",
"773da9d33e3766b7237a1d7db0811832",
"cc833a65137c8a046c8f45bb695092b1",
"067664d066c4152fb96a5c129cbabadf",
"474bc084bc41b124aa4919e7a37a9648",
]
monkeypatch.setattr(CanadianBuildingFootprints, "md5s", md5s)
monkeypatch.setattr(
CanadianBuildingFootprints, "provinces_territories", ["Alberta"]
)
monkeypatch.setattr(
CanadianBuildingFootprints, "md5s", ["25091d1f051baa30d8f2026545cfb696"]
)
url = os.path.join("tests", "data", "cbf") + os.sep
monkeypatch.setattr(CanadianBuildingFootprints, "url", url)
monkeypatch.setattr(plt, "show", lambda *args: None)
Expand Down Expand Up @@ -76,7 +66,13 @@ def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None:
def test_plot(self, dataset: CanadianBuildingFootprints) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x["mask"])
dataset.plot(x, suptitle="Test")

def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
Expand Down
50 changes: 50 additions & 0 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from typing import Any, Callable, Dict, Optional

import matplotlib.pyplot as plt
from rasterio.crs import CRS

from .geo import VectorDataset
Expand Down Expand Up @@ -120,3 +121,52 @@ def _download(self) -> None:
self.root,
md5=md5 if self.checksum else None,
)

def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`VectorDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
.. versionchanged:: 0.3
Method now takes a sample dict, not a Tensor. Additionally, it is possible
to show subplot titles and/or use a custom suptitle.
"""
image = sample["mask"].squeeze(0)
ncols = 1

showing_prediction = "prediction" in sample
if showing_prediction:
pred = sample["prediction"].squeeze(0)
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4))

if showing_prediction:
axs[0].imshow(image)
axs[0].axis("off")
axs[1].imshow(pred)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(image)
axs.axis("off")
if show_titles:
axs.set_title("Mask")

if suptitle is not None:
plt.suptitle(suptitle)

return fig
15 changes: 0 additions & 15 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,21 +677,6 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:

return sample

def plot(self, data: Tensor) -> None:
"""Plot a data sample.
Args:
data: the data to plot
"""
array = data.squeeze().numpy()

# Plot the image
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()


class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
"""Abstract base class for datasets lacking geospatial information.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _verify(self) -> None:
"have manually downloaded the dataset as suggested in the documentation."
)

def plot( # type: ignore[override]
def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
Expand Down

0 comments on commit b2e178f

Please sign in to comment.