diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 3c36019de7b..09127b2e7ea 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -3,6 +3,7 @@ """Common dataset utilities.""" +# https://github.com/sphinx-doc/sphinx/issues/11327 from __future__ import annotations import bz2 @@ -88,7 +89,7 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: pass -def extract_archive(src: str, dst: Optional[str] = None) -> None: +def extract_archive(src: str, dst: str | None = None) -> None: """Extract an archive. Args: @@ -101,7 +102,7 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None: if dst is None: dst = os.path.dirname(src) - suffix_and_extractor: list[tuple[Union[str, tuple[str, ...]], Any]] = [ + suffix_and_extractor: list[tuple[str | tuple[str, ...], Any]] = [ (".rar", _rarfile.RarFile), ( (".tar", ".tar.gz", ".tar.bz2", ".tar.xz", ".tgz", ".tbz2", ".tbz", ".txz"), @@ -135,9 +136,9 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None: def download_and_extract_archive( url: str, download_root: str, - extract_root: Optional[str] = None, - filename: Optional[str] = None, - md5: Optional[str] = None, + extract_root: str | None = None, + filename: str | None = None, + md5: str | None = None, ) -> None: """Download and extract an archive. @@ -162,7 +163,7 @@ def download_and_extract_archive( def download_radiant_mlhub_dataset( - dataset_id: str, download_root: str, api_key: Optional[str] = None + dataset_id: str, download_root: str, api_key: str | None = None ) -> None: """Download a dataset from Radiant Earth. @@ -185,7 +186,7 @@ def download_radiant_mlhub_dataset( def download_radiant_mlhub_collection( - collection_id: str, download_root: str, api_key: Optional[str] = None + collection_id: str, download_root: str, api_key: str | None = None ) -> None: """Download a collection from Radiant Earth. @@ -255,7 +256,7 @@ def __getitem__(self, key: int) -> float: # noqa: D105 def __getitem__(self, key: slice) -> list[float]: # noqa: D105 pass - def __getitem__(self, key: Union[int, slice]) -> Union[float, list[float]]: + def __getitem__(self, key: int | slice) -> float | list[float]: """Index the (minx, maxx, miny, maxy, mint, maxt) tuple. Args: @@ -277,7 +278,7 @@ def __iter__(self) -> Iterator[float]: """ yield from [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt] - def __contains__(self, other: "BoundingBox") -> bool: + def __contains__(self, other: BoundingBox) -> bool: """Whether or not other is within the bounds of this bounding box. Args: @@ -297,7 +298,7 @@ def __contains__(self, other: "BoundingBox") -> bool: and (self.mint <= other.maxt <= self.maxt) ) - def __or__(self, other: "BoundingBox") -> "BoundingBox": + def __or__(self, other: BoundingBox) -> BoundingBox: """The union operator. Args: @@ -317,7 +318,7 @@ def __or__(self, other: "BoundingBox") -> "BoundingBox": max(self.maxt, other.maxt), ) - def __and__(self, other: "BoundingBox") -> "BoundingBox": + def __and__(self, other: BoundingBox) -> BoundingBox: """The intersection operator. Args: @@ -369,7 +370,7 @@ def volume(self) -> float: """ return self.area * (self.maxt - self.mint) - def intersects(self, other: "BoundingBox") -> bool: + def intersects(self, other: BoundingBox) -> bool: """Whether or not two bounding boxes intersect. Args: @@ -627,7 +628,7 @@ def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: return _dict_list_to_list_dict(sample) -def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]": +def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: """Load an image file using rasterio. Args: @@ -637,7 +638,7 @@ def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]": the image """ with rasterio.open(path) as f: - array: "np.typing.NDArray[np.int_]" = f.read().astype(np.int32) + array: np.typing.NDArray[np.int_] = f.read().astype(np.int32) # NonGeoClassificationDataset expects images returned with channels last (HWC) array = array.transpose(1, 2, 0) return array @@ -656,8 +657,8 @@ def draw_semantic_segmentation_masks( image: Tensor, mask: Tensor, alpha: float = 0.5, - colors: Optional[Sequence[Union[str, tuple[int, int, int]]]] = None, -) -> "np.typing.NDArray[np.uint8]": + colors: Sequence[str | tuple[int, int, int]] | None = None, +) -> np.typing.NDArray[np.uint8]: """Overlay a semantic segmentation mask onto an image. Args: @@ -681,8 +682,8 @@ def draw_semantic_segmentation_masks( def rgb_to_mask( - rgb: "np.typing.NDArray[np.uint8]", colors: list[tuple[int, int, int]] -) -> "np.typing.NDArray[np.uint8]": + rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]] +) -> np.typing.NDArray[np.uint8]: """Converts an RGB colormap mask to a integer mask. Args: @@ -696,7 +697,7 @@ def rgb_to_mask( # we can map is 255 h, w = rgb.shape[:2] - mask: "np.typing.NDArray[np.uint8]" = np.zeros(shape=(h, w), dtype=np.uint8) + mask: np.typing.NDArray[np.uint8] = np.zeros(shape=(h, w), dtype=np.uint8) for i, c in enumerate(colors): cmask = rgb == c # Only update mask if class is present in mask @@ -706,11 +707,11 @@ def rgb_to_mask( def percentile_normalization( - img: "np.typing.NDArray[np.int_]", + img: np.typing.NDArray[np.int_], lower: float = 2, upper: float = 98, - axis: Optional[Union[int, Sequence[int]]] = None, -) -> "np.typing.NDArray[np.int_]": + axis: int | Sequence[int] | None = None, +) -> np.typing.NDArray[np.int_]: """Applies percentile normalization to an input image. Specifically, this will rescale the values in the input such that values <= the @@ -732,7 +733,7 @@ def percentile_normalization( assert lower < upper lower_percentile = np.percentile(img, lower, axis=axis) upper_percentile = np.percentile(img, upper, axis=axis) - img_normalized: "np.typing.NDArray[np.int_]" = np.clip( + img_normalized: np.typing.NDArray[np.int_] = np.clip( (img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1 ) return img_normalized