From be29b24e5cea27d464fbc15495f46496e7972ebe Mon Sep 17 00:00:00 2001 From: Adrian Tofting Date: Fri, 29 Sep 2023 10:49:09 +0200 Subject: [PATCH] Move property `files` up to GeoDataset --- torchgeo/datasets/geo.py | 61 +++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 09564dbaba1..8fdc5b6484c 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -75,6 +75,13 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): _crs = CRS.from_epsg(4326) _res = 0.0 + #: Glob expression used to search for files. + #: + #: This expression should be specific enough that it will not pick up files from + #: other datasets. It should not include a file extension, as the dataset may be in + #: a different file format than what it was originally downloaded as. + filename_glob = "*" + # NOTE: according to the Python docs: # # * https://docs.python.org/3/library/exceptions.html#NotImplementedError @@ -98,6 +105,9 @@ def __init__( and returns a transformed version """ self.transforms = transforms + self.paths: Union[ + str, Iterable[str] + ] = "data" # TODO: temp fix when refactoring # Create an R-tree to index the dataset self.index = Index(interleaved=False, properties=Property(dimension=3)) @@ -269,6 +279,31 @@ def res(self, new_res: float) -> None: print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}") self._res = new_res + @property + def files(self) -> set[str]: + """A list of all files in the dataset. + + Returns: + All files in the dataset. + + .. versionadded:: 0.5 + """ + # Make iterable + if isinstance(self.paths, str): + paths: Iterable[str] = [self.paths] + else: + paths = self.paths + + # Using set to remove any duplicates if directories are overlapping + files: set[str] = set() + for path in paths: + if os.path.isdir(path): + pathname = os.path.join(path, "**", self.filename_glob) + files |= set(glob.iglob(pathname, recursive=True)) + else: + files.add(path) + + return files class RasterDataset(GeoDataset): """Abstract base class for :class:`GeoDataset` stored as raster files.""" @@ -423,32 +458,6 @@ def __init__( self._crs = cast(CRS, crs) self._res = cast(float, res) - @property - def files(self) -> set[str]: - """A list of all files in the dataset. - - Returns: - All files in the dataset. - - .. versionadded:: 0.5 - """ - # Make iterable - if isinstance(self.paths, str): - paths: Iterable[str] = [self.paths] - else: - paths = self.paths - - # Using set to remove any duplicates if directories are overlapping - files: set[str] = set() - for path in paths: - if os.path.isdir(path): - pathname = os.path.join(path, "**", self.filename_glob) - files |= set(glob.iglob(pathname, recursive=True)) - else: - files.add(path) - - return files - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query.