Skip to content

Commit

Permalink
Move list_files up to GeoDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adriantre committed Sep 29, 2023
1 parent 628801a commit bdb9422
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
_crs = CRS.from_epsg(4326)
_res = 0.0

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand All @@ -98,6 +105,9 @@ def __init__(
and returns a transformed version
"""
self.transforms = transforms
self.paths: Union[
str, Iterable[str]
] = "data" # TODO: temp fix when refactoring

# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
Expand Down Expand Up @@ -269,6 +279,33 @@ def res(self, new_res: float) -> None:
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res

def list_files(self, filename_glob: Optional[str] = None) -> list[str]:
"""Get list of files matching filename_glob.
Args:
filename_glob: Defaults to self.filename_glob
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

filename_glob = filename_glob or self.filename_glob

# using set to remove any duplicates if directories are overlapping
filepaths: set[str] = set()
for dir_or_file in paths:
if os.path.isfile(dir_or_file):
filepaths.add(dir_or_file)
else:
pathname = os.path.join(dir_or_file, "**", filename_glob)
filepaths |= set(glob.iglob(pathname, recursive=True))

return list(filepaths)


class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
Expand Down Expand Up @@ -423,33 +460,6 @@ def __init__(
self._crs = cast(CRS, crs)
self._res = cast(float, res)

def list_files(self, filename_glob: Optional[str] = None) -> list[str]:
"""Get list of files matching filename_glob.
Args:
filename_glob: Defaults to self.filename_glob
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

filename_glob = filename_glob or self.filename_glob

# using set to remove any duplicates if directories are overlapping
filepaths: set[str] = set()
for dir_or_file in paths:
if os.path.isfile(dir_or_file):
filepaths.add(dir_or_file)
else:
pathname = os.path.join(dir_or_file, "**", filename_glob)
filepaths |= set(glob.iglob(pathname, recursive=True))

return list(filepaths)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Expand Down

0 comments on commit bdb9422

Please sign in to comment.