Skip to content

Commit

Permalink
Move property files up to GeoDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adriantre committed Sep 29, 2023
1 parent 8ec2e93 commit be29b24
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
_crs = CRS.from_epsg(4326)
_res = 0.0

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand All @@ -98,6 +105,9 @@ def __init__(
and returns a transformed version
"""
self.transforms = transforms
self.paths: Union[
str, Iterable[str]
] = "data" # TODO: temp fix when refactoring

# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
Expand Down Expand Up @@ -269,6 +279,31 @@ def res(self, new_res: float) -> None:
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.
Returns:
All files in the dataset.
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)

return files

class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
Expand Down Expand Up @@ -423,32 +458,6 @@ def __init__(
self._crs = cast(CRS, crs)
self._res = cast(float, res)

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.
Returns:
All files in the dataset.
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)

return files

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Expand Down

0 comments on commit be29b24

Please sign in to comment.