Skip to content

Commit

Permalink
Feature/refactor vector root to paths (#1597)
Browse files Browse the repository at this point in the history
* Make RasterDataset accept list of files

* Fix check if str

* Use isdir and isfile

* Rename root to paths and update type hint

* Update children of RasterDataset methods using root

* Fix check to cast str to list

* Update conf files for RasterDatasets

* Add initial suggested test

* Add workaround for lists LandCoverAIBase

* Add method handle_nonlocal_path for users to override

* Raise RuntimeError to support existing tests

* Remove reduntand cast to set

* Remove required os.exists for paths

* Revert "Remove required os.exists for paths"

This reverts commit 84bf62b.

* Use arg  as potitional argument not kwarg

* Improve comments and logs about arg paths

* Remove misleading comment

* Change type hint of 'paths' to Iterable

* Change type hint of 'paths' to Iterable

* Remove premature handling of non-local paths

* Replace root with paths in docstrings

* Add versionadded to list_files docstring

* Add versionchanged to docstrings

* Update type of paths in childred of Raster

* Replace docstring for paths in all raster

* Swap root with paths for conf files for raster

* Add newline before versionchanged

* Revert name to root in conf for ChesapeakeCVPR

* Simplify EUDEM tests

* paths must be a string if you want autodownload support

* Convert list_files to a property

* Fix type hints

* Test with a real empty directory

* Move property `files` up to GeoDataset

* Rename root to paths for VectorDataset

* Fix mypy

* Fix tests

* Delete duplicate code

* Delete duplicate code

* Fix test coverage

* Document name change

---------

Co-authored-by: Adrian Tofting <[email protected]>
Co-authored-by: Adrian Tofting <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
4 people authored Sep 29, 2023
1 parent 6ae0d78 commit 3532f78
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 79 deletions.
2 changes: 1 addition & 1 deletion tests/datasets/test_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_or(self, dataset: CanadianBuildingFootprints) -> None:
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None:
CanadianBuildingFootprints(root=dataset.root, download=True)
CanadianBuildingFootprints(dataset.paths, download=True)

def test_plot(self, dataset: CanadianBuildingFootprints) -> None:
query = dataset.bounds
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def dataset(
)
monkeypatch.setattr(
ChesapeakeCVPR,
"files",
"_files",
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
)
root = str(tmp_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def dataset(
)
monkeypatch.setattr(
EnviroAtlas,
"files",
"_files",
["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"],
)
root = str(tmp_path)
Expand Down
10 changes: 5 additions & 5 deletions tests/datasets/test_openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings:

monkeypatch.setattr(OpenBuildings, "md5s", md5s)
transforms = nn.Identity()
return OpenBuildings(root=root, transforms=transforms)
return OpenBuildings(root, transforms=transforms)

def test_no_shapes_to_rasterize(
self, dataset: OpenBuildings, tmp_path: Path
Expand All @@ -61,19 +61,19 @@ def test_no_building_data_found(self, tmp_path: Path) -> None:
with pytest.raises(
RuntimeError, match="have manually downloaded the dataset as suggested "
):
OpenBuildings(root=false_root)
OpenBuildings(false_root)

def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
OpenBuildings(dataset.root, checksum=True)
OpenBuildings(dataset.paths, checksum=True)

def test_no_meta_data_found(self, tmp_path: Path) -> None:
false_root = os.path.join(tmp_path, "empty")
os.makedirs(false_root)
with pytest.raises(FileNotFoundError, match="Meta data file"):
OpenBuildings(root=false_root)
OpenBuildings(false_root)

def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
# change meta data to another 'title_url' so that there is no match found
Expand All @@ -85,7 +85,7 @@ def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
json.dump(content, f)

with pytest.raises(FileNotFoundError, match="data was found in"):
OpenBuildings(dataset.root)
OpenBuildings(dataset.paths)

def test_getitem(self, dataset: OpenBuildings) -> None:
x = dataset[dataset.bounds]
Expand Down
21 changes: 13 additions & 8 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Canadian Building Footprints dataset."""

import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
Expand Down Expand Up @@ -60,7 +61,7 @@ class CanadianBuildingFootprints(VectorDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.00001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -70,7 +71,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -83,8 +84,11 @@ def __init__(
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` and data is not found, or
``checksum=True`` and checksums don't match
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.checksum = checksum

if download:
Expand All @@ -96,16 +100,17 @@ def __init__(
+ "You can use download=True to download it"
)

super().__init__(root, crs, res, transforms)
super().__init__(paths, crs, res, transforms)

def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.root, prov_terr + ".zip")
filepath = os.path.join(self.paths, prov_terr + ".zip")
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True
Expand All @@ -115,11 +120,11 @@ def _download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return

assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + ".zip",
self.root,
self.paths,
md5=md5 if self.checksum else None,
)

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ class ChesapeakeCVPR(GeoDataset):
)

# these are used to check the integrity of the dataset
files = [
_files = [
"de_1m_2013_extended-debuffered-test_tiles",
"de_1m_2013_extended-debuffered-train_tiles",
"de_1m_2013_extended-debuffered-val_tiles",
Expand Down Expand Up @@ -704,7 +704,7 @@ def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))

# Check if the extracted files already exist
if all(map(exists, self.files)):
if all(map(exists, self._files)):
return

# Check if the zip files have already been downloaded
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class EnviroAtlas(GeoDataset):
)

# these are used to check the integrity of the dataset
files = [
_files = [
"austin_tx-2012_1m-test_tiles-debuffered",
"austin_tx-2012_1m-val5_tiles-debuffered",
"durham_nc-2012_1m-test_tiles-debuffered",
Expand Down Expand Up @@ -422,7 +422,7 @@ def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))

# Check if the extracted files already exist
if all(map(exists, self.files)):
if all(map(exists, self._files)):
return

# Check if the zip files have already been downloaded
Expand Down
88 changes: 42 additions & 46 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
dataset = landsat7 | landsat8
"""

paths: Union[str, Iterable[str]]
_crs = CRS.from_epsg(4326)
_res = 0.0

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand Down Expand Up @@ -269,17 +277,36 @@ def res(self, new_res: float) -> None:
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.
Returns:
All files in the dataset.
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)

return files


class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

#: Regular expression used to extract date from filename.
#:
#: The expression should use named groups. The expression may contain any number of
Expand Down Expand Up @@ -423,32 +450,6 @@ def __init__(
self._crs = cast(CRS, crs)
self._res = cast(float, res)

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.
Returns:
All files in the dataset.
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)

return files

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Expand Down Expand Up @@ -571,16 +572,9 @@ def _load_warp_file(self, filepath: str) -> DatasetReader:
class VectorDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as vector files."""

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -589,7 +583,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -603,16 +597,18 @@ def __init__(
.. versionadded:: 0.4
The *label_name* parameter.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
super().__init__(transforms)

self.root = root
self.paths = paths
self.label_name = label_name

# Populate the dataset index
i = 0
pathname = os.path.join(root, "**", self.filename_glob)
for filepath in glob.iglob(pathname, recursive=True):
for filepath in self.files:
try:
with fiona.open(filepath) as src:
if crs is None:
Expand All @@ -633,7 +629,7 @@ def __init__(
i += 1

if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{root}'`"
msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`"
raise FileNotFoundError(msg)

self._crs = crs
Expand Down
Loading

0 comments on commit 3532f78

Please sign in to comment.