Skip to content

Commit

Permalink
Upgrade file to 3.10 syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Apr 14, 2023
1 parent d1dc59b commit 18b19f1
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Common dataset utilities."""

# https://github.com/sphinx-doc/sphinx/issues/11327
from __future__ import annotations

import bz2
Expand Down Expand Up @@ -88,7 +89,7 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
pass


def extract_archive(src: str, dst: Optional[str] = None) -> None:
def extract_archive(src: str, dst: str | None = None) -> None:
"""Extract an archive.
Args:
Expand All @@ -101,7 +102,7 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None:
if dst is None:
dst = os.path.dirname(src)

suffix_and_extractor: list[tuple[Union[str, tuple[str, ...]], Any]] = [
suffix_and_extractor: list[tuple[str | tuple[str, ...], Any]] = [
(".rar", _rarfile.RarFile),
(
(".tar", ".tar.gz", ".tar.bz2", ".tar.xz", ".tgz", ".tbz2", ".tbz", ".txz"),
Expand Down Expand Up @@ -135,9 +136,9 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None:
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
extract_root: str | None = None,
filename: str | None = None,
md5: str | None = None,
) -> None:
"""Download and extract an archive.
Expand All @@ -162,7 +163,7 @@ def download_and_extract_archive(


def download_radiant_mlhub_dataset(
dataset_id: str, download_root: str, api_key: Optional[str] = None
dataset_id: str, download_root: str, api_key: str | None = None
) -> None:
"""Download a dataset from Radiant Earth.
Expand All @@ -185,7 +186,7 @@ def download_radiant_mlhub_dataset(


def download_radiant_mlhub_collection(
collection_id: str, download_root: str, api_key: Optional[str] = None
collection_id: str, download_root: str, api_key: str | None = None
) -> None:
"""Download a collection from Radiant Earth.
Expand Down Expand Up @@ -255,7 +256,7 @@ def __getitem__(self, key: int) -> float: # noqa: D105
def __getitem__(self, key: slice) -> list[float]: # noqa: D105
pass

def __getitem__(self, key: Union[int, slice]) -> Union[float, list[float]]:
def __getitem__(self, key: int | slice) -> float | list[float]:
"""Index the (minx, maxx, miny, maxy, mint, maxt) tuple.
Args:
Expand All @@ -277,7 +278,7 @@ def __iter__(self) -> Iterator[float]:
"""
yield from [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt]

def __contains__(self, other: "BoundingBox") -> bool:
def __contains__(self, other: BoundingBox) -> bool:
"""Whether or not other is within the bounds of this bounding box.
Args:
Expand All @@ -297,7 +298,7 @@ def __contains__(self, other: "BoundingBox") -> bool:
and (self.mint <= other.maxt <= self.maxt)
)

def __or__(self, other: "BoundingBox") -> "BoundingBox":
def __or__(self, other: BoundingBox) -> BoundingBox:
"""The union operator.
Args:
Expand All @@ -317,7 +318,7 @@ def __or__(self, other: "BoundingBox") -> "BoundingBox":
max(self.maxt, other.maxt),
)

def __and__(self, other: "BoundingBox") -> "BoundingBox":
def __and__(self, other: BoundingBox) -> BoundingBox:
"""The intersection operator.
Args:
Expand Down Expand Up @@ -369,7 +370,7 @@ def volume(self) -> float:
"""
return self.area * (self.maxt - self.mint)

def intersects(self, other: "BoundingBox") -> bool:
def intersects(self, other: BoundingBox) -> bool:
"""Whether or not two bounding boxes intersect.
Args:
Expand Down Expand Up @@ -627,7 +628,7 @@ def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]:
return _dict_list_to_list_dict(sample)


def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]":
def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]:
"""Load an image file using rasterio.
Args:
Expand All @@ -637,7 +638,7 @@ def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]":
the image
"""
with rasterio.open(path) as f:
array: "np.typing.NDArray[np.int_]" = f.read().astype(np.int32)
array: np.typing.NDArray[np.int_] = f.read().astype(np.int32)
# NonGeoClassificationDataset expects images returned with channels last (HWC)
array = array.transpose(1, 2, 0)
return array
Expand All @@ -656,8 +657,8 @@ def draw_semantic_segmentation_masks(
image: Tensor,
mask: Tensor,
alpha: float = 0.5,
colors: Optional[Sequence[Union[str, tuple[int, int, int]]]] = None,
) -> "np.typing.NDArray[np.uint8]":
colors: Sequence[str | tuple[int, int, int]] | None = None,
) -> np.typing.NDArray[np.uint8]:
"""Overlay a semantic segmentation mask onto an image.
Args:
Expand All @@ -681,8 +682,8 @@ def draw_semantic_segmentation_masks(


def rgb_to_mask(
rgb: "np.typing.NDArray[np.uint8]", colors: list[tuple[int, int, int]]
) -> "np.typing.NDArray[np.uint8]":
rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]]
) -> np.typing.NDArray[np.uint8]:
"""Converts an RGB colormap mask to a integer mask.
Args:
Expand All @@ -696,7 +697,7 @@ def rgb_to_mask(
# we can map is 255

h, w = rgb.shape[:2]
mask: "np.typing.NDArray[np.uint8]" = np.zeros(shape=(h, w), dtype=np.uint8)
mask: np.typing.NDArray[np.uint8] = np.zeros(shape=(h, w), dtype=np.uint8)
for i, c in enumerate(colors):
cmask = rgb == c
# Only update mask if class is present in mask
Expand All @@ -706,11 +707,11 @@ def rgb_to_mask(


def percentile_normalization(
img: "np.typing.NDArray[np.int_]",
img: np.typing.NDArray[np.int_],
lower: float = 2,
upper: float = 98,
axis: Optional[Union[int, Sequence[int]]] = None,
) -> "np.typing.NDArray[np.int_]":
axis: int | Sequence[int] | None = None,
) -> np.typing.NDArray[np.int_]:
"""Applies percentile normalization to an input image.
Specifically, this will rescale the values in the input such that values <= the
Expand All @@ -732,7 +733,7 @@ def percentile_normalization(
assert lower < upper
lower_percentile = np.percentile(img, lower, axis=axis)
upper_percentile = np.percentile(img, upper, axis=axis)
img_normalized: "np.typing.NDArray[np.int_]" = np.clip(
img_normalized: np.typing.NDArray[np.int_] = np.clip(
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
)
return img_normalized

0 comments on commit 18b19f1

Please sign in to comment.