Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow VectorDataset to accept list of files #1597

Merged
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4ade256
Make RasterDataset accept list of files
Jun 22, 2023
a0f985d
Fix check if str
adriantre Jun 22, 2023
a9f6944
Use isdir and isfile
adriantre Jun 23, 2023
b69625b
Rename root to paths and update type hint
adriantre Jun 23, 2023
b7a51bd
Update children of RasterDataset methods using root
adriantre Jun 23, 2023
d6a0919
Fix check to cast str to list
adriantre Jun 28, 2023
ce5f474
Update conf files for RasterDatasets
adriantre Jun 28, 2023
81833e9
Add initial suggested test
adriantre Jun 28, 2023
b247861
Add workaround for lists LandCoverAIBase
adriantre Jun 29, 2023
f569051
Add method handle_nonlocal_path for users to override
adriantre Jun 29, 2023
d4f757c
Raise RuntimeError to support existing tests
adriantre Jun 29, 2023
0414f63
Remove reduntand cast to set
adriantre Jun 29, 2023
2195553
Remove required os.exists for paths
adriantre Jul 3, 2023
61de902
Revert "Remove required os.exists for paths"
adriantre Jul 3, 2023
5bed6f3
Use arg as potitional argument not kwarg
adriantre Sep 28, 2023
8e80458
Improve comments and logs about arg paths
adriantre Sep 28, 2023
b736cef
Remove misleading comment
adriantre Sep 28, 2023
2f7df48
Change type hint of 'paths' to Iterable
adriantre Sep 28, 2023
a6e5fe1
Change type hint of 'paths' to Iterable
adriantre Sep 28, 2023
ca5c4bf
Remove premature handling of non-local paths
adriantre Sep 28, 2023
a228cc2
Replace root with paths in docstrings
adriantre Sep 28, 2023
c22a6a9
Add versionadded to list_files docstring
adriantre Sep 28, 2023
44f6eb5
Add versionchanged to docstrings
adriantre Sep 28, 2023
9dae8c4
Update type of paths in childred of Raster
adriantre Sep 28, 2023
1311957
Replace docstring for paths in all raster
adriantre Sep 28, 2023
697dfd7
Swap root with paths for conf files for raster
adriantre Sep 28, 2023
026ee11
Add newline before versionchanged
adriantre Sep 29, 2023
628801a
Revert name to root in conf for ChesapeakeCVPR
adriantre Sep 29, 2023
eae2992
Simplify EUDEM tests
adamjstewart Sep 29, 2023
2bc82c8
paths must be a string if you want autodownload support
adamjstewart Sep 29, 2023
d391079
Convert list_files to a property
adamjstewart Sep 29, 2023
66f2f02
Fix type hints
adamjstewart Sep 29, 2023
8ec2e93
Test with a real empty directory
adamjstewart Sep 29, 2023
be29b24
Move property `files` up to GeoDataset
adriantre Sep 29, 2023
0b0ade4
Rename root to paths for VectorDataset
adriantre Sep 29, 2023
a02c9b1
Merge remote-tracking branch 'origin/main' into feature/refactor_vect…
adriantre Sep 29, 2023
1a0d877
Fix mypy
adriantre Sep 29, 2023
55fbf83
Fix tests
adriantre Sep 29, 2023
3504673
Delete duplicate code
adamjstewart Sep 29, 2023
89df751
Delete duplicate code
adamjstewart Sep 29, 2023
10fd27a
Fix test coverage
adamjstewart Sep 29, 2023
110d695
Document name change
adamjstewart Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Canadian Building Footprints dataset."""

import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
Expand Down Expand Up @@ -60,7 +61,7 @@ class CanadianBuildingFootprints(VectorDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.00001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -70,7 +71,7 @@ def __init__(
"""Initialize a new Dataset instance.

Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -83,8 +84,11 @@ def __init__(
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` and data is not found, or
``checksum=True`` and checksums don't match

.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.checksum = checksum

if download:
Expand All @@ -96,16 +100,17 @@ def __init__(
+ "You can use download=True to download it"
)

super().__init__(root, crs, res, transforms)
super().__init__(paths, crs, res, transforms)

def _check_integrity(self) -> bool:
"""Check integrity of dataset.

Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.root, prov_terr + ".zip")
filepath = os.path.join(self.paths, prov_terr + ".zip")
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True
Expand All @@ -115,11 +120,11 @@ def _download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return

assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + ".zip",
self.root,
self.paths,
md5=md5 if self.checksum else None,
)

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ class ChesapeakeCVPR(GeoDataset):
)

# these are used to check the integrity of the dataset
files = [
_files = [
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"de_1m_2013_extended-debuffered-test_tiles",
"de_1m_2013_extended-debuffered-train_tiles",
"de_1m_2013_extended-debuffered-val_tiles",
Expand Down Expand Up @@ -704,7 +704,7 @@ def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))

# Check if the extracted files already exist
if all(map(exists, self.files)):
if all(map(exists, self._files)):
return
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# Check if the zip files have already been downloaded
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class EnviroAtlas(GeoDataset):
)

# these are used to check the integrity of the dataset
files = [
_files = [
"austin_tx-2012_1m-test_tiles-debuffered",
"austin_tx-2012_1m-val5_tiles-debuffered",
"durham_nc-2012_1m-test_tiles-debuffered",
Expand Down Expand Up @@ -422,7 +422,7 @@ def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))

# Check if the extracted files already exist
if all(map(exists, self.files)):
if all(map(exists, self._files)):
return

# Check if the zip files have already been downloaded
Expand Down
47 changes: 41 additions & 6 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
_crs = CRS.from_epsg(4326)
_res = 0.0

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand All @@ -98,6 +105,9 @@ def __init__(
and returns a transformed version
"""
self.transforms = transforms
self.paths: Union[
str, Iterable[str]
] = "data" # TODO: temp fix when refactoring

# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
Expand Down Expand Up @@ -269,6 +279,32 @@ def res(self, new_res: float) -> None:
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.

Returns:
All files in the dataset.

.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

return files


class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
Expand Down Expand Up @@ -580,7 +616,7 @@ class VectorDataset(GeoDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -589,7 +625,7 @@ def __init__(
"""Initialize a new Dataset instance.

Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -606,13 +642,12 @@ def __init__(
"""
super().__init__(transforms)

self.root = root
self.paths = paths
self.label_name = label_name

# Populate the dataset index
i = 0
pathname = os.path.join(root, "**", self.filename_glob)
for filepath in glob.iglob(pathname, recursive=True):
for filepath in self.files:
try:
with fiona.open(filepath) as src:
if crs is None:
Expand All @@ -633,7 +668,7 @@ def __init__(
i += 1

if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{root}'`"
msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`"
raise FileNotFoundError(msg)

self._crs = crs
Expand Down
24 changes: 18 additions & 6 deletions torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hashlib
import os
from functools import lru_cache
from typing import Any, Callable, Optional, cast
from typing import Any, Callable, Optional, Union, cast

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -71,7 +71,10 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
}

def __init__(
self, root: str = "data", download: bool = False, checksum: bool = False
self,
root: Union[str, list[str]] = "data",
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new LandCover.ai dataset instance.

Expand All @@ -87,6 +90,10 @@ def __init__(
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
if isinstance(root, list):
# TODO: Workaround until list are fully implemented
root = root[0]
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

self.root = root
self.download = download
self.checksum = checksum
Expand Down Expand Up @@ -211,7 +218,7 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -222,7 +229,7 @@ def __init__(
"""Initialize a new LandCover.ai NonGeo dataset instance.

Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -236,9 +243,14 @@ def __init__(
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match

.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
LandCoverAIBase.__init__(self, root, download, checksum)
RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache)
LandCoverAIBase.__init__(self, paths, download, checksum)
RasterDataset.__init__(
self, paths, crs, res, transforms=transforms, cache=cache
)

def _verify_data(self) -> bool:
"""Verify if the images and masks are present."""
Expand Down
34 changes: 21 additions & 13 deletions torchgeo/datasets/openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import json
import os
import sys
from typing import Any, Callable, Optional, cast
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union, cast

import fiona
import fiona.transform
Expand Down Expand Up @@ -205,7 +206,7 @@ class OpenBuildings(VectorDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -214,7 +215,7 @@ def __init__(
"""Initialize a new Dataset instance.

Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -224,11 +225,13 @@ def __init__(

Raises:
FileNotFoundError: if no files are found in ``root``

.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.res = res
self.checksum = checksum
self.root = root
self.res = res
self.transforms = transforms

Expand All @@ -237,15 +240,17 @@ def __init__(
# Create an R-tree to index the dataset using the polygon centroid as bounds
self.index = Index(interleaved=False, properties=Property(dimension=3))

with open(os.path.join(root, "tiles.geojson")) as f:
assert isinstance(self.paths, str)
with open(os.path.join(self.paths, "tiles.geojson")) as f:
data = json.load(f)

features = data["features"]
features_filenames = [
feature["properties"]["tile_url"].split("/")[-1] for feature in features
] # get csv filename

polygon_files = glob.glob(os.path.join(self.root, self.zipfile_glob))
assert isinstance(self.paths, str)
polygon_files = glob.glob(os.path.join(self.paths, self.zipfile_glob))
polygon_filenames = [f.split(os.sep)[-1] for f in polygon_files]

matched_features = [
Expand Down Expand Up @@ -273,15 +278,16 @@ def __init__(
maxt = sys.maxsize
coords = (minx, maxx, miny, maxy, mint, maxt)

assert isinstance(self.paths, str)
filepath = os.path.join(
self.root, feature["properties"]["tile_url"].split("/")[-1]
self.paths, feature["properties"]["tile_url"].split("/")[-1]
)
self.index.insert(i, coords, filepath)
i += 1

if i == 0:
raise FileNotFoundError(
f"No {self.__class__.__name__} data was found in '{self.root}'"
f"No {self.__class__.__name__} data was found in '{self.paths}'"
)

self._crs = crs
Expand Down Expand Up @@ -398,7 +404,8 @@ def _verify(self) -> None:
FileNotFoundError: if metadata file is not found in root
"""
# Check if the zip files have already been downloaded and checksum
pathname = os.path.join(self.root, self.zipfile_glob)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, self.zipfile_glob)
i = 0
for zipfile in glob.iglob(pathname):
filename = os.path.basename(zipfile)
Expand All @@ -410,14 +417,15 @@ def _verify(self) -> None:
return

# check if the metadata file has been downloaded
if not os.path.exists(os.path.join(self.root, self.meta_data_filename)):
assert isinstance(self.paths, str)
if not os.path.exists(os.path.join(self.paths, self.meta_data_filename)):
raise FileNotFoundError(
f"Meta data file {self.meta_data_filename} "
f"not found in in `root={self.root}`."
f"not found in in `root={self.paths}`."
)

raise RuntimeError(
f"Dataset not found in `root={self.root}` "
f"Dataset not found in `root={self.paths}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as suggested in the documentation."
)
Expand Down
Loading