From 71bc7bce73df838dbf16b2ce7fba0b8edbf50b7a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 15 Apr 2023 20:27:51 -0500 Subject: [PATCH] Drop Python 3.8 support (#1246) --- .github/workflows/style.yaml | 2 +- .github/workflows/tests.yaml | 4 +- .pre-commit-config.yaml | 2 +- .readthedocs.yaml | 4 +- docs/user/contributing.rst | 6 +- environment.yml | 26 +++--- evaluate.py | 4 +- experiments/download_ssl4eo.py | 58 ++++++------- experiments/run_benchmarks_experiments.py | 3 +- experiments/run_so2sat_byol_experiments.py | 3 +- experiments/sample_ssl4eo.py | 7 +- pyproject.toml | 4 +- requirements/min-reqs.old | 24 +++--- setup.cfg | 52 ++++++------ tests/data/enviroatlas/data.py | 6 +- tests/data/l7irish/data.py | 4 +- tests/data/l8biome/data.py | 4 +- tests/data/reforestree/data.py | 3 +- tests/data/sentinel1/data.py | 4 +- tests/data/sentinel2/data.py | 4 +- tests/data/spacenet/data.py | 6 +- tests/data/ssl4eo/s12/data.py | 4 +- tests/datamodules/test_geo.py | 6 +- tests/datasets/test_eurosat.py | 3 +- tests/datasets/test_geo.py | 7 +- tests/datasets/test_sentinel.py | 7 +- tests/datasets/test_splits.py | 9 +- tests/datasets/test_utils.py | 42 +++++----- tests/models/test_resnet.py | 6 +- tests/models/test_vit.py | 6 +- tests/samplers/test_batch.py | 6 +- tests/samplers/test_single.py | 4 +- tests/samplers/test_utils.py | 4 +- tests/trainers/conftest.py | 5 +- tests/trainers/test_byol.py | 22 ++--- tests/trainers/test_classification.py | 40 ++++----- tests/trainers/test_detection.py | 18 ++-- tests/trainers/test_regression.py | 26 +++--- tests/trainers/test_segmentation.py | 18 ++-- tests/transforms/test_indices.py | 16 ++-- tests/transforms/test_transforms.py | 20 ++--- torchgeo/datamodules/chesapeake.py | 14 ++-- torchgeo/datamodules/deepglobelandcover.py | 4 +- torchgeo/datamodules/etci2021.py | 6 +- torchgeo/datamodules/geo.py | 60 +++++++------- torchgeo/datamodules/gid15.py | 4 +- torchgeo/datamodules/inria.py | 4 +- torchgeo/datamodules/l7irish.py | 4 +- torchgeo/datamodules/l8biome.py | 4 +- torchgeo/datamodules/naip.py | 4 +- torchgeo/datamodules/nasa_marine_debris.py | 6 +- torchgeo/datamodules/oscd.py | 4 +- torchgeo/datamodules/potsdam.py | 4 +- torchgeo/datamodules/sen12ms.py | 6 +- torchgeo/datamodules/spacenet.py | 6 +- torchgeo/datamodules/utils.py | 4 +- torchgeo/datamodules/vaihingen.py | 4 +- torchgeo/datasets/advance.py | 12 +-- torchgeo/datasets/agb_live_woody_density.py | 6 +- torchgeo/datasets/astergdem.py | 6 +- torchgeo/datasets/benin_cashews.py | 20 ++--- torchgeo/datasets/bigearthnet.py | 16 ++-- torchgeo/datasets/cbf.py | 6 +- torchgeo/datasets/cdl.py | 6 +- torchgeo/datasets/chesapeake.py | 15 ++-- torchgeo/datasets/cloud_cover.py | 17 ++-- torchgeo/datasets/cms_mangrove_canopy.py | 6 +- torchgeo/datasets/cowc.py | 12 +-- torchgeo/datasets/cv4a_kenya_crop_type.py | 20 ++--- torchgeo/datasets/cyclone.py | 14 ++-- torchgeo/datasets/deepglobelandcover.py | 8 +- torchgeo/datasets/dfc2022.py | 11 +-- torchgeo/datasets/eddmaps.py | 4 +- torchgeo/datasets/enviroatlas.py | 11 +-- torchgeo/datasets/esri2020.py | 6 +- torchgeo/datasets/etci2021.py | 10 +-- torchgeo/datasets/eudem.py | 6 +- torchgeo/datasets/eurosat.py | 9 +- torchgeo/datasets/fair1m.py | 14 ++-- torchgeo/datasets/forestdamage.py | 18 ++-- torchgeo/datasets/gbif.py | 6 +- torchgeo/datasets/geo.py | 57 ++++++------- torchgeo/datasets/gid15.py | 10 +-- torchgeo/datasets/globbiomass.py | 10 +-- torchgeo/datasets/idtrees.py | 30 +++---- torchgeo/datasets/inaturalist.py | 4 +- torchgeo/datasets/inria.py | 10 +-- torchgeo/datasets/l7irish.py | 13 +-- torchgeo/datasets/l8biome.py | 13 +-- torchgeo/datasets/landcoverai.py | 18 ++-- torchgeo/datasets/landsat.py | 9 +- torchgeo/datasets/levircd.py | 10 +-- torchgeo/datasets/loveda.py | 12 +-- torchgeo/datasets/millionaid.py | 12 +-- torchgeo/datasets/naip.py | 4 +- torchgeo/datasets/nasa_marine_debris.py | 10 +-- torchgeo/datasets/openbuildings.py | 18 ++-- torchgeo/datasets/oscd.py | 15 ++-- torchgeo/datasets/patternnet.py | 6 +- torchgeo/datasets/potsdam.py | 8 +- torchgeo/datasets/reforestree.py | 14 ++-- torchgeo/datasets/resisc45.py | 6 +- torchgeo/datasets/seco.py | 10 +-- torchgeo/datasets/sen12ms.py | 13 +-- torchgeo/datasets/sentinel.py | 11 +-- torchgeo/datasets/so2sat.py | 9 +- torchgeo/datasets/spacenet.py | 58 ++++++------- torchgeo/datasets/splits.py | 17 ++-- torchgeo/datasets/ssl4eo.py | 8 +- torchgeo/datasets/ucmerced.py | 6 +- torchgeo/datasets/usavars.py | 11 +-- torchgeo/datasets/utils.py | 92 ++++++++++----------- torchgeo/datasets/vaihingen.py | 8 +- torchgeo/datasets/vhr10.py | 16 ++-- torchgeo/datasets/xview.py | 10 +-- torchgeo/datasets/zuericrop.py | 11 +-- torchgeo/models/api.py | 4 +- torchgeo/models/changestar.py | 10 +-- torchgeo/models/farseg.py | 10 +-- torchgeo/models/fcsiam.py | 3 +- torchgeo/samplers/batch.py | 11 +-- torchgeo/samplers/single.py | 9 +- torchgeo/samplers/utils.py | 16 ++-- torchgeo/trainers/byol.py | 10 +-- torchgeo/trainers/classification.py | 6 +- torchgeo/trainers/detection.py | 10 +-- torchgeo/trainers/regression.py | 6 +- torchgeo/trainers/segmentation.py | 6 +- torchgeo/trainers/utils.py | 10 +-- torchgeo/transforms/indices.py | 10 +-- torchgeo/transforms/transforms.py | 26 +++--- train.py | 12 +-- 132 files changed, 811 insertions(+), 813 deletions(-) diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 783885facbc..6da370e8ed9 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -123,7 +123,7 @@ jobs: pip install -r requirements/style.txt pip list - name: Run pyupgrade checks - run: pyupgrade --py38-plus $(find . -path ./docs/src -prune -o -name "*.py" -print) + run: pyupgrade --py39-plus $(find . -path ./docs/src -prune -o -name "*.py" -print) concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 60daeb7fd50..1823649c088 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -40,7 +40,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.9', '3.10'] steps: - name: Clone repo uses: actions/checkout@v3.5.0 @@ -99,7 +99,7 @@ jobs: - name: Set up python uses: actions/setup-python@v4.5.0 with: - python-version: '3.8' + python-version: '3.9' - name: Cache dependencies uses: actions/cache@v3.3.1 id: cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 813ae060051..c9bbd99ab4a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: v3.3.1 hooks: - id: pyupgrade - args: [--py38-plus] + args: [--py39-plus] - repo: https://github.com/pycqa/isort rev: 5.12.0 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 90c9e30d0fe..e564d336f8b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,9 +6,9 @@ version: 2 # Set the version of Python build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # Configuration of the Python environment to be used python: diff --git a/docs/user/contributing.rst b/docs/user/contributing.rst index 4dbe51c8a24..50e4d251579 100644 --- a/docs/user/contributing.rst +++ b/docs/user/contributing.rst @@ -58,14 +58,14 @@ For example, if you add a new dataset in ``torchgeo/datasets/foo.py``, you'll ne $ pytest --cov=torchgeo/datasets --cov-report=term-missing tests/datasets/test_foo.py ========================= test session starts ========================= - platform darwin -- Python 3.8.11, pytest-6.2.4, py-1.9.0, pluggy-0.13.0 + platform darwin -- Python 3.10.11, pytest-6.2.4, py-1.9.0, pluggy-0.13.0 rootdir: ~/torchgeo, configfile: pyproject.toml plugins: mock-1.11.1, anyio-3.2.1, cov-2.8.1, nbmake-0.5 collected 7 items tests/datasets/test_foo.py ....... [100%] - ---------- coverage: platform darwin, python 3.8.11-final-0 ----------- + ---------- coverage: platform darwin, python 3.10.11-final-0 ----------- Name Stmts Miss Cover Missing ----------------------------------------------------------------------- torchgeo/datasets/__init__.py 26 0 100% @@ -103,7 +103,7 @@ All of these tools should be used from the root of the project to ensure that ou $ black . $ isort . - $ pyupgrade --py38-plus $(find . -name "*.py") + $ pyupgrade --py39-plus $(find . -name "*.py") Flake8, pydocstyle, and mypy won't format your code for you, but they will warn you about potential issues with your code or docstrings: diff --git a/environment.yml b/environment.yml index f8978808141..602c1913265 100644 --- a/environment.yml +++ b/environment.yml @@ -4,19 +4,19 @@ channels: - conda-forge dependencies: - einops>=0.3 - - fiona>=1.8.12 - - h5py>=2.9 - - matplotlib>=3.3 - - numpy>=1.17.3 + - fiona>=1.8.19 + - h5py>=3 + - matplotlib>=3.3.3 + - numpy>=1.19.3 - pip - pycocotools>=2.0.1 - - pyproj>=2.4.1 - - python>=3.8 + - pyproj>=3 + - python>=3.9 - pytorch>=1.12 - pyvista>=0.25.2 - rarfile>=4 - - rasterio>=1.1.1 - - shapely>=1.7 + - rasterio>=1.2 + - shapely>=1.7.1 - torchvision>=0.13 - pip: - black[jupyter]>=21.8 @@ -30,18 +30,18 @@ dependencies: - nbmake>=1.3.3 - nbsphinx>=0.8.5 - omegaconf>=2.1 - - opencv-python>=4.1.2 - - pandas>=0.25.2 - - pillow>=6.2.1 + - opencv-python>=4.4.0.46 + - pandas>=1.1.3 + - pillow>=8 - pydocstyle[toml]>=6.1 - pytest>=6.1.2 - pytest-cov>=2.4 - git+https://github.com/pytorch/pytorch_sphinx_theme - - pyupgrade>=2.4 + - pyupgrade>=2.8 - radiant-mlhub>=0.3 - rtree>=1 - scikit-image>=0.18 - - scikit-learn>=0.22 + - scikit-learn>=0.24 - scipy>=1.6.2 - segmentation-models-pytorch>=0.2 - setuptools>=42 diff --git a/evaluate.py b/evaluate.py index f8d9cfcc76c..beab7c25a8f 100755 --- a/evaluate.py +++ b/evaluate.py @@ -8,7 +8,7 @@ import argparse import csv import os -from typing import Any, Dict, Union, cast +from typing import Any, Union, cast import lightning.pytorch as pl import torch @@ -155,7 +155,7 @@ def main(args: argparse.Namespace) -> None: dm.setup("validate") # Record model hyperparameters - hparams = cast(Dict[str, Union[str, float]], model.hparams) + hparams = cast(dict[str, Union[str, float]], model.hparams) if issubclass(TASK, ClassificationTask): val_row = { "split": "val", diff --git a/experiments/download_ssl4eo.py b/experiments/download_ssl4eo.py index ec80d0a7296..78526b18144 100644 --- a/experiments/download_ssl4eo.py +++ b/experiments/download_ssl4eo.py @@ -32,7 +32,7 @@ from collections import OrderedDict from datetime import date, datetime, timedelta from multiprocessing.dummy import Lock, Pool -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import ee import numpy as np @@ -79,7 +79,7 @@ def date2str(date: datetime) -> str: return date.strftime("%Y-%m-%d") -def get_period(date: datetime, days: int = 5) -> Tuple[str, str]: +def get_period(date: datetime, days: int = 5) -> tuple[str, str]: date1 = date - timedelta(days=days / 2) date2 = date + timedelta(days=days / 2) return date2str(date1), date2str(date2) @@ -121,7 +121,7 @@ def get_collection_s1() -> ee.ImageCollection: def filter_collection( - collection: ee.ImageCollection, coords: List[float], period: Tuple[str, str] + collection: ee.ImageCollection, coords: list[float], period: tuple[str, str] ) -> ee.ImageCollection: filtered = collection if period is not None: @@ -137,7 +137,7 @@ def filter_collection( def filter_collection_s1( - collection: ee.ImageCollection, coords: List[float], period: Tuple[str, str] + collection: ee.ImageCollection, coords: list[float], period: tuple[str, str] ) -> ee.ImageCollection: filtered = collection if period is not None: @@ -161,7 +161,7 @@ def filter_collection_s1( def center_crop( - img: np.ndarray[Any, np.dtype[Any]], out_size: Tuple[int, int] + img: np.ndarray[Any, np.dtype[Any]], out_size: tuple[int, int] ) -> np.ndarray[Any, np.dtype[Any]]: image_height, image_width = img.shape[:2] crop_height, crop_width = out_size @@ -171,8 +171,8 @@ def center_crop( def adjust_coords( - coords: List[List[float]], old_size: Tuple[int, int], new_size: Tuple[int, int] -) -> List[List[float]]: + coords: list[list[float]], old_size: tuple[int, int], new_size: tuple[int, int] +) -> list[list[float]]: xres = (coords[1][0] - coords[0][0]) / old_size[1] yres = (coords[0][1] - coords[1][1]) / old_size[0] xoff = int((old_size[1] - new_size[1] + 1) * 0.5) @@ -192,11 +192,11 @@ def get_properties(image: ee.Image) -> Any: def get_patch_s1( collection: ee.ImageCollection, - center_coord: List[float], + center_coord: list[float], radius: float, - bands: List[str], - crop: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + bands: list[str], + crop: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: image = collection.sort("system:time_start", False).first() # get most recent region = ( ee.Geometry.Point(center_coord).buffer(radius).bounds() @@ -233,11 +233,11 @@ def get_patch_s1( def get_patch_s2( collection: ee.ImageCollection, - center_coord: List[float], + center_coord: list[float], radius: float, - bands: Optional[List[str]] = None, - crop: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + bands: Optional[list[str]] = None, + crop: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: if bands is None: bands = RGB_BANDS @@ -277,18 +277,18 @@ def get_patch_s2( def get_random_patches_match( idx: int, - collections: Dict[str, Any], - bands: Dict[str, Any], - crops: Dict[str, Any], - dates: List[Any], + collections: dict[str, Any], + bands: dict[str, Any], + crops: dict[str, Any], + dates: list[Any], radius: int, debug: bool = False, - match_coords: Dict[str, Any] = {}, -) -> Tuple[ - Optional[List[Dict[str, Any]]], - Optional[List[Dict[str, Any]]], - Optional[List[Dict[str, Any]]], - List[float], + match_coords: dict[str, Any] = {}, +) -> tuple[ + Optional[list[dict[str, Any]]], + Optional[list[dict[str, Any]]], + Optional[list[dict[str, Any]]], + list[float], ]: # (lon,lat) of idx patch coords = match_coords[str(idx)] @@ -343,7 +343,7 @@ def get_random_patches_match( def save_geotiff( - img: np.ndarray[Any, np.dtype[Any]], coords: List[List[float]], filename: str + img: np.ndarray[Any, np.dtype[Any]], coords: list[list[float]], filename: str ) -> None: height, width, channels = img.shape xres = (coords[1][0] - coords[0][0]) / width @@ -366,9 +366,9 @@ def save_geotiff( def save_patch( - raster: Dict[str, Any], - coords: List[List[float]], - metadata: Dict[str, Any], + raster: dict[str, Any], + coords: list[list[float]], + metadata: dict[str, Any], path: str, ) -> None: patch_id = metadata["properties"]["system:index"] diff --git a/experiments/run_benchmarks_experiments.py b/experiments/run_benchmarks_experiments.py index 7f3a95c099b..8f10c402e44 100755 --- a/experiments/run_benchmarks_experiments.py +++ b/experiments/run_benchmarks_experiments.py @@ -8,7 +8,6 @@ import os import subprocess import time -from typing import List EPOCH_SIZE = 4096 @@ -34,7 +33,7 @@ ): print(f"\n{i}/{total_num_experiments} -- {time.time() - tic}") tic = time.time() - command: List[str] = [ + command: list[str] = [ "python", "benchmark.py", "--landsat-root", diff --git a/experiments/run_so2sat_byol_experiments.py b/experiments/run_so2sat_byol_experiments.py index 7b4cbf0c98e..0b527d4c927 100755 --- a/experiments/run_so2sat_byol_experiments.py +++ b/experiments/run_so2sat_byol_experiments.py @@ -7,7 +7,6 @@ import os import subprocess from multiprocessing import Process, Queue -from typing import List # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0, 1, 2, 3, 3] @@ -18,7 +17,7 @@ model_options = ["resnet50"] lr_options = [1e-4] loss_options = ["ce"] -weight_options: List[str] = [] # set paths to checkpoint files +weight_options: list[str] = [] # set paths to checkpoint files bands_options = ["s2"] diff --git a/experiments/sample_ssl4eo.py b/experiments/sample_ssl4eo.py index 8e3b7917da9..8a2e31df050 100644 --- a/experiments/sample_ssl4eo.py +++ b/experiments/sample_ssl4eo.py @@ -37,7 +37,6 @@ import os import time import warnings -from typing import Tuple import numpy as np import pandas as pd @@ -69,7 +68,7 @@ def km2deg(kms: float, radius: float = 6371) -> float: return kms / (2.0 * radius * np.pi / 360.0) -def sample_point(cities: pd.DataFrame, std: float) -> Tuple[float, float]: +def sample_point(cities: pd.DataFrame, std: float) -> tuple[float, float]: city = cities.sample() point = (float(city["lng"]), float(city["lat"])) std = km2deg(std) @@ -78,8 +77,8 @@ def sample_point(cities: pd.DataFrame, std: float) -> Tuple[float, float]: def create_bbox( - coords: Tuple[float, float], bbox_size_degree: float -) -> Tuple[float, float, float, float]: + coords: tuple[float, float], bbox_size_degree: float +) -> tuple[float, float, float, float]: lon, lat = coords bbox = ( lon - bbox_size_degree, diff --git a/pyproject.toml b/pyproject.toml index 13452391eeb..8035d71efc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ build-backend = "setuptools.build_meta" [tool.black] -target-version = ["py38", "py39", "py310"] +target-version = ["py39", "py310"] color = true skip_magic_trailing_comma = true @@ -25,7 +25,7 @@ skip_gitignore = true color_output = true [tool.mypy] -python_version = "3.8" +python_version = "3.9" ignore_missing_imports = true show_error_codes = true exclude = "(build|data|dist|docs/src|images|logo|logs|output)/" diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 98233acbfea..0c4b38e5d78 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -3,28 +3,28 @@ setuptools==42.0.0 # install einops==0.3.0 -fiona==1.8.12 +fiona==1.8.19 kornia==0.6.5 lightning==1.8.0 -matplotlib==3.3.0 -numpy==1.17.3 -pillow==6.2.1 -pyproj==2.4.1 -rasterio==1.1.1 +matplotlib==3.3.3 +numpy==1.19.3 +pillow==8.0.0 +pyproj==3.0.0 +rasterio==1.2.0 rtree==1.0.0 -scikit-learn==0.22 +scikit-learn==0.24 segmentation-models-pytorch==0.2.0 -shapely==1.7.0 +shapely==1.7.1 timm==0.4.12 torch==1.12.0 torchmetrics==0.10.0 torchvision==0.13.0 # datasets -h5py==2.9.0 +h5py==3.0.0 laspy==2.0.0 -opencv-python==4.1.2.30 -pandas==0.25.2 +opencv-python==4.4.0.46 +pandas==1.1.3 pycocotools==2.0.1 pyvista==0.25.2 radiant-mlhub==0.3.0 @@ -43,7 +43,7 @@ black[jupyter]==21.8b0 flake8==3.8.0 isort[colors]==5.8.0 pydocstyle[toml]==6.1.0 -pyupgrade==2.4.0 +pyupgrade==2.8.0 # tests mypy==0.900 diff --git a/setup.cfg b/setup.cfg index 7a1b538239d..4815f71c8bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,6 @@ classifiers = Development Status :: 3 - Alpha Intended Audience :: Science/Research Programming Language :: Python :: 3 - Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 License :: OSI Approved :: MIT License @@ -26,30 +25,31 @@ keywords = pytorch, deep learning, machine learning, remote sensing, satellite i install_requires = # einops 0.3+ required for einops.repeat einops>=0.3,<0.7 - # fiona 1.8.12+ required for Python 3.8 wheels - fiona>=1.8.12,<2 + # fiona 1.8.19+ required to fix erroneous warning + # https://github.com/Toblerity/Fiona/issues/986 + fiona>=1.8.19,<2 # kornia 0.6.5+ required due to change in kornia.augmentation API kornia>=0.6.5,<0.7 # lightning 1.8+ is first release lightning>=1.8,<3 - # matplotlib 3.3+ required for (H, W, 1) image support in plt.imshow - matplotlib>=3.3,<4 - # numpy 1.17.3+ required by Python 3.8 wheels - numpy>=1.17.3,<2 - # pillow 6.2.1+ required for Python 3.8 wheels - pillow>=6.2.1,<10 - # pyproj 2.4.1+ required for Python 3.8 wheels - pyproj>=2.4.1,<4 - # rasterio 1.1.1+ required for Python 3.8 wheels - rasterio>=1.1.1,<2 + # matplotlib 3.3.3+ required for Python 3.9 wheels + matplotlib>=3.3.3,<4 + # numpy 1.19.3+ required by Python 3.9 wheels + numpy>=1.19.3,<2 + # pillow 8+ required for Python 3.9 wheels + pillow>=8,<10 + # pyproj 3+ required for Python 3.9 wheels + pyproj>=3,<4 + # rasterio 1.2+ required for Python 3.9 wheels + rasterio>=1.2,<2 # rtree 1+ required for len(index), index & index, index | index rtree>=1,<2 - # scikit-learn 0.22+ required for Python 3.8 wheels - scikit-learn>=0.22,<2 + # scikit-learn 0.24+ required for Python 3.9 wheels + scikit-learn>=0.24,<2 # segmentation-models-pytorch 0.2+ required for smp.losses module segmentation-models-pytorch>=0.2,<0.4 - # shapely 1.7+ required for Python 3.8 wheels - shapely>=1.7,<3 + # shapely 1.7.1+ required for Python 3.9 wheels + shapely>=1.7.1,<3 # timm 0.4.12 required by segmentation-models-pytorch timm>=0.4.12,<0.7 # torch 1.12+ required by torchvision @@ -58,7 +58,7 @@ install_requires = torchmetrics>=0.10,<0.12 # torchvision 0.13+ required for torchvision.models._api.WeightsEnum torchvision>=0.13,<0.16 -python_requires = >=3.8,<4 +python_requires = >=3.9,<4 packages = find: [options.package_data] @@ -69,14 +69,14 @@ include = torchgeo* [options.extras_require] datasets = - # h5py 2.9+ required for Python 3.8 wheels - h5py>=2.9,<4 + # h5py 3+ required for Python 3.9 wheels + h5py>=3,<4 # laspy 2+ required for laspy.read laspy>=2,<3 - # opencv-python 4.1.2+ required for Python 3.8 wheels - opencv-python>=4.1.2,<5 - # pandas 0.25.2+ required for Python 3.8 wheels - pandas>=0.25.2,<3 + # opencv-python 4.4.0.46+ required for Python 3.9 wheels + opencv-python>=4.4.0.46,<5 + # pandas 1.1.3+ required for Python 3.9 wheels + pandas>=1.1.3,<3 # pycocotools 2.0.1+ required for proper dependency declaration pycocotools>=2.0.1,<3 # pyvista 0.25.2 required for wheels @@ -114,8 +114,8 @@ style = isort[colors]>=5.8,<6 # pydocstyle 6.1+ required for pyproject.toml support pydocstyle[toml]>=6.1,<7 - # pyupgrade 2.4+ required for --py38-plus flag - pyupgrade>=2.4,<4 + # pyupgrade 2.8+ required for --py39-plus flag + pyupgrade>=2.8,<4 tests = # mypy 0.900+ required for pyproject.toml support mypy>=0.900,<2 diff --git a/tests/data/enviroatlas/data.py b/tests/data/enviroatlas/data.py index 8a006615bf7..f2816e193a7 100755 --- a/tests/data/enviroatlas/data.py +++ b/tests/data/enviroatlas/data.py @@ -5,7 +5,7 @@ import os import shutil -from typing import Any, Dict +from typing import Any import fiona import fiona.transform @@ -29,7 +29,7 @@ "prior_from_cooccurrences_101_31_no_osm_no_buildings": "prior_no_osm_no_buildings", } -layer_data_profiles: Dict[str, Dict[Any, Any]] = { +layer_data_profiles: dict[str, dict[Any, Any]] = { "a_naip": { "profile": { "driver": "GTiff", @@ -198,7 +198,7 @@ ] -def write_data(path: str, profile: Dict[Any, Any], data_type: Any, vals: Any) -> None: +def write_data(path: str, profile: dict[Any, Any], data_type: Any, vals: Any) -> None: assert all(key in profile for key in ("count", "height", "width", "dtype")) with rasterio.open(path, "w", **profile) as dst: size = (profile["count"], profile["height"], profile["width"]) diff --git a/tests/data/l7irish/data.py b/tests/data/l7irish/data.py index 9c7869d01b0..0daf87ea6fe 100755 --- a/tests/data/l7irish/data.py +++ b/tests/data/l7irish/data.py @@ -6,7 +6,7 @@ import hashlib import os import shutil -from typing import Dict, List, Union +from typing import Union import numpy as np import rasterio @@ -17,7 +17,7 @@ np.random.seed(0) -FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]] +FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] bands = [ "B10.TIF", diff --git a/tests/data/l8biome/data.py b/tests/data/l8biome/data.py index ccca7814c20..7150a1e142c 100755 --- a/tests/data/l8biome/data.py +++ b/tests/data/l8biome/data.py @@ -6,7 +6,7 @@ import hashlib import os import shutil -from typing import Dict, List, Union +from typing import Union import numpy as np import rasterio @@ -17,7 +17,7 @@ np.random.seed(0) -FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]] +FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] bands = [ "B1.TIF", diff --git a/tests/data/reforestree/data.py b/tests/data/reforestree/data.py index 27573cb6191..ee5515e49b1 100755 --- a/tests/data/reforestree/data.py +++ b/tests/data/reforestree/data.py @@ -7,7 +7,6 @@ import hashlib import os import shutil -from typing import List import numpy as np from PIL import Image @@ -25,7 +24,7 @@ } -def create_annotation(path: str, img_paths: List[str]) -> None: +def create_annotation(path: str, img_paths: list[str]) -> None: cols = ["img_path", "xmin", "ymin", "xmax", "ymax", "group", "AGB"] data = [] for img_path in img_paths: diff --git a/tests/data/sentinel1/data.py b/tests/data/sentinel1/data.py index 9e0beef8e1a..dc7bd58f16f 100755 --- a/tests/data/sentinel1/data.py +++ b/tests/data/sentinel1/data.py @@ -4,7 +4,7 @@ # Licensed under the MIT License. import os -from typing import Dict, List, Union +from typing import Union import numpy as np import rasterio @@ -15,7 +15,7 @@ np.random.seed(0) -FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]] +FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] filenames: FILENAME_HIERARCHY = { # ASF DAAC diff --git a/tests/data/sentinel2/data.py b/tests/data/sentinel2/data.py index 2e6ac5dac04..83f50d1d225 100755 --- a/tests/data/sentinel2/data.py +++ b/tests/data/sentinel2/data.py @@ -4,7 +4,7 @@ # Licensed under the MIT License. import os -from typing import Dict, List, Union +from typing import Union import numpy as np import rasterio @@ -15,7 +15,7 @@ np.random.seed(0) -FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]] +FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] filenames: FILENAME_HIERARCHY = { # USGS Earth Explorer diff --git a/tests/data/spacenet/data.py b/tests/data/spacenet/data.py index 0de4a51494b..f51485ce8d5 100644 --- a/tests/data/spacenet/data.py +++ b/tests/data/spacenet/data.py @@ -6,7 +6,7 @@ import os import shutil from collections import OrderedDict -from typing import List, cast +from typing import cast import fiona import numpy as np @@ -63,7 +63,7 @@ datasets = [SpaceNet1, SpaceNet2, SpaceNet3, SpaceNet4, SpaceNet5, SpaceNet6, SpaceNet7] -def create_test_image(img_dir: str, imgs: List[str]) -> List[List[float]]: +def create_test_image(img_dir: str, imgs: list[str]) -> list[list[float]]: """Create test image Args: @@ -99,7 +99,7 @@ def create_test_image(img_dir: str, imgs: List[str]) -> List[List[float]]: def create_test_label( lbldir: str, lblname: str, - coords: List[List[float]], + coords: list[list[float]], det_type: str, empty: bool = False, diff_crs: bool = False, diff --git a/tests/data/ssl4eo/s12/data.py b/tests/data/ssl4eo/s12/data.py index a1798f60756..975855f4be2 100755 --- a/tests/data/ssl4eo/s12/data.py +++ b/tests/data/ssl4eo/s12/data.py @@ -6,7 +6,7 @@ import hashlib import os import shutil -from typing import Dict, List, Union +from typing import Union import numpy as np import rasterio @@ -17,7 +17,7 @@ np.random.seed(0) -FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]] +FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] s1 = ["VH.tif", "VV.tif"] s2c = [ diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index f283aa2a234..b3f1bad3f54 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import Any, Dict +from typing import Any import matplotlib.pyplot as plt import pytest @@ -26,7 +26,7 @@ def __init__(self, split: str = "train", download: bool = False) -> None: self.index.insert(0, (0, 1, 2, 3, 4, 5)) self.res = 1 - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: image = torch.arange(3 * 2 * 2).view(3, 2, 2) return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query} @@ -61,7 +61,7 @@ class CustomNonGeoDataset(NonGeoDataset): def __init__(self, split: str = "train", download: bool = False) -> None: pass - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} def __len__(self) -> int: diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index a69ece9f9b4..ccfe8057caa 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -5,7 +5,6 @@ import shutil from itertools import product from pathlib import Path -from typing import Type import matplotlib.pyplot as plt import pytest @@ -28,7 +27,7 @@ class TestEuroSAT: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> EuroSAT: - base_class: Type[EuroSAT] = request.param[0] + base_class: type[EuroSAT] = request.param[0] split: str = request.param[1] monkeypatch.setattr(torchgeo.datasets.eurosat, "download_url", download_url) md5 = "aa051207b0547daba0ac6af57808d68e" diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index ff0b35c3c3a..259c583ebe2 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -4,7 +4,6 @@ import os import pickle from pathlib import Path -from typing import Dict, List import pytest import torch @@ -39,7 +38,7 @@ def __init__( self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) bounds = BoundingBox(*hit.bounds) @@ -51,12 +50,12 @@ class CustomVectorDataset(VectorDataset): class CustomSentinelDataset(Sentinel2): - all_bands: List[str] = [] + all_bands: list[str] = [] separate_files = False class CustomNonGeoDataset(NonGeoDataset): - def __getitem__(self, index: int) -> Dict[str, int]: + def __getitem__(self, index: int) -> dict[str, int]: return {"index": index} def __len__(self) -> int: diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index 38a2f6d3e16..2d6c42aa89e 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -3,7 +3,6 @@ import os from pathlib import Path -from typing import List import matplotlib.pyplot as plt import pytest @@ -73,19 +72,19 @@ def test_empty_bands(self) -> None: Sentinel1(bands=[]) @pytest.mark.parametrize("bands", [["HH", "HH"], ["HH", "HV", "HH"]]) - def test_duplicate_bands(self, bands: List[str]) -> None: + def test_duplicate_bands(self, bands: list[str]) -> None: with pytest.raises(AssertionError, match="'bands' contains duplicate bands"): Sentinel1(bands=bands) @pytest.mark.parametrize("bands", [["HH_HV"], ["HH", "HV", "HH_HV"]]) - def test_invalid_bands(self, bands: List[str]) -> None: + def test_invalid_bands(self, bands: list[str]) -> None: with pytest.raises(AssertionError, match="invalid band 'HH_HV'"): Sentinel1(bands=bands) @pytest.mark.parametrize( "bands", [["HH", "VV"], ["HH", "VH"], ["VV", "HV"], ["HH", "HV", "VV", "VH"]] ) - def test_dual_transmit(self, bands: List[str]) -> None: + def test_dual_transmit(self, bands: list[str]) -> None: with pytest.raises(AssertionError, match="'bands' cannot contain both "): Sentinel1(bands=bands) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 6e7dfe31406..f2f75b56659 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from collections.abc import Sequence from math import floor, isclose -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, Union import pytest from rasterio.crs import CRS @@ -38,7 +39,7 @@ def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool: class CustomGeoDataset(GeoDataset): def __init__( self, - items: List[Tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")], + items: list[tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")], crs: CRS = CRS.from_epsg(3005), res: float = 1, ) -> None: @@ -48,7 +49,7 @@ def __init__( self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) return {"content": hit.object} @@ -254,7 +255,7 @@ def test_roi_split() -> None: ], ) def test_time_series_split( - lengths: Sequence[Union[Tuple[int, int], int, float]], + lengths: Sequence[Union[tuple[int, int], int, float]], expected_lengths: Sequence[int], ) -> None: ds = CustomGeoDataset( diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index e89794a9726..877caedf1fc 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -11,7 +11,7 @@ import sys from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any import numpy as np import pytest @@ -221,7 +221,7 @@ def test_iter(self) -> None: ) def test_contains( self, - test_input: Tuple[float, float, float, float, float, float], + test_input: tuple[float, float, float, float, float, float], expected: bool, ) -> None: bbox1 = BoundingBox(0, 1, 0, 1, 0, 1) @@ -256,8 +256,8 @@ def test_contains( ) def test_or( self, - test_input: Tuple[float, float, float, float, float, float], - expected: Tuple[float, float, float, float, float, float], + test_input: tuple[float, float, float, float, float, float], + expected: tuple[float, float, float, float, float, float], ) -> None: bbox1 = BoundingBox(0, 1, 0, 1, 0, 1) bbox2 = BoundingBox(*test_input) @@ -290,8 +290,8 @@ def test_or( ) def test_and_intersection( self, - test_input: Tuple[float, float, float, float, float, float], - expected: Tuple[float, float, float, float, float, float], + test_input: tuple[float, float, float, float, float, float], + expected: tuple[float, float, float, float, float, float], ) -> None: bbox1 = BoundingBox(0, 1, 0, 1, 0, 1) bbox2 = BoundingBox(*test_input) @@ -309,7 +309,7 @@ def test_and_intersection( ], ) def test_and_no_intersection( - self, test_input: Tuple[float, float, float, float, float, float] + self, test_input: tuple[float, float, float, float, float, float] ) -> None: bbox1 = BoundingBox(0, 1, 0, 1, 0, 1) bbox2 = BoundingBox(*test_input) @@ -334,7 +334,7 @@ def test_and_no_intersection( ], ) def test_area( - self, test_input: Tuple[float, float, float, float, float, float], expected: int + self, test_input: tuple[float, float, float, float, float, float], expected: int ) -> None: bbox = BoundingBox(*test_input) assert bbox.area == expected @@ -354,7 +354,7 @@ def test_area( ], ) def test_volume( - self, test_input: Tuple[float, float, float, float, float, float], expected: int + self, test_input: tuple[float, float, float, float, float, float], expected: int ) -> None: bbox = BoundingBox(*test_input) assert bbox.volume == expected @@ -387,7 +387,7 @@ def test_volume( ) def test_intersects( self, - test_input: Tuple[float, float, float, float, float, float], + test_input: tuple[float, float, float, float, float, float], expected: bool, ) -> None: bbox1 = BoundingBox(0, 1, 0, 1, 0, 1) @@ -405,9 +405,9 @@ def test_split( self, proportion: float, horizontal: bool, - expected: Tuple[ - Tuple[float, float, float, float, float, float], - Tuple[float, float, float, float, float, float], + expected: tuple[ + tuple[float, float, float, float, float, float], + tuple[float, float, float, float, float, float], ], ) -> None: bbox = BoundingBox(0, 1, 0, 1, 0, 1) @@ -512,13 +512,13 @@ def test_disambiguate_timestamp( class TestCollateFunctionsMatchingKeys: @pytest.fixture(scope="class") - def samples(self) -> List[Dict[str, Any]]: + def samples(self) -> list[dict[str, Any]]: return [ {"image": torch.tensor([1, 2, 0]), "crs": CRS.from_epsg(2000)}, {"image": torch.tensor([0, 0, 3]), "crs": CRS.from_epsg(2001)}, ] - def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: sample = stack_samples(samples) assert sample["image"].size() == torch.Size([2, 3]) assert torch.allclose(sample["image"], torch.tensor([[1, 2, 0], [0, 0, 3]])) @@ -529,13 +529,13 @@ def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: assert torch.allclose(samples[i]["image"], new_samples[i]["image"]) assert samples[i]["crs"] == new_samples[i]["crs"] - def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: sample = concat_samples(samples) assert sample["image"].size() == torch.Size([6]) assert torch.allclose(sample["image"], torch.tensor([1, 2, 0, 0, 0, 3])) assert sample["crs"] == CRS.from_epsg(2000) - def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: sample = merge_samples(samples) assert sample["image"].size() == torch.Size([3]) assert torch.allclose(sample["image"], torch.tensor([1, 2, 3])) @@ -544,13 +544,13 @@ def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None: class TestCollateFunctionsDifferingKeys: @pytest.fixture(scope="class") - def samples(self) -> List[Dict[str, Any]]: + def samples(self) -> list[dict[str, Any]]: return [ {"image": torch.tensor([1, 2, 0]), "crs1": CRS.from_epsg(2000)}, {"mask": torch.tensor([0, 0, 3]), "crs2": CRS.from_epsg(2001)}, ] - def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: sample = stack_samples(samples) assert sample["image"].size() == torch.Size([1, 3]) assert sample["mask"].size() == torch.Size([1, 3]) @@ -565,7 +565,7 @@ def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: assert torch.allclose(samples[1]["mask"], new_samples[0]["mask"]) assert samples[1]["crs2"] == new_samples[0]["crs2"] - def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: sample = concat_samples(samples) assert sample["image"].size() == torch.Size([3]) assert sample["mask"].size() == torch.Size([3]) @@ -574,7 +574,7 @@ def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: assert sample["crs1"] == CRS.from_epsg(2000) assert sample["crs2"] == CRS.from_epsg(2001) - def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: sample = merge_samples(samples) assert sample["image"].size() == torch.Size([3]) assert sample["mask"].size() == torch.Size([3]) diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index f17b5f19b23..c37de00db81 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from pathlib import Path -from typing import Any, Dict +from typing import Any import pytest import timm @@ -15,8 +15,8 @@ from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 -def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: - state_dict: Dict[str, Any] = torch.load(url) +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) return state_dict diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py index 88124584488..5bee25ff1d0 100644 --- a/tests/models/test_vit.py +++ b/tests/models/test_vit.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from pathlib import Path -from typing import Any, Dict +from typing import Any import pytest import timm @@ -15,8 +15,8 @@ from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224 -def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: - state_dict: Dict[str, Any] = torch.load(url) +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) return state_dict diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 9dc32394d3d..05b0cf0b3a8 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. import math +from collections.abc import Iterator from itertools import product -from typing import Dict, Iterator, List import pytest from _pytest.fixtures import SubRequest @@ -18,7 +18,7 @@ class CustomBatchGeoSampler(BatchGeoSampler): def __init__(self) -> None: pass - def __iter__(self) -> Iterator[List[BoundingBox]]: + def __iter__(self) -> Iterator[list[BoundingBox]]: for i in range(len(self)): yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))] @@ -32,7 +32,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {"index": query} diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 229e81dcb2e..c17dd7da34a 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. import math +from collections.abc import Iterator from itertools import product -from typing import Dict, Iterator import pytest from _pytest.fixtures import SubRequest @@ -39,7 +39,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {"index": query} diff --git a/tests/samplers/test_utils.py b/tests/samplers/test_utils.py index 9905f99a3ed..20d828827ab 100644 --- a/tests/samplers/test_utils.py +++ b/tests/samplers/test_utils.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import math -from typing import Optional, Tuple, Union +from typing import Optional, Union import pytest @@ -10,7 +10,7 @@ from torchgeo.samplers import tile_to_chips from torchgeo.samplers.utils import _to_tuple -MAYBE_TUPLE = Union[float, Tuple[float, float]] +MAYBE_TUPLE = Union[float, tuple[float, float]] @pytest.mark.parametrize( diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index 1015f527ced..614babe7102 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -4,7 +4,6 @@ import os from collections import OrderedDict from pathlib import Path -from typing import Dict import pytest import torch @@ -29,13 +28,13 @@ def model() -> Module: @pytest.fixture(scope="package") -def state_dict(model: Module) -> Dict[str, Tensor]: +def state_dict(model: Module) -> dict[str, Tensor]: return model.state_dict() @pytest.fixture(params=["model", "backbone"]) def checkpoint( - state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path + state_dict: dict[str, Tensor], request: SubRequest, tmp_path: Path ) -> str: if request.param == "model": state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()}) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 7fdf90f7a9b..67bfef7d5f1 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any, Dict, Type, cast +from typing import Any, cast import pytest import timm @@ -30,8 +30,8 @@ from .test_segmentation import SegmentationTestModel -def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: - state_dict: Dict[str, Any] = torch.load(url) +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) return state_dict @@ -67,12 +67,12 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, - classname: Type[LightningDataModule], + classname: type[LightningDataModule], fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) + conf_dict = cast(dict[str, dict[str, Any]], conf_dict) if name.startswith("seco"): monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) @@ -100,7 +100,7 @@ def test_trainer( trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[str, Any]: + def model_kwargs(self) -> dict[str, Any]: return { "backbone": "resnet18", "in_channels": 13, @@ -133,13 +133,13 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): BYOLTask(**model_kwargs) def test_weight_enum( - self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum ) -> None: model_kwargs["backbone"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] @@ -147,7 +147,7 @@ def test_weight_enum( BYOLTask(**model_kwargs) def test_weight_str( - self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum ) -> None: model_kwargs["backbone"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] @@ -156,7 +156,7 @@ def test_weight_str( @pytest.mark.slow def test_weight_enum_download( - self, model_kwargs: Dict[str, Any], weights: WeightsEnum + self, model_kwargs: dict[str, Any], weights: WeightsEnum ) -> None: model_kwargs["backbone"] = weights.meta["model"] model_kwargs["in_channels"] = weights.meta["in_chans"] @@ -165,7 +165,7 @@ def test_weight_enum_download( @pytest.mark.slow def test_weight_str_download( - self, model_kwargs: Dict[str, Any], weights: WeightsEnum + self, model_kwargs: dict[str, Any], weights: WeightsEnum ) -> None: model_kwargs["backbone"] = weights.meta["model"] model_kwargs["in_channels"] = weights.meta["in_chans"] diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 98f6fff1d20..8a45ff7ce46 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any, Dict, Type, cast +from typing import Any, cast import pytest import timm @@ -62,8 +62,8 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: - state_dict: Dict[str, Any] = torch.load(url) +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) return state_dict @@ -88,7 +88,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, - classname: Type[LightningDataModule], + classname: type[LightningDataModule], fast_dev_run: bool, ) -> None: if name.startswith("so2sat"): @@ -96,7 +96,7 @@ def test_trainer( conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) + conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -125,7 +125,7 @@ def test_trainer( pass @pytest.fixture - def model_kwargs(self) -> Dict[str, Any]: + def model_kwargs(self) -> dict[str, Any]: return { "model": "resnet18", "in_channels": 13, @@ -158,13 +158,13 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) def test_weight_enum( - self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum ) -> None: model_kwargs["model"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] @@ -173,7 +173,7 @@ def test_weight_enum( ClassificationTask(**model_kwargs) def test_weight_str( - self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum ) -> None: model_kwargs["model"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] @@ -183,7 +183,7 @@ def test_weight_str( @pytest.mark.slow def test_weight_enum_download( - self, model_kwargs: Dict[str, Any], weights: WeightsEnum + self, model_kwargs: dict[str, Any], weights: WeightsEnum ) -> None: model_kwargs["model"] = weights.meta["model"] model_kwargs["in_channels"] = weights.meta["in_chans"] @@ -192,21 +192,21 @@ def test_weight_enum_download( @pytest.mark.slow def test_weight_str_download( - self, model_kwargs: Dict[str, Any], weights: WeightsEnum + self, model_kwargs: dict[str, Any], weights: WeightsEnum ) -> None: model_kwargs["model"] = weights.meta["model"] model_kwargs["in_channels"] = weights.meta["in_chans"] model_kwargs["weights"] = str(weights) ClassificationTask(**model_kwargs) - def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: + def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool + self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(EuroSATDataModule, "plot", plot) datamodule = EuroSATDataModule( @@ -221,7 +221,7 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictClassificationDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) @@ -248,12 +248,12 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, - classname: Type[LightningDataModule], + classname: type[LightningDataModule], fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) + conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -282,7 +282,7 @@ def test_trainer( pass @pytest.fixture - def model_kwargs(self) -> Dict[str, Any]: + def model_kwargs(self) -> dict[str, Any]: return { "model": "resnet18", "in_channels": 14, @@ -291,14 +291,14 @@ def model_kwargs(self) -> Dict[str, Any]: "weights": None, } - def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: + def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool + self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(BigEarthNetDataModule, "plot", plot) datamodule = BigEarthNetDataModule( @@ -313,7 +313,7 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictMultiLabelClassificationDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 34262829a47..48b8b0d4579 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Type, cast +from typing import Any, cast import pytest import torch @@ -65,13 +65,13 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, - classname: Type[LightningDataModule], + classname: type[LightningDataModule], model_name: str, fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(dict[Any, dict[Any, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -109,27 +109,27 @@ def test_trainer( pass @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> dict[Any, Any]: return {"model": "faster-rcnn", "backbone": "resnet18", "num_classes": 2} - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): ObjectDetectionTask(**model_kwargs) - def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_backbone(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["backbone"] = "invalid_backbone" match = "Backbone type 'invalid_backbone' is not valid." with pytest.raises(ValueError, match=match): ObjectDetectionTask(**model_kwargs) - def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: + def test_non_pretrained_backbone(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["pretrained"] = False ObjectDetectionTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool + self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot) datamodule = NASAMarineDebrisDataModule( @@ -144,7 +144,7 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictObjectDetectionDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 4b981c37ce0..e3c53d19f80 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any, Dict, Type, cast +from typing import Any, cast import pytest import timm @@ -37,8 +37,8 @@ def setup(self, stage: str) -> None: self.predict_dataset = TropicalCyclone(split="test", **self.kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: - state_dict: Dict[str, Any] = torch.load(url) +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) return state_dict @@ -55,11 +55,11 @@ class TestRegressionTask: ], ) def test_trainer( - self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + self, name: str, classname: type[LightningDataModule], fast_dev_run: bool ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) + conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -89,7 +89,7 @@ def test_trainer( pass @pytest.fixture - def model_kwargs(self) -> Dict[str, Any]: + def model_kwargs(self) -> dict[str, Any]: return { "model": "resnet18", "weights": None, @@ -121,13 +121,13 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): RegressionTask(**model_kwargs) def test_weight_enum( - self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum ) -> None: model_kwargs["model"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] @@ -136,7 +136,7 @@ def test_weight_enum( RegressionTask(**model_kwargs) def test_weight_str( - self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum ) -> None: model_kwargs["model"] = mocked_weights.meta["model"] model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] @@ -146,7 +146,7 @@ def test_weight_str( @pytest.mark.slow def test_weight_enum_download( - self, model_kwargs: Dict[str, Any], weights: WeightsEnum + self, model_kwargs: dict[str, Any], weights: WeightsEnum ) -> None: model_kwargs["model"] = weights.meta["model"] model_kwargs["in_channels"] = weights.meta["in_chans"] @@ -155,7 +155,7 @@ def test_weight_enum_download( @pytest.mark.slow def test_weight_str_download( - self, model_kwargs: Dict[str, Any], weights: WeightsEnum + self, model_kwargs: dict[str, Any], weights: WeightsEnum ) -> None: model_kwargs["model"] = weights.meta["model"] model_kwargs["in_channels"] = weights.meta["in_chans"] @@ -163,7 +163,7 @@ def test_weight_str_download( RegressionTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool + self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot) datamodule = TropicalCycloneDataModule( @@ -178,7 +178,7 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictRegressionDataModule( root="tests/data/cyclone", batch_size=1, num_workers=0 ) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 0a6a7f7aa07..80a3404e68a 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Type, cast +from typing import Any, cast import pytest import segmentation_models_pytorch as smp @@ -83,7 +83,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, - classname: Type[LightningDataModule], + classname: type[LightningDataModule], fast_dev_run: bool, ) -> None: if name == "naipchesapeake": @@ -95,7 +95,7 @@ def test_trainer( conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(dict[Any, dict[Any, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -125,7 +125,7 @@ def test_trainer( pass @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> dict[Any, Any]: return { "model": "unet", "backbone": "resnet18", @@ -136,25 +136,25 @@ def model_kwargs(self) -> Dict[Any, Any]: "ignore_index": 0, } - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): SemanticSegmentationTask(**model_kwargs) - def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_loss(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): SemanticSegmentationTask(**model_kwargs) - def test_invalid_ignoreindex(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_ignoreindex(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["ignore_index"] = "0" match = "ignore_index must be an int or None" with pytest.raises(ValueError, match=match): SemanticSegmentationTask(**model_kwargs) - def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: + def test_ignoreindex_with_jaccard(self, model_kwargs: dict[Any, Any]) -> None: model_kwargs["loss"] = "jaccard" model_kwargs["ignore_index"] = 0 match = "ignore_index has no effect on training when loss='jaccard'" @@ -162,7 +162,7 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: SemanticSegmentationTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool + self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool ) -> None: model_kwargs["in_channels"] = 15 monkeypatch.setattr(SEN12MSDataModule, "plot", plot) diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index c58cf0887d5..1c1f74e2561 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import Dict - import pytest import torch from torch import Tensor @@ -27,7 +25,7 @@ @pytest.fixture -def sample() -> Dict[str, Tensor]: +def sample() -> dict[str, Tensor]: return { "image": torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), "mask": torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), @@ -35,14 +33,14 @@ def sample() -> Dict[str, Tensor]: @pytest.fixture -def batch() -> Dict[str, Tensor]: +def batch() -> dict[str, Tensor]: return { "image": torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), "mask": torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), } -def test_append_index_sample(sample: Dict[str, Tensor]) -> None: +def test_append_index_sample(sample: dict[str, Tensor]) -> None: c, h, w = sample["image"].shape aug = AugmentationSequential( AppendNormalizedDifferenceIndex(index_a=0, index_b=1), @@ -52,7 +50,7 @@ def test_append_index_sample(sample: Dict[str, Tensor]) -> None: assert output["image"].shape == (1, c + 1, h, w) -def test_append_index_batch(batch: Dict[str, Tensor]) -> None: +def test_append_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch["image"].shape aug = AugmentationSequential( AppendNormalizedDifferenceIndex(index_a=0, index_b=1), @@ -62,7 +60,7 @@ def test_append_index_batch(batch: Dict[str, Tensor]) -> None: assert output["image"].shape == (b, c + 1, h, w) -def test_append_triband_index_batch(batch: Dict[str, Tensor]) -> None: +def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch["image"].shape aug = AugmentationSequential( AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2), @@ -87,7 +85,7 @@ def test_append_triband_index_batch(batch: Dict[str, Tensor]) -> None: ], ) def test_append_normalized_difference_indices( - sample: Dict[str, Tensor], index: AppendNormalizedDifferenceIndex + sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex ) -> None: c, h, w = sample["image"].shape aug = AugmentationSequential(index(0, 1), data_keys=["image", "mask"]) @@ -97,7 +95,7 @@ def test_append_normalized_difference_indices( @pytest.mark.parametrize("index", [AppendGBNDVI, AppendGRNDVI, AppendRBNDVI]) def test_append_tri_band_normalized_difference_indices( - sample: Dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex + sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex ) -> None: c, h, w = sample["image"].shape aug = AugmentationSequential(index(0, 1, 2), data_keys=["image", "mask"]) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index a7db809e0b1..b4b6d8a54fc 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import Dict - import kornia.augmentation as K import pytest import torch @@ -20,7 +18,7 @@ @pytest.fixture -def batch_gray() -> Dict[str, Tensor]: +def batch_gray() -> dict[str, Tensor]: return { "image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), @@ -30,7 +28,7 @@ def batch_gray() -> Dict[str, Tensor]: @pytest.fixture -def batch_rgb() -> Dict[str, Tensor]: +def batch_rgb() -> dict[str, Tensor]: return { "image": torch.tensor( [ @@ -49,7 +47,7 @@ def batch_rgb() -> Dict[str, Tensor]: @pytest.fixture -def batch_multispectral() -> Dict[str, Tensor]: +def batch_multispectral() -> dict[str, Tensor]: return { "image": torch.tensor( [ @@ -69,14 +67,14 @@ def batch_multispectral() -> Dict[str, Tensor]: } -def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> None: +def assert_matching(output: dict[str, Tensor], expected: dict[str, Tensor]) -> None: for key in expected: err = f"output[{key}] != expected[{key}]" equal = torch.allclose(output[key], expected[key]) assert equal, err -def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None: +def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { "image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), @@ -90,7 +88,7 @@ def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None: assert_matching(output, expected) -def test_augmentation_sequential_rgb(batch_rgb: Dict[str, Tensor]) -> None: +def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: expected = { "image": torch.tensor( [ @@ -114,7 +112,7 @@ def test_augmentation_sequential_rgb(batch_rgb: Dict[str, Tensor]) -> None: def test_augmentation_sequential_multispectral( - batch_multispectral: Dict[str, Tensor] + batch_multispectral: dict[str, Tensor] ) -> None: expected = { "image": torch.tensor( @@ -141,7 +139,7 @@ def test_augmentation_sequential_multispectral( def test_augmentation_sequential_image_only( - batch_multispectral: Dict[str, Tensor] + batch_multispectral: dict[str, Tensor] ) -> None: expected = { "image": torch.tensor( @@ -168,7 +166,7 @@ def test_augmentation_sequential_image_only( def test_sequential_transforms_augmentations( - batch_multispectral: Dict[str, Tensor] + batch_multispectral: dict[str, Tensor] ) -> None: expected = { "image": torch.tensor( diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 50e699aa188..40d90942652 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -3,7 +3,7 @@ """Chesapeake Bay High-Resolution Land Cover Project datamodule.""" -from typing import Any, Dict, List +from typing import Any import kornia.augmentation as K import torch.nn as nn @@ -29,7 +29,7 @@ def __init__(self, aug: nn.Module) -> None: super().__init__() self.aug = aug - def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply the augmentation. Args: @@ -58,9 +58,9 @@ class ChesapeakeCVPRDataModule(GeoDataModule): def __init__( self, - train_splits: List[str], - val_splits: List[str], - test_splits: List[str], + train_splits: list[str], + val_splits: list[str], + test_splits: list[str], batch_size: int = 64, patch_size: int = 256, length: int = 1000, @@ -158,8 +158,8 @@ def setup(self, stage: str) -> None: ) def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index fdf00265016..3bfa594a22f 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -3,7 +3,7 @@ """DeepGlobe Land Cover Classification Challenge datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K @@ -24,7 +24,7 @@ class DeepGlobeLandCoverDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 64, + patch_size: Union[tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 3c4a8dd1e83..ea78a0ebb79 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -3,7 +3,7 @@ """ETCI 2021 datamodule.""" -from typing import Any, Dict +from typing import Any import torch from torch import Tensor @@ -63,8 +63,8 @@ def setup(self, stage: str) -> None: self.predict_dataset = ETCI2021(split="test", **self.kwargs) def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 0f5711ec151..dcd2a199eec 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -3,7 +3,7 @@ """Base classes for all :mod:`torchgeo` data modules.""" -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import kornia.augmentation as K import matplotlib.pyplot as plt @@ -34,9 +34,9 @@ class GeoDataModule(LightningDataModule): # type: ignore[misc] def __init__( self, - dataset_class: Type[GeoDataset], + dataset_class: type[GeoDataset], batch_size: int = 1, - patch_size: Union[int, Tuple[int, int]] = 64, + patch_size: Union[int, tuple[int, int]] = 64, length: int = 1000, num_workers: int = 0, **kwargs: Any, @@ -61,11 +61,11 @@ def __init__( self.kwargs = kwargs # Datasets - self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None # Samplers self.sampler: Optional[GeoSampler] = None @@ -91,7 +91,7 @@ def __init__( self.collate_fn = stack_samples # Data augmentation - Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] + Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] ) @@ -142,7 +142,7 @@ def setup(self, stage: str) -> None: self.test_dataset, self.patch_size, self.patch_size ) - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. Returns: @@ -172,7 +172,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" raise MisconfigurationException(msg) - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. Returns: @@ -202,7 +202,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" raise MisconfigurationException(msg) - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. Returns: @@ -232,7 +232,7 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" raise MisconfigurationException(msg) - def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. Returns: @@ -263,8 +263,8 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: raise MisconfigurationException(msg) def transfer_batch_to_device( - self, batch: Dict[str, Tensor], device: torch.device, dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], device: torch.device, dataloader_idx: int + ) -> dict[str, Tensor]: """Transfer batch to device. Defines how custom data types are moved to the target device. @@ -285,8 +285,8 @@ def transfer_batch_to_device( return batch def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: @@ -340,7 +340,7 @@ class NonGeoDataModule(LightningDataModule): # type: ignore[misc] def __init__( self, - dataset_class: Type[NonGeoDataset], + dataset_class: type[NonGeoDataset], batch_size: int = 1, num_workers: int = 0, **kwargs: Any, @@ -361,11 +361,11 @@ def __init__( self.kwargs = kwargs # Datasets - self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None # Data loaders self.train_batch_size: Optional[int] = None @@ -377,7 +377,7 @@ def __init__( self.collate_fn = default_collate # Data augmentation - Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] + Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] ) @@ -419,7 +419,7 @@ def setup(self, stage: str) -> None: split="test", **self.kwargs ) - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. Returns: @@ -442,7 +442,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" raise MisconfigurationException(msg) - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. Returns: @@ -465,7 +465,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" raise MisconfigurationException(msg) - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. Returns: @@ -488,7 +488,7 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" raise MisconfigurationException(msg) - def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. Returns: @@ -512,8 +512,8 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: raise MisconfigurationException(msg) def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 1297741eb2f..8a40bddb03a 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -3,7 +3,7 @@ """GID-15 datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K @@ -26,7 +26,7 @@ class GID15DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 64, + patch_size: Union[tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 81503810011..698273b9485 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,7 +3,7 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K @@ -27,7 +27,7 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 64, + patch_size: Union[tuple[int, int], int] = 64, num_workers: int = 0, val_split_pct: float = 0.1, test_split_pct: float = 0.1, diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index 5e8c89f29a0..51453654bab 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -3,7 +3,7 @@ """L7 Irish datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import torch @@ -24,7 +24,7 @@ class L7IrishDataModule(GeoDataModule): def __init__( self, batch_size: int = 1, - patch_size: Union[int, Tuple[int, int]] = 32, + patch_size: Union[int, tuple[int, int]] = 32, length: int = 5, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index a20e4eef312..88364ec134f 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -3,7 +3,7 @@ """L8 Biome datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import torch @@ -24,7 +24,7 @@ class L8BiomeDataModule(GeoDataModule): def __init__( self, batch_size: int = 1, - patch_size: Union[int, Tuple[int, int]] = 32, + patch_size: Union[int, tuple[int, int]] = 32, length: int = 5, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 63ef8c8501e..b3361557dbd 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -3,7 +3,7 @@ """National Agriculture Imagery Program (NAIP) datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K import matplotlib.pyplot as plt @@ -23,7 +23,7 @@ class NAIPChesapeakeDataModule(GeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, Tuple[int, int]] = 256, + patch_size: Union[int, tuple[int, int]] = 256, length: int = 1000, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index cdcab7f7e9c..f740df5eb76 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -3,7 +3,7 @@ """NASA Marine Debris datamodule.""" -from typing import Any, Dict, List +from typing import Any import torch from torch import Tensor @@ -13,7 +13,7 @@ from .utils import dataset_split -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: +def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """Custom object detection collate fn to handle variable boxes. Args: @@ -22,7 +22,7 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: Returns: batch dict output """ - output: Dict[str, Any] = {} + output: dict[str, Any] = {} output["image"] = torch.stack([sample["image"] for sample in batch]) output["boxes"] = [sample["boxes"] for sample in batch] output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch] diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 624a47c615f..748c4038091 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -3,7 +3,7 @@ """OSCD datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K import torch @@ -65,7 +65,7 @@ class OSCDDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 64, + patch_size: Union[tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index f22558a5908..0bd712d4607 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -3,7 +3,7 @@ """Potsdam datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K @@ -26,7 +26,7 @@ class Potsdam2DDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 64, + patch_size: Union[tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index d412c93d978..e07b3d01e5a 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -3,7 +3,7 @@ """SEN12MS datamodule.""" -from typing import Any, Dict +from typing import Any import torch from sklearn.model_selection import GroupShuffleSplit @@ -99,8 +99,8 @@ def setup(self, stage: str) -> None: self.test_dataset = SEN12MS(split="test", **self.kwargs) def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 90a6f6f5ac6..3a3b7531cba 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -3,7 +3,7 @@ """SpaceNet datamodules.""" -from typing import Any, Dict +from typing import Any import kornia.augmentation as K from torch import Tensor @@ -73,8 +73,8 @@ def setup(self, stage: str) -> None: ) def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index b1df01721c3..832ab1d0318 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -3,7 +3,7 @@ """Common datamodule utilities.""" -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Generator from torch.utils.data import Subset, TensorDataset, random_split @@ -20,7 +20,7 @@ def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float, test_pct: Optional[float] = None, -) -> List[Subset[Any]]: +) -> list[Subset[Any]]: """Split a torch Dataset into train/val/test sets. If ``test_pct`` is not set then only train and validation splits are returned. diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 1128ea76655..883ff7781b9 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,7 +3,7 @@ """Vaihingen datamodule.""" -from typing import Any, Tuple, Union +from typing import Any, Union import kornia.augmentation as K @@ -26,7 +26,7 @@ class Vaihingen2DDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 64, + patch_size: Union[tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index 83eea56d613..e8d75b13fcf 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional, cast +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -86,7 +86,7 @@ class ADVANCE(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -118,9 +118,9 @@ def __init__( self.files = self._load_files(self.root) self.classes = sorted({f["cls"] for f in self.files}) - self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -149,7 +149,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> List[Dict[str, str]]: + def _load_files(self, root: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -233,7 +233,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index dccc69440e2..4b7d4ac6462 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -6,7 +6,7 @@ import glob import json import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -61,7 +61,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, cache: bool = True, ) -> None: @@ -126,7 +126,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 409953ae943..0bc1d886457 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -49,7 +49,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -94,7 +94,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 34dc2e734ae..371c1985ab3 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -6,7 +6,7 @@ import json import os from functools import lru_cache -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -174,8 +174,8 @@ def __init__( root: str = "data", chip_size: int = 256, stride: int = 128, - bands: Tuple[str, ...] = all_bands, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + bands: tuple[str, ...] = all_bands, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -228,7 +228,7 @@ def __init__( ]: self.chips_metadata.append((y, x)) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -267,7 +267,7 @@ def __len__(self) -> int: """ return len(self.chips_metadata) - def _validate_bands(self, bands: Tuple[str, ...]) -> None: + def _validate_bands(self, bands: tuple[str, ...]) -> None: """Validate list of bands. Args: @@ -284,8 +284,8 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None: @lru_cache(maxsize=128) def _load_all_imagery( - self, bands: Tuple[str, ...] = all_bands - ) -> Tuple[Tensor, rasterio.Affine, CRS]: + self, bands: tuple[str, ...] = all_bands + ) -> tuple[Tensor, rasterio.Affine, CRS]: """Load all the imagery (across time) for the dataset. Optionally allows for subsetting of the bands that are loaded. @@ -318,8 +318,8 @@ def _load_all_imagery( @lru_cache(maxsize=128) def _load_single_scene( - self, date: str, bands: Tuple[str, ...] - ) -> Tuple[Tensor, rasterio.Affine, CRS]: + self, date: str, bands: tuple[str, ...] + ) -> tuple[Tensor, rasterio.Affine, CRS]: """Load the imagery for a single date. Optionally allows for subsetting of the bands that are loaded. @@ -427,7 +427,7 @@ def _download(self, api_key: Optional[str] = None) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, time_step: int = 0, suptitle: Optional[str] = None, diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 79365d9a717..2118265c569 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -6,7 +6,7 @@ import glob import json import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -269,7 +269,7 @@ def __init__( split: str = "train", bands: str = "all", num_classes: int = 19, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -299,7 +299,7 @@ def __init__( self._verify() self.folders = self._load_folders() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -310,7 +310,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: """ image = self._load_image(index) label = self._load_target(index) - sample: Dict[str, Tensor] = {"image": image, "label": label} + sample: dict[str, Tensor] = {"image": image, "label": label} if self.transforms is not None: sample = self.transforms(sample) @@ -325,7 +325,7 @@ def __len__(self) -> int: """ return len(self.folders) - def _load_folders(self) -> List[Dict[str, str]]: + def _load_folders(self) -> list[dict[str, str]]: """Load folder paths. Returns: @@ -348,7 +348,7 @@ def _load_folders(self) -> List[Dict[str, str]]: ] return folders - def _load_paths(self, index: int) -> List[str]: + def _load_paths(self, index: int) -> list[str]: """Load paths to band files. Args: @@ -513,7 +513,7 @@ def _extract(self, filepath: str) -> None: def _onehot_labels_to_names( self, label_mask: "np.typing.NDArray[np.bool_]" - ) -> List[str]: + ) -> list[str]: """Gets a list of class names given a label mask. Args: @@ -530,7 +530,7 @@ def _onehot_labels_to_names( def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index 536415b0f24..625357c12ce 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -4,7 +4,7 @@ """Canadian Building Footprints dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -62,7 +62,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: float = 0.00001, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -124,7 +124,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index b5743d5fac2..6c2c46f977d 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -325,7 +325,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -401,7 +401,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 43c829d6671..a472c9b20bf 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -6,7 +6,8 @@ import abc import os import sys -from typing import Any, Callable, Dict, List, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Callable, Optional, cast import fiona import matplotlib.pyplot as plt @@ -90,7 +91,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -169,7 +170,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -533,7 +534,7 @@ def __init__( root: str = "data", splits: Sequence[str] = ["de-train"], layers: Sequence[str] = ["naip-new", "lc"], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -612,7 +613,7 @@ def __init__( }, ) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -625,7 +626,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[Dict[str, str]], [hit.object for hit in hits]) + filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} @@ -739,7 +740,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index c11d2d463b9..a3c344bd944 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -5,7 +5,8 @@ import json import os -from typing import Any, Callable, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -103,7 +104,7 @@ def __init__( root: str = "data", split: str = "train", bands: Sequence[str] = band_names, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -150,7 +151,7 @@ def __len__(self) -> int: """ return len(self.chip_paths) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Returns a sample from dataset. Args: @@ -161,7 +162,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: """ image = self._load_image(index) label = self._load_target(index) - sample: Dict[str, Tensor] = {"image": image, "mask": label} + sample: dict[str, Tensor] = {"image": image, "mask": label} if self.transforms is not None: sample = self.transforms(sample) @@ -219,7 +220,7 @@ def _read_json_data(object_path: str) -> Any: json_data = json.load(read_contents) return json_data - def _load_items(self, item_json: str) -> Dict[str, List[str]]: + def _load_items(self, item_json: str) -> dict[str, list[str]]: """Loads the label item and corresponding source items. Args: @@ -262,7 +263,7 @@ def _load_items(self, item_json: str) -> Dict[str, List[str]]: item_meta["source"] = source_item_paths return item_meta - def _load_collections(self) -> List[Dict[str, Any]]: + def _load_collections(self) -> list[dict[str, Any]]: """Loads the paths to source and label assets for each collection. Returns: @@ -272,7 +273,7 @@ def _load_collections(self) -> List[Dict[str, Any]]: RuntimeError if collection.json is not found in the uncompressed dataset """ indexed_chips = [] - label_collection: List[str] = [] + label_collection: list[str] = [] for c in self.collection_names[self.split]: if "label" in c: label_collection.append(c) @@ -351,7 +352,7 @@ def _download(self, api_key: Optional[str] = None) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index a0ccea6731b..fb4408a09f4 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -172,7 +172,7 @@ def __init__( res: Optional[float] = None, measurement: str = "agb", country: str = all_countries[0], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -253,7 +253,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index d34749cf2d0..e9d8dc02b7f 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -6,7 +6,7 @@ import abc import csv import os -from typing import Callable, Dict, List, Optional, cast +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -47,12 +47,12 @@ def base_url(self) -> str: @property @abc.abstractmethod - def filenames(self) -> List[str]: + def filenames(self) -> list[str]: """List of files to download.""" @property @abc.abstractmethod - def md5s(self) -> List[str]: + def md5s(self) -> list[str]: """List of MD5 checksums of files to download.""" @property @@ -64,7 +64,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -111,7 +111,7 @@ def __init__( self.images.append(row[0]) self.targets.append(row[1]) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -193,7 +193,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index adf00b99609..989b7dc2f76 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -6,7 +6,7 @@ import csv import os from functools import lru_cache -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -117,8 +117,8 @@ def __init__( root: str = "data", chip_size: int = 256, stride: int = 128, - bands: Tuple[str, ...] = band_names, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + bands: tuple[str, ...] = band_names, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -172,7 +172,7 @@ def __init__( ]: self.chips_metadata.append((tile_index, y, x)) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -214,7 +214,7 @@ def __len__(self) -> int: return len(self.chips_metadata) @lru_cache(maxsize=128) - def _load_label_tile(self, tile_name: str) -> Tuple[Tensor, Tensor]: + def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]: """Load a single _tile_ of labels and field_ids. Args: @@ -245,7 +245,7 @@ def _load_label_tile(self, tile_name: str) -> Tuple[Tensor, Tensor]: return (labels, field_ids) - def _validate_bands(self, bands: Tuple[str, ...]) -> None: + def _validate_bands(self, bands: tuple[str, ...]) -> None: """Validate list of bands. Args: @@ -262,7 +262,7 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None: @lru_cache(maxsize=128) def _load_all_image_tiles( - self, tile_name: str, bands: Tuple[str, ...] = band_names + self, tile_name: str, bands: tuple[str, ...] = band_names ) -> Tensor: """Load all the imagery (across time) for a single _tile_. @@ -299,7 +299,7 @@ def _load_all_image_tiles( @lru_cache(maxsize=128) def _load_single_image_tile( - self, tile_name: str, date: str, bands: Tuple[str, ...] + self, tile_name: str, date: str, bands: tuple[str, ...] ) -> Tensor: """Load the imagery for a single tile for a single date. @@ -356,7 +356,7 @@ def _check_integrity(self) -> bool: return images and targets - def get_splits(self) -> Tuple[List[int], List[int]]: + def get_splits(self) -> tuple[list[int], list[int]]: """Get the field_ids for the train/test splits from the dataset directory. Returns: @@ -407,7 +407,7 @@ def _download(self, api_key: Optional[str] = None) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, time_step: int = 0, suptitle: Optional[str] = None, diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index f8bf4645225..4022f5a84f9 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -6,7 +6,7 @@ import json import os from functools import lru_cache -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -67,7 +67,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -108,7 +108,7 @@ def __init__( with open(filename) as f: self.collection = json.load(f)["links"] - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: @@ -124,7 +124,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: source_id.replace("source", "{0}"), ) - sample: Dict[str, Any] = {"image": self._load_image(directory)} + sample: dict[str, Any] = {"image": self._load_image(directory)} sample.update(self._load_features(directory)) if self.transforms is not None: @@ -163,7 +163,7 @@ def _load_image(self, directory: str) -> Tensor: tensor = torch.from_numpy(array).permute((2, 0, 1)).float() return tensor - def _load_features(self, directory: str) -> Dict[str, Any]: + def _load_features(self, directory: str) -> dict[str, Any]: """Load features for a single image. Args: @@ -174,7 +174,7 @@ def _load_features(self, directory: str) -> Dict[str, Any]: """ filename = os.path.join(directory.format("source"), "features.json") with open(filename) as f: - features: Dict[str, Any] = json.load(f) + features: dict[str, Any] = json.load(f) filename = os.path.join(directory.format("labels"), "labels.json") with open(filename) as f: @@ -224,7 +224,7 @@ def _download(self, api_key: Optional[str] = None) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index 1113e9bb64b..694da07f3d1 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -4,7 +4,7 @@ """DeepGlobe Land Cover Classification Challenge dataset.""" import os -from typing import Callable, Dict, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -91,7 +91,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new DeepGlobeLandCover dataset instance. @@ -132,7 +132,7 @@ def __init__( self.image_fns.append(image_path) self.mask_fns.append(mask_path) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -221,7 +221,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 6a11fdc6344..2c7df0350fb 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -5,7 +5,8 @@ import glob import os -from typing import Callable, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -137,7 +138,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new DFC2022 dataset instance. @@ -163,7 +164,7 @@ def __init__( self.class2idx = {c: i for i, c in enumerate(self.classes)} self.files = self._load_files() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -196,7 +197,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self) -> List[Dict[str, str]]: + def _load_files(self) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Returns: @@ -294,7 +295,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index 2866c9b288f..8dfa5a7c957 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -5,7 +5,7 @@ import os import sys -from typing import Any, Dict +from typing import Any import numpy as np from rasterio.crs import CRS @@ -91,7 +91,7 @@ def __init__(self, root: str = "data") -> None: self.index.insert(i, coords) i += 1 - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve metadata indexed by query. Args: diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 5cc7e416ca0..8487389927f 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -5,7 +5,8 @@ import os import sys -from typing import Any, Callable, Dict, List, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Callable, Optional, cast import fiona import matplotlib.pyplot as plt @@ -253,7 +254,7 @@ def __init__( root: str = "data", splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"], layers: Sequence[str] = ["naip", "prior"], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, prior_as_input: bool = False, cache: bool = True, download: bool = False, @@ -330,7 +331,7 @@ def __init__( }, ) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -343,7 +344,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[Dict[str, str]], [hit.object for hit in hits]) + filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} @@ -450,7 +451,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 7a97de8b0a1..64c4cc82d33 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -69,7 +69,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -139,7 +139,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 2a11edc9204..cdb42a6850d 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -81,7 +81,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -118,7 +118,7 @@ def __init__( self.files = self._load_files(self.root, self.split) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -154,7 +154,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files(self, root: str, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -260,7 +260,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 13be06458af..5a2134f84dc 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -84,7 +84,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -141,7 +141,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 96a61870271..74a5de64536 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,7 +4,8 @@ """EuroSAT dataset.""" import os -from typing import Callable, Dict, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -109,7 +110,7 @@ def __init__( root: str = "data", split: str = "train", bands: Sequence[str] = BAND_SETS["all"], - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -160,7 +161,7 @@ def __init__( is_valid_file=is_in_split, ) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -257,7 +258,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index 9bfdd61c50a..33cc6d01683 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Callable, Optional, cast from xml.etree.ElementTree import Element, parse import matplotlib.patches as patches @@ -19,7 +19,7 @@ from .utils import check_integrity, extract_archive -def parse_pascal_voc(path: str) -> Dict[str, Any]: +def parse_pascal_voc(path: str) -> dict[str, Any]: """Read a PASCAL VOC annotation file. Args: @@ -164,7 +164,7 @@ class FAIR1M(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new FAIR1M dataset instance. @@ -183,7 +183,7 @@ def __init__( glob.glob(os.path.join(self.root, self.labels_root, "*.xml")) ) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -229,8 +229,8 @@ def _load_image(self, path: str) -> Tensor: return tensor def _load_target( - self, points: List[List[Tuple[float, float]]], labels: List[str] - ) -> Tuple[Tensor, Tensor]: + self, points: list[list[tuple[float, float]]], labels: list[str] + ) -> tuple[Tensor, Tensor]: """Load the target mask for a single image. Args: @@ -280,7 +280,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 6c00604c793..1cb6fa567d8 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional from xml.etree import ElementTree import matplotlib.patches as patches @@ -19,7 +19,7 @@ from .utils import check_integrity, download_and_extract_archive, extract_archive -def parse_pascal_voc(path: str) -> Dict[str, Any]: +def parse_pascal_voc(path: str) -> dict[str, Any]: """Read a PASCAL VOC annotation file. Args: @@ -104,7 +104,7 @@ class ForestDamage(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -130,9 +130,9 @@ def __init__( self.files = self._load_files(self.root) - self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -162,7 +162,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> List[Dict[str, str]]: + def _load_files(self, root: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -202,8 +202,8 @@ def _load_image(self, path: str) -> Tensor: return tensor def _load_target( - self, bboxes: List[List[int]], labels_list: List[str] - ) -> Tuple[Tensor, Tensor]: + self, bboxes: list[list[int]], labels_list: list[str] + ) -> tuple[Tensor, Tensor]: """Load the target mask for a single image. Args: @@ -260,7 +260,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index b9e34b764dd..17e6952ab33 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -7,7 +7,7 @@ import os import sys from datetime import datetime, timedelta -from typing import Any, Dict, Tuple +from typing import Any import numpy as np from rasterio.crs import CRS @@ -18,7 +18,7 @@ def _disambiguate_timestamps( year: float, month: float, day: float -) -> Tuple[float, float]: +) -> tuple[float, float]: """Disambiguate partial timestamps. Based on :func:`torchgeo.datasets.utils.disambiguate_timestamps`. @@ -128,7 +128,7 @@ def __init__(self, root: str = "data") -> None: self.index.insert(i, coords) i += 1 - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve metadata indexed by query. Args: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index ef707685bf5..f5ed1d3ec38 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -9,7 +9,8 @@ import os import re import sys -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast +from collections.abc import Sequence +from typing import Any, Callable, Optional, cast import fiona import fiona.transform @@ -32,7 +33,7 @@ from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples -class GeoDataset(Dataset[Dict[str, Any]], abc.ABC): +class GeoDataset(Dataset[dict[str, Any]], abc.ABC): """Abstract base class for datasets containing geospatial information. Geospatial information includes things like: @@ -90,7 +91,7 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC): __add__ = None # type: ignore[assignment] def __init__( - self, transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None + self, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None ) -> None: """Initialize a new Dataset instance. @@ -104,7 +105,7 @@ def __init__( self.index = Index(interleaved=False, properties=Property(dimension=3)) @abc.abstractmethod - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -174,7 +175,7 @@ def __str__(self) -> str: def __getstate__( self, - ) -> Tuple[Dict[str, Any], List[Tuple[Any, Any, Optional[Any]]]]: + ) -> tuple[dict[str, Any], list[tuple[Any, Any, Optional[Any]]]]: """Define how instances are pickled. Returns: @@ -186,9 +187,9 @@ def __getstate__( def __setstate__( self, - state: Tuple[ - Dict[Any, Any], - List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + state: tuple[ + dict[Any, Any], + list[tuple[int, tuple[float, float, float, float, float, float], str]], ], ) -> None: """Define how to unpickle an instance. @@ -287,13 +288,13 @@ class RasterDataset(GeoDataset): separate_files = False #: Names of all available bands in the dataset - all_bands: List[str] = [] + all_bands: list[str] = [] #: Names of RGB bands in the dataset, used for plotting - rgb_bands: List[str] = [] + rgb_bands: list[str] = [] #: Color map for the dataset, used for plotting - cmap: Dict[int, Tuple[int, int, int, int]] = {} + cmap: dict[int, tuple[int, int, int, int]] = {} def __init__( self, @@ -301,7 +302,7 @@ def __init__( crs: Optional[CRS] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -386,7 +387,7 @@ def __init__( self._crs = cast(CRS, crs) self.res = cast(float, res) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -399,7 +400,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[str], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -407,7 +408,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: ) if self.separate_files: - data_list: List[Tensor] = [] + data_list: list[Tensor] = [] filename_regex = re.compile(self.filename_regex, re.VERBOSE) for band in self.bands: band_filepaths = [] @@ -532,7 +533,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: float = 0.0001, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, label_name: Optional[str] = None, ) -> None: """Initialize a new Dataset instance. @@ -588,7 +589,7 @@ def __init__( self._crs = crs - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -652,14 +653,14 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: return sample -class NonGeoDataset(Dataset[Dict[str, Any]], abc.ABC): +class NonGeoDataset(Dataset[dict[str, Any]], abc.ABC): """Abstract base class for datasets lacking geospatial information. This base class is designed for datasets with pre-defined image chips. """ @abc.abstractmethod - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: @@ -702,7 +703,7 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, loader: Optional[Callable[[str], Any]] = pil_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: @@ -730,7 +731,7 @@ def __init__( # Must be set after calling super().__init__() self.transforms = transforms - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -754,7 +755,7 @@ def __len__(self) -> int: """ return len(self.imgs) - def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: + def _load_image(self, index: int) -> tuple[Tensor, Tensor]: """Load a single image and it's class label. Args: @@ -797,9 +798,9 @@ def __init__( dataset1: GeoDataset, dataset2: GeoDataset, collate_fn: Callable[ - [Sequence[Dict[str, Any]]], Dict[str, Any] + [Sequence[dict[str, Any]]], dict[str, Any] ] = concat_samples, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, ) -> None: """Initialize a new Dataset instance. @@ -859,7 +860,7 @@ def _merge_dataset_indices(self) -> None: if i == 0: raise RuntimeError("Datasets have no spatiotemporal intersection") - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image and metadata indexed by query. Args: @@ -924,9 +925,9 @@ def __init__( dataset1: GeoDataset, dataset2: GeoDataset, collate_fn: Callable[ - [Sequence[Dict[str, Any]]], Dict[str, Any] + [Sequence[dict[str, Any]]], dict[str, Any] ] = merge_samples, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, ) -> None: """Initialize a new Dataset instance. @@ -980,7 +981,7 @@ def _merge_dataset_indices(self) -> None: self.index.insert(i, hit.bounds) i += 1 - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image and metadata indexed by query. Args: diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 19602b1c24b..a027bde4770 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -88,7 +88,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -125,7 +125,7 @@ def __init__( self.files = self._load_files(self.root, self.split) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -156,7 +156,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files(self, root: str, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -242,7 +242,7 @@ def _download(self) -> None: ) def plot( - self, sample: Dict[str, Tensor], suptitle: Optional[str] = None + self, sample: dict[str, Tensor], suptitle: Optional[str] = None ) -> plt.Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 4dd6b8aca07..8f7959116e3 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Optional, cast import matplotlib.pyplot as plt import torch @@ -121,7 +121,7 @@ def __init__( crs: Optional[CRS] = None, res: Optional[float] = None, measurement: str = "agb", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -162,7 +162,7 @@ def __init__( super().__init__(root, crs, res, transforms=transforms, cache=cache) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -176,7 +176,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[str], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -227,7 +227,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 3e608f60fe1..043a0cdf5f6 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload +from typing import Any, Callable, Optional, cast, overload import fiona import matplotlib.pyplot as plt @@ -146,7 +146,7 @@ def __init__( root: str = "data", split: str = "train", task: str = "task1", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -193,7 +193,7 @@ def __init__( self.images, self.geometries, self.labels = self._load(root) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -281,7 +281,7 @@ def _load_boxes(self, path: str) -> Tensor: the bounding boxes """ base_path = os.path.basename(path) - geometries = cast(Dict[int, Dict[str, Any]], self.geometries) + geometries = cast(dict[int, dict[str, Any]], self.geometries) # Find object ids and geometries # The train set geometry->image mapping is contained @@ -336,7 +336,7 @@ def _load_target(self, path: str) -> Tensor: def _load( self, root: str - ) -> Tuple[List[str], Optional[Dict[int, Dict[str, Any]]], Any]: + ) -> tuple[list[str], Optional[dict[int, dict[str, Any]]], Any]: """Load files, geometries, and labels. Args: @@ -386,7 +386,7 @@ def _load_labels(self, directory: str) -> Any: df.reset_index() return df - def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]: + def _load_geometries(self, directory: str) -> dict[int, dict[str, Any]]: """Load the shape files containing the geometries. Args: @@ -398,7 +398,7 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]: filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp")) i = 0 - features: Dict[int, Dict[str, Any]] = {} + features: dict[int, dict[str, Any]] = {} for path in filepaths: with fiona.open(path) as src: for feature in src: @@ -413,23 +413,23 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]: @overload def _filter_boxes( - self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor - ) -> Tuple[Tensor, Tensor]: + self, image_size: tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor + ) -> tuple[Tensor, Tensor]: ... @overload def _filter_boxes( - self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: None - ) -> Tuple[Tensor, None]: + self, image_size: tuple[int, int], min_size: int, boxes: Tensor, labels: None + ) -> tuple[Tensor, None]: ... def _filter_boxes( self, - image_size: Tuple[int, int], + image_size: tuple[int, int], min_size: int, boxes: Tensor, labels: Optional[Tensor], - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Optional[Tensor]]: """Clip boxes to image size and filter boxes with sides less than ``min_size``. Args: @@ -492,10 +492,10 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, - hsi_indices: Tuple[int, int, int] = (0, 1, 2), + hsi_indices: tuple[int, int, int] = (0, 1, 2), ) -> plt.Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index a433380576a..ac4fa41bb40 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -6,7 +6,7 @@ import glob import os import sys -from typing import Any, Dict +from typing import Any from rasterio.crs import CRS @@ -98,7 +98,7 @@ def __init__(self, root: str = "data") -> None: self.index.insert(i, coords) i += 1 - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve metadata indexed by query. Args: diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 395d0c909bb..f05aa7e56f6 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -55,7 +55,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, checksum: bool = False, ) -> None: """Initialize a new InriaAerialImageLabeling Dataset instance. @@ -80,7 +80,7 @@ def __init__( self._verify() self.files = self._load_files(root) - def _load_files(self, root: str) -> List[Dict[str, str]]: + def _load_files(self, root: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -142,7 +142,7 @@ def __len__(self) -> int: """ return len(self.files) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -183,7 +183,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> Figure: diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 9c6c55c6698..d5abea73315 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -6,7 +6,8 @@ import glob import os import re -from typing import Any, Callable, Dict, List, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Callable, Optional, cast import matplotlib.pyplot as plt import torch @@ -95,7 +96,7 @@ def __init__( crs: Optional[CRS] = None, res: Optional[float] = None, bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -171,7 +172,7 @@ def _extract(self) -> None: for tarfile in glob.iglob(pathname): extract_archive(tarfile) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -184,14 +185,14 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[str], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( f"query: {query} not found in index with bounds: {self.bounds}" ) - image_list: List[Tensor] = [] + image_list: list[Tensor] = [] filename_regex = re.compile(self.filename_regex, re.VERBOSE) for band in self.all_bands: band_filepaths = [] @@ -239,7 +240,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index a6026a48164..e848c5dcdf4 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -6,7 +6,8 @@ import glob import os import re -from typing import Any, Callable, Dict, List, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Callable, Optional, cast import matplotlib.pyplot as plt import torch @@ -88,7 +89,7 @@ def __init__( crs: Optional[CRS] = None, res: Optional[float] = None, bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -165,7 +166,7 @@ def _extract(self) -> None: for tarfile in glob.iglob(pathname): extract_archive(tarfile) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -178,14 +179,14 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[str], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( f"query: {query} not found in index with bounds: {self.bounds}" ) - image_list: List[Tensor] = [] + image_list: list[Tensor] = [] filename_regex = re.compile(self.filename_regex, re.VERBOSE) for band in self.bands: band_filepaths = [] @@ -227,7 +228,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index aad853bb2f2..9203f71e8b6 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -7,7 +7,7 @@ import hashlib import os from functools import lru_cache -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -22,7 +22,7 @@ from .utils import BoundingBox, download_url, extract_archive, working_dir -class LandCoverAIBase(Dataset[Dict[str, Any]], abc.ABC): +class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): r"""Abstract base class for LandCover.ai Geo and NonGeo datasets. The `LandCover.ai `__ (Land Cover from @@ -125,7 +125,7 @@ def _verify(self) -> None: self._extract() @abc.abstractmethod - def __getitem__(self, query: Any) -> Dict[str, Any]: + def __getitem__(self, query: Any) -> dict[str, Any]: """Retrieve image, mask and metadata indexed by index. Args: @@ -152,7 +152,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -213,7 +213,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -247,7 +247,7 @@ def _verify_data(self) -> bool: masks = glob.glob(mask_query) return len(images) > 0 and len(images) == len(masks) - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -260,7 +260,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - img_filepaths = cast(List[str], [hit.object for hit in hits]) + img_filepaths = cast(list[str], [hit.object for hit in hits]) mask_filepaths = [path.replace("images", "masks") for path in img_filepaths] if not img_filepaths: @@ -302,7 +302,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -330,7 +330,7 @@ def __init__( with open(os.path.join(self.root, split + ".txt")) as f: self.ids = f.readlines() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 618c1282ff9..2816cb9b506 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -4,7 +4,8 @@ """Landsat datasets.""" import abc -from typing import Any, Callable, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Callable, Optional import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -51,7 +52,7 @@ class Landsat(RasterDataset, abc.ABC): @property @abc.abstractmethod - def default_bands(self) -> List[str]: + def default_bands(self) -> list[str]: """Bands to load by default.""" def __init__( @@ -60,7 +61,7 @@ def __init__( crs: Optional[CRS] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -86,7 +87,7 @@ def __init__( def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 8beda317aaa..cf7349339ed 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -55,7 +55,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -92,7 +92,7 @@ def __init__( self.files = self._load_files(self.root, self.directory, self.split) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -124,7 +124,7 @@ def __len__(self) -> int: def _load_files( self, root: str, directory: str, split: str - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -210,7 +210,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index f6e2b83f3d2..22e0a23d57f 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -91,8 +91,8 @@ def __init__( self, root: str = "data", split: str = "train", - scene: List[str] = ["urban", "rural"], - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + scene: list[str] = ["urban", "rural"], + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -145,7 +145,7 @@ def __init__( self.files = self._load_files(self.scene_paths, self.split) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -177,7 +177,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, scene_paths: List[str], split: str) -> List[Dict[str, str]]: + def _load_files(self, scene_paths: list[str], split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -265,7 +265,7 @@ def _download(self) -> None: ) def plot( - self, sample: Dict[str, Tensor], suptitle: Optional[str] = None + self, sample: dict[str, Tensor], suptitle: Optional[str] = None ) -> plt.Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 9310624925c..d3ecb7224ef 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -4,7 +4,7 @@ """Million-AID dataset.""" import glob import os -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -190,7 +190,7 @@ def __init__( root: str = "data", task: str = "multi-class", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new MillionAID dataset instance. @@ -219,7 +219,7 @@ def __init__( self.files = self._load_files(self.root) self.classes = sorted({cls for f in self.files for cls in f["label"]}) - self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} def __len__(self) -> int: """Return the number of data points in the dataset. @@ -229,7 +229,7 @@ def __len__(self) -> int: """ return len(self.files) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -249,7 +249,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: return sample - def _load_files(self, root: str) -> List[Dict[str, Any]]: + def _load_files(self, root: str) -> list[dict[str, Any]]: """Return the paths of the files in the dataset. Args: @@ -333,7 +333,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index 558df96dea4..027ad8aa409 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -3,7 +3,7 @@ """National Agriculture Imagery Program (NAIP) dataset.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import matplotlib.pyplot as plt @@ -49,7 +49,7 @@ class NAIP(RasterDataset): def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 026af81123e..3f124f42dba 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -4,7 +4,7 @@ """NASA Marine Debris dataset.""" import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -60,7 +60,7 @@ class NASAMarineDebris(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -86,7 +86,7 @@ def __init__( self._verify() self.files = self._load_files() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -147,7 +147,7 @@ def _load_target(self, path: str) -> Tensor: tensor = torch.from_numpy(array) return tensor - def _load_files(self) -> List[Dict[str, str]]: + def _load_files(self) -> list[dict[str, str]]: """Load a image and label files. Returns: @@ -221,7 +221,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 94c10b073e6..57e735b59a3 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -7,7 +7,7 @@ import json import os import sys -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Optional, cast import fiona import fiona.transform @@ -206,7 +206,7 @@ def __init__( root: str = "data", crs: Optional[CRS] = None, res: float = 0.0001, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, checksum: bool = False, ) -> None: """Initialize a new Dataset instance. @@ -292,7 +292,7 @@ def __init__( self._crs = crs self._source_crs = source_crs - def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -306,7 +306,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(List[str], [hit.object for hit in hits]) + filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -337,8 +337,8 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: return sample def _filter_geometries( - self, query: BoundingBox, filepaths: List[str] - ) -> List[Dict[str, Any]]: + self, query: BoundingBox, filepaths: list[str] + ) -> list[dict[str, Any]]: """Filters a df read from the polygon csv file based on query and conf thresh. Args: @@ -372,7 +372,7 @@ def _filter_geometries( return shapes - def _wkt_fiona_geom_transform(self, x: str) -> Dict[str, Any]: + def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: """Function to transform a geometry string into new crs. Args: @@ -392,7 +392,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> Dict[str, Any]: geom = fiona.model.Geometry(**x) else: geom = x - transformed: Dict[str, Any] = fiona.transform.transform_geom( + transformed: dict[str, Any] = fiona.transform.transform_geom( self._source_crs.to_dict(), self._crs.to_dict(), geom ) return transformed @@ -431,7 +431,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index b302e8e4d21..6f5b7003d20 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -5,7 +5,8 @@ import glob import os -from typing import Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Callable, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -82,7 +83,7 @@ def __init__( root: str = "data", split: str = "train", bands: str = "all", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -115,7 +116,7 @@ def __init__( self.files = self._load_files() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -145,7 +146,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self) -> List[Dict[str, Union[str, Sequence[str]]]]: + def _load_files(self) -> list[dict[str, Union[str, Sequence[str]]]]: regions = [] labels_root = os.path.join( self.root, @@ -160,7 +161,7 @@ def _load_files(self) -> List[Dict[str, Union[str, Sequence[str]]]]: region = folder.split(os.sep)[-2] mask = os.path.join(labels_root, region, "cm", "cm.png") - def get_image_paths(ind: int) -> List[str]: + def get_image_paths(ind: int) -> list[str]: return sorted( glob.glob( os.path.join(images_root, region, f"imgs_{ind}_rect", "*.tif") @@ -198,7 +199,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor: Returns: the image """ - images: List["np.typing.NDArray[np.int_]"] = [] + images: list["np.typing.NDArray[np.int_]"] = [] for path in paths: with Image.open(path) as img: images.append(np.array(img)) @@ -271,7 +272,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index 1f6c594c2ef..a13e2431b26 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -4,7 +4,7 @@ """PatternNet dataset.""" import os -from typing import Callable, Dict, Optional, cast +from typing import Callable, Optional, cast import matplotlib.pyplot as plt from torch import Tensor @@ -83,7 +83,7 @@ class PatternNet(NonGeoClassificationDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -147,7 +147,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 8ecfb1c6891..5443c807d22 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -4,7 +4,7 @@ """Potsdam dataset.""" import os -from typing import Callable, Dict, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -122,7 +122,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new Potsdam dataset instance. @@ -149,7 +149,7 @@ def __init__( if os.path.exists(image) and os.path.exists(mask): self.files.append(dict(image=image, mask=mask)) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -241,7 +241,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 68cb269b4ef..1e50165e5b8 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Optional import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -62,7 +62,7 @@ class ReforesTree(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -97,9 +97,9 @@ def __init__( self.annot_df = pd.read_csv(os.path.join(root, "mapping", "final_dataset.csv")) - self.class2idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + self.class2idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -129,7 +129,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> List[str]: + def _load_files(self, root: str) -> list[str]: """Return the paths of the files in the dataset. Args: @@ -158,7 +158,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, filepath: str) -> Tuple[Tensor, ...]: + def _load_target(self, filepath: str) -> tuple[Tensor, ...]: """Load boxes and labels for a single image. Args: @@ -220,7 +220,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index a8cc4ca1ef1..2c5e5871b79 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -4,7 +4,7 @@ """RESISC45 dataset.""" import os -from typing import Callable, Dict, Optional, cast +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -159,7 +159,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -243,7 +243,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 45f715b34d7..757f6765448 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -5,7 +5,7 @@ import os import random -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -71,8 +71,8 @@ def __init__( root: str = "data", version: str = "100k", seasons: int = 1, - bands: List[str] = rgb_bands, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + bands: list[str] = rgb_bands, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -111,7 +111,7 @@ def __init__( self._verify() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -229,7 +229,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 2c50ebb2c29..4f49ff67a46 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -4,7 +4,8 @@ """SEN12MS dataset.""" import os -from typing import Callable, Dict, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -63,7 +64,7 @@ class SEN12MS(NonGeoDataset): This download will likely take several hours. """ # noqa: E501 - BAND_SETS: Dict[str, Tuple[str, ...]] = { + BAND_SETS: dict[str, tuple[str, ...]] = { "all": ( "VV", "VH", @@ -166,7 +167,7 @@ def __init__( root: str = "data", split: str = "train", bands: Sequence[str] = BAND_SETS["all"], - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new SEN12MS dataset instance. @@ -212,7 +213,7 @@ def __init__( with open(os.path.join(self.root, split + "_list.txt")) as f: self.ids = [line.rstrip() for line in f.readlines()] - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -230,7 +231,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: image = torch.cat(tensors=[s1, s2], dim=0) image = torch.index_select(image, dim=0, index=self.band_indices) - sample: Dict[str, Tensor] = {"image": image, "mask": lc[0]} + sample: dict[str, Tensor] = {"image": image, "mask": lc[0]} if self.transforms is not None: sample = self.transforms(sample) @@ -313,7 +314,7 @@ def _check_integrity(self) -> bool: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 7bbb1a0b590..2b7a32d285c 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -3,7 +3,8 @@ """Sentinel datasets.""" -from typing import Any, Callable, Dict, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Callable, Optional import matplotlib.pyplot as plt import torch @@ -142,7 +143,7 @@ def __init__( crs: Optional[CRS] = None, res: float = 10, bands: Sequence[str] = ["VV", "VH"], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -186,7 +187,7 @@ def __init__( def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: @@ -295,7 +296,7 @@ def __init__( crs: Optional[CRS] = None, res: float = 10, bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -322,7 +323,7 @@ def __init__( def plot( self, - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 0770228a229..0cd620a78a3 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -4,7 +4,8 @@ """So2Sat dataset.""" import os -from typing import Callable, Dict, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -142,7 +143,7 @@ def __init__( root: str = "data", split: str = "train", bands: Sequence[str] = BAND_SETS["all"], - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new So2Sat dataset instance. @@ -208,7 +209,7 @@ def __init__( with h5py.File(self.fn, "r") as f: self.size: int = f["label"].shape[0] - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -279,7 +280,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index f0e3987dd64..e1f00347dfe 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -9,7 +9,7 @@ import math import os import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import fiona import matplotlib.pyplot as plt @@ -57,7 +57,7 @@ def dataset_id(self) -> str: @property @abc.abstractmethod - def imagery(self) -> Dict[str, str]: + def imagery(self) -> dict[str, str]: """Mapping of image identifier and filename.""" @property @@ -67,20 +67,20 @@ def label_glob(self) -> str: @property @abc.abstractmethod - def collection_md5_dict(self) -> Dict[str, str]: + def collection_md5_dict(self) -> dict[str, str]: """Mapping of collection id and md5 checksum.""" @property @abc.abstractmethod - def chip_size(self) -> Dict[str, Tuple[int, int]]: + def chip_size(self) -> dict[str, tuple[int, int]]: """Mapping of images and their chip size.""" def __init__( self, root: str, image: str, - collections: List[str] = [], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + collections: list[str] = [], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -126,7 +126,7 @@ def __init__( self.files = self._load_files(root) - def _load_files(self, root: str) -> List[Dict[str, str]]: + def _load_files(self, root: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -146,7 +146,7 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: files.append({"image_path": imgpath, "label_path": lbl_path}) return files - def _load_image(self, path: str) -> Tuple[Tensor, Affine, CRS]: + def _load_image(self, path: str) -> tuple[Tensor, Affine, CRS]: """Load a single image. Args: @@ -162,7 +162,7 @@ def _load_image(self, path: str) -> Tuple[Tensor, Affine, CRS]: return tensor, img.transform, img.crs def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int] + self, path: str, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] ) -> Tensor: """Rasterizes the dataset's labels (in geojson format). @@ -215,7 +215,7 @@ def __len__(self) -> int: """ return len(self.files) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -237,7 +237,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: return sample - def _check_integrity(self) -> List[str]: + def _check_integrity(self) -> list[str]: """Checks the integrity of the dataset structure. Returns: @@ -277,7 +277,7 @@ def _check_integrity(self) -> List[str]: return to_be_downloaded - def _download(self, collections: List[str], api_key: Optional[str] = None) -> None: + def _download(self, collections: list[str], api_key: Optional[str] = None) -> None: """Download the dataset and extract it. Args: @@ -303,7 +303,7 @@ def _download(self, collections: List[str], api_key: Optional[str] = None) -> No def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> Figure: @@ -404,7 +404,7 @@ def __init__( self, root: str = "data", image: str = "rgb", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -519,8 +519,8 @@ def __init__( self, root: str = "data", image: str = "PS-RGB", - collections: List[str] = [], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + collections: list[str] = [], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -640,8 +640,8 @@ def __init__( root: str = "data", image: str = "PS-RGB", speed_mask: Optional[bool] = False, - collections: List[str] = [], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + collections: list[str] = [], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -673,7 +673,7 @@ def __init__( ) def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int] + self, path: str, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] ) -> Tensor: """Rasterizes the dataset's labels (in geojson format). @@ -737,7 +737,7 @@ def _load_mask( def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> Figure: @@ -889,8 +889,8 @@ def __init__( self, root: str = "data", image: str = "PS-RGBNIR", - angles: List[str] = [], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + angles: list[str] = [], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -921,7 +921,7 @@ def __init__( root, image, collections, transforms, download, api_key, checksum ) - def _load_files(self, root: str) -> List[Dict[str, str]]: + def _load_files(self, root: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -1058,8 +1058,8 @@ def __init__( root: str = "data", image: str = "PS-RGB", speed_mask: Optional[bool] = False, - collections: List[str] = [], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + collections: list[str] = [], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -1189,7 +1189,7 @@ def __init__( self, root: str = "data", image: str = "PS-RGB", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, ) -> None: @@ -1290,7 +1290,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, @@ -1336,7 +1336,7 @@ def __init__( self.files = self._load_files(root) - def _load_files(self, root: str) -> List[Dict[str, str]]: + def _load_files(self, root: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -1363,7 +1363,7 @@ def _load_files(self, root: str) -> List[Dict[str, str]]: files.append({"image_path": img}) return files - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index d85d56b63dd..0164f5438d3 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -3,10 +3,11 @@ """Dataset splitting utilities.""" +from collections.abc import Sequence from copy import deepcopy from itertools import accumulate from math import floor, isclose -from typing import List, Optional, Sequence, Tuple, Union, cast +from typing import Optional, Union, cast from rtree.index import Index, Property from torch import Generator, default_generator, randint, randperm @@ -50,7 +51,7 @@ def random_bbox_assignment( dataset: GeoDataset, lengths: Sequence[float], generator: Optional[Generator] = default_generator, -) -> List[GeoDataset]: +) -> list[GeoDataset]: """Split a GeoDataset randomly assigning its index's BoundingBoxes. This function will go through each BoundingBox in the GeoDataset's index and @@ -104,7 +105,7 @@ def random_bbox_splitting( dataset: GeoDataset, fractions: Sequence[float], generator: Optional[Generator] = default_generator, -) -> List[GeoDataset]: +) -> list[GeoDataset]: """Split a GeoDataset randomly splitting its index's BoundingBoxes. This function will go through each BoundingBox in the GeoDataset's index, @@ -172,7 +173,7 @@ def random_grid_cell_assignment( fractions: Sequence[float], grid_size: int = 6, generator: Optional[Generator] = default_generator, -) -> List[GeoDataset]: +) -> list[GeoDataset]: """Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets. This function will go through each BoundingBox in the GeoDataset's index, overlay @@ -250,7 +251,7 @@ def random_grid_cell_assignment( return new_datasets -def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDataset]: +def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> list[GeoDataset]: """Split a GeoDataset intersecting it with a ROI for each desired new GeoDataset. Args: @@ -288,8 +289,8 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> List[GeoDatas def time_series_split( - dataset: GeoDataset, lengths: Sequence[Union[float, Tuple[float, float]]] -) -> List[GeoDataset]: + dataset: GeoDataset, lengths: Sequence[Union[float, tuple[float, float]]] +) -> list[GeoDataset]: """Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets. Args: @@ -325,7 +326,7 @@ def time_series_split( for offset, length in zip(accumulate(lengths), lengths) ] - lengths = cast(Sequence[Tuple[float, float]], lengths) + lengths = cast(Sequence[tuple[float, float]], lengths) new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index 17a83a8e6d3..72901f86930 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -5,7 +5,7 @@ import os import random -from typing import Callable, Dict, Optional, cast +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -101,7 +101,7 @@ def __init__( root: str = "data", split: str = "s2c", seasons: int = 1, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new SSL4EOS12 instance. @@ -133,7 +133,7 @@ def __init__( self._verify() - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -198,7 +198,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 40449294542..145cda876c5 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -3,7 +3,7 @@ """UC Merced dataset.""" import os -from typing import Callable, Dict, Optional, cast +from typing import Callable, Optional, cast import matplotlib.pyplot as plt import numpy as np @@ -106,7 +106,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -205,7 +205,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, ) -> plt.Figure: diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 980fb639954..d2a166c92ff 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -5,7 +5,8 @@ import glob import os -from typing import Callable, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -98,7 +99,7 @@ def __init__( root: str = "data", split: str = "train", labels: Sequence[str] = ALL_LABELS, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -148,7 +149,7 @@ def __init__( for lab in self.labels } - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -180,7 +181,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self) -> List[str]: + def _load_files(self) -> list[str]: """Loads file names.""" with open(os.path.join(self.root, f"{self.split}_split.txt")) as f: files = f.read().splitlines() @@ -252,7 +253,7 @@ def _extract(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_labels: bool = True, suptitle: Optional[str] = None, ) -> Figure: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index a9dc6627264..55d94abed6f 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -3,6 +3,9 @@ """Common dataset utilities.""" +# https://github.com/sphinx-doc/sphinx/issues/11327 +from __future__ import annotations + import bz2 import collections import contextlib @@ -11,21 +14,10 @@ import os import sys import tarfile +from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, - Union, - cast, - overload, -) +from typing import Any, cast, overload import numpy as np import rasterio @@ -97,7 +89,7 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: pass -def extract_archive(src: str, dst: Optional[str] = None) -> None: +def extract_archive(src: str, dst: str | None = None) -> None: """Extract an archive. Args: @@ -110,7 +102,7 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None: if dst is None: dst = os.path.dirname(src) - suffix_and_extractor: List[Tuple[Union[str, Tuple[str, ...]], Any]] = [ + suffix_and_extractor: list[tuple[str | tuple[str, ...], Any]] = [ (".rar", _rarfile.RarFile), ( (".tar", ".tar.gz", ".tar.bz2", ".tar.xz", ".tgz", ".tbz2", ".tbz", ".txz"), @@ -125,7 +117,7 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None: f.extractall(dst) return - suffix_and_decompressor: List[Tuple[str, Any]] = [ + suffix_and_decompressor: list[tuple[str, Any]] = [ (".bz2", bz2.open), (".gz", gzip.open), (".xz", lzma.open), @@ -144,9 +136,9 @@ def extract_archive(src: str, dst: Optional[str] = None) -> None: def download_and_extract_archive( url: str, download_root: str, - extract_root: Optional[str] = None, - filename: Optional[str] = None, - md5: Optional[str] = None, + extract_root: str | None = None, + filename: str | None = None, + md5: str | None = None, ) -> None: """Download and extract an archive. @@ -171,7 +163,7 @@ def download_and_extract_archive( def download_radiant_mlhub_dataset( - dataset_id: str, download_root: str, api_key: Optional[str] = None + dataset_id: str, download_root: str, api_key: str | None = None ) -> None: """Download a dataset from Radiant Earth. @@ -194,7 +186,7 @@ def download_radiant_mlhub_dataset( def download_radiant_mlhub_collection( - collection_id: str, download_root: str, api_key: Optional[str] = None + collection_id: str, download_root: str, api_key: str | None = None ) -> None: """Download a collection from Radiant Earth. @@ -261,10 +253,10 @@ def __getitem__(self, key: int) -> float: # noqa: D105 pass @overload - def __getitem__(self, key: slice) -> List[float]: # noqa: D105 + def __getitem__(self, key: slice) -> list[float]: # noqa: D105 pass - def __getitem__(self, key: Union[int, slice]) -> Union[float, List[float]]: + def __getitem__(self, key: int | slice) -> float | list[float]: """Index the (minx, maxx, miny, maxy, mint, maxt) tuple. Args: @@ -286,7 +278,7 @@ def __iter__(self) -> Iterator[float]: """ yield from [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt] - def __contains__(self, other: "BoundingBox") -> bool: + def __contains__(self, other: BoundingBox) -> bool: """Whether or not other is within the bounds of this bounding box. Args: @@ -306,7 +298,7 @@ def __contains__(self, other: "BoundingBox") -> bool: and (self.mint <= other.maxt <= self.maxt) ) - def __or__(self, other: "BoundingBox") -> "BoundingBox": + def __or__(self, other: BoundingBox) -> BoundingBox: """The union operator. Args: @@ -326,7 +318,7 @@ def __or__(self, other: "BoundingBox") -> "BoundingBox": max(self.maxt, other.maxt), ) - def __and__(self, other: "BoundingBox") -> "BoundingBox": + def __and__(self, other: BoundingBox) -> BoundingBox: """The intersection operator. Args: @@ -378,7 +370,7 @@ def volume(self) -> float: """ return self.area * (self.maxt - self.mint) - def intersects(self, other: "BoundingBox") -> bool: + def intersects(self, other: BoundingBox) -> bool: """Whether or not two bounding boxes intersect. Args: @@ -398,7 +390,7 @@ def intersects(self, other: "BoundingBox") -> bool: def split( self, proportion: float, horizontal: bool = True - ) -> Tuple["BoundingBox", "BoundingBox"]: + ) -> tuple[BoundingBox, BoundingBox]: """Split BoundingBox in two. Args: @@ -435,7 +427,7 @@ def split( return bbox1, bbox2 -def disambiguate_timestamp(date_str: str, format: str) -> Tuple[float, float]: +def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: """Disambiguate partial timestamps. TorchGeo stores the timestamp of each file in a spatiotemporal R-tree. If the full @@ -510,7 +502,7 @@ def working_dir(dirname: str, create: bool = False) -> Iterator[None]: os.chdir(cwd) -def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List[Any]]: +def _list_dict_to_dict_list(samples: Iterable[dict[Any, Any]]) -> dict[Any, list[Any]]: """Convert a list of dictionaries to a dictionary of lists. Args: @@ -528,7 +520,7 @@ def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List return collated -def _dict_list_to_list_dict(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: +def _dict_list_to_list_dict(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: """Convert a dictionary of lists to a list of dictionaries. Args: @@ -539,7 +531,7 @@ def _dict_list_to_list_dict(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, .. versionadded:: 0.2 """ - uncollated: List[Dict[Any, Any]] = [ + uncollated: list[dict[Any, Any]] = [ {} for _ in range(max(map(len, sample.values()))) ] for key, values in sample.items(): @@ -548,7 +540,7 @@ def _dict_list_to_list_dict(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, return uncollated -def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: +def stack_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: """Stack a list of samples along a new axis. Useful for forming a mini-batch of samples to pass to @@ -562,14 +554,14 @@ def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: .. versionadded:: 0.2 """ - collated: Dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: dict[Any, Any] = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.stack(value) return collated -def concat_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: +def concat_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: """Concatenate a list of samples along an existing axis. Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`. @@ -582,7 +574,7 @@ def concat_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: .. versionadded:: 0.2 """ - collated: Dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: dict[Any, Any] = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.cat(value) @@ -591,7 +583,7 @@ def concat_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: return collated -def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: +def merge_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: """Merge a list of samples. Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`. @@ -604,7 +596,7 @@ def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: .. versionadded:: 0.2 """ - collated: Dict[Any, Any] = {} + collated: dict[Any, Any] = {} for sample in samples: for key, value in sample.items(): if key in collated and isinstance(value, Tensor): @@ -616,7 +608,7 @@ def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: return collated -def unbind_samples(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: +def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: """Reverse of :func:`stack_samples`. Useful for turning a mini-batch of samples into a list of samples. These individual @@ -636,7 +628,7 @@ def unbind_samples(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: return _dict_list_to_list_dict(sample) -def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]": +def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: """Load an image file using rasterio. Args: @@ -646,7 +638,7 @@ def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]": the image """ with rasterio.open(path) as f: - array: "np.typing.NDArray[np.int_]" = f.read().astype(np.int32) + array: np.typing.NDArray[np.int_] = f.read().astype(np.int32) # NonGeoClassificationDataset expects images returned with channels last (HWC) array = array.transpose(1, 2, 0) return array @@ -665,8 +657,8 @@ def draw_semantic_segmentation_masks( image: Tensor, mask: Tensor, alpha: float = 0.5, - colors: Optional[Sequence[Union[str, Tuple[int, int, int]]]] = None, -) -> "np.typing.NDArray[np.uint8]": + colors: Sequence[str | tuple[int, int, int]] | None = None, +) -> np.typing.NDArray[np.uint8]: """Overlay a semantic segmentation mask onto an image. Args: @@ -690,8 +682,8 @@ def draw_semantic_segmentation_masks( def rgb_to_mask( - rgb: "np.typing.NDArray[np.uint8]", colors: List[Tuple[int, int, int]] -) -> "np.typing.NDArray[np.uint8]": + rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]] +) -> np.typing.NDArray[np.uint8]: """Converts an RGB colormap mask to a integer mask. Args: @@ -705,7 +697,7 @@ def rgb_to_mask( # we can map is 255 h, w = rgb.shape[:2] - mask: "np.typing.NDArray[np.uint8]" = np.zeros(shape=(h, w), dtype=np.uint8) + mask: np.typing.NDArray[np.uint8] = np.zeros(shape=(h, w), dtype=np.uint8) for i, c in enumerate(colors): cmask = rgb == c # Only update mask if class is present in mask @@ -715,11 +707,11 @@ def rgb_to_mask( def percentile_normalization( - img: "np.typing.NDArray[np.int_]", + img: np.typing.NDArray[np.int_], lower: float = 2, upper: float = 98, - axis: Optional[Union[int, Sequence[int]]] = None, -) -> "np.typing.NDArray[np.int_]": + axis: int | Sequence[int] | None = None, +) -> np.typing.NDArray[np.int_]: """Applies percentile normalization to an input image. Specifically, this will rescale the values in the input such that values <= the @@ -741,7 +733,7 @@ def percentile_normalization( assert lower < upper lower_percentile = np.percentile(img, lower, axis=axis) upper_percentile = np.percentile(img, upper, axis=axis) - img_normalized: "np.typing.NDArray[np.int_]" = np.clip( + img_normalized: np.typing.NDArray[np.int_] = np.clip( (img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1 ) return img_normalized diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 8509607f2f0..78370f31585 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -4,7 +4,7 @@ """Vaihingen dataset.""" import os -from typing import Callable, Dict, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -121,7 +121,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new Vaihingen2D dataset instance. @@ -148,7 +148,7 @@ def __init__( if os.path.exists(image) and os.path.exists(mask): self.files.append(dict(image=image, mask=mask)) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -242,7 +242,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 214cbbe24a6..4dbda994433 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -4,7 +4,7 @@ """NWPU VHR-10 dataset.""" import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -18,7 +18,7 @@ def convert_coco_poly_to_mask( - segmentations: List[int], height: int, width: int + segmentations: list[int], height: int, width: int ) -> Tensor: """Convert coco polygons to mask tensor. @@ -53,7 +53,7 @@ class ConvertCocoAnnotations: https://github.com/pytorch/vision/blob/v0.14.0/references/detection/coco_utils.py """ - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: """Converts MS COCO fields (boxes, masks & labels) from list of ints to tensors. Args: @@ -182,7 +182,7 @@ def __init__( self, root: str = "data", split: str = "positive", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -236,7 +236,7 @@ def __init__( self.coco_convert = ConvertCocoAnnotations() self.ids = list(sorted(self.coco.imgs.keys())) - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: @@ -247,7 +247,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: """ id_ = index % len(self) + 1 - sample: Dict[str, Any] = { + sample: dict[str, Any] = { "image": self._load_image(id_), "label": self._load_target(id_), } @@ -298,7 +298,7 @@ def _load_image(self, id_: int) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, id_: int) -> Dict[str, Any]: + def _load_target(self, id_: int) -> dict[str, Any]: """Load the annotations for a single image. Args: @@ -365,7 +365,7 @@ def _download(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, show_feats: Optional[str] = "both", diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index bf289844a07..3b423247a52 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -66,7 +66,7 @@ def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new xView2 dataset instance. @@ -89,7 +89,7 @@ def __init__( self.class2idx = {c: i for i, c in enumerate(self.classes)} self.files = self._load_files(root, split) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -121,7 +121,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files(self, root: str, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -221,7 +221,7 @@ def _verify(self) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index a03a29209ee..008f4e34d9d 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -4,7 +4,8 @@ """ZueriCrop dataset.""" import os -from typing import Callable, Dict, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Callable, Optional import matplotlib.pyplot as plt import torch @@ -64,7 +65,7 @@ def __init__( self, root: str = "data", bands: Sequence[str] = band_names, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, checksum: bool = False, ) -> None: @@ -103,7 +104,7 @@ def __init__( "h5py is not installed and is required to use this dataset" ) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: @@ -154,7 +155,7 @@ def _load_image(self, index: int) -> Tensor: tensor = torch.index_select(tensor, dim=1, index=self.band_indices) return tensor - def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]: + def _load_target(self, index: int) -> tuple[Tensor, Tensor, Tensor]: """Load the target mask for a single image. Args: @@ -262,7 +263,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], time_step: int = 0, show_titles: bool = True, suptitle: Optional[str] = None, diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 95cf5b5c84a..8553819550f 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -10,7 +10,7 @@ * https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py """ # noqa: E501 -from typing import Any, Callable, List, Union +from typing import Any, Callable, Union import torch.nn as nn from torchvision.models._api import WeightsEnum @@ -79,7 +79,7 @@ def get_weight(name: str) -> WeightsEnum: return eval(name) -def list_models() -> List[str]: +def list_models() -> list[str]: """List the registered models. .. versionadded:: 0.4 diff --git a/torchgeo/models/changestar.py b/torchgeo/models/changestar.py index 0e5d55312aa..172c36c9c81 100644 --- a/torchgeo/models/changestar.py +++ b/torchgeo/models/changestar.py @@ -3,8 +3,6 @@ """ChangeStar implementations.""" -from typing import Dict, List - import torch import torch.nn as nn from einops import rearrange @@ -41,7 +39,7 @@ def __init__( scale_factor: number of upsampling factor """ super().__init__() - layers: List[Module] = [ + layers: list[Module] = [ nn.modules.Sequential( nn.modules.Conv2d(in_channels, inner_channels, 3, 1, 1), nn.modules.BatchNorm2d(inner_channels), @@ -64,7 +62,7 @@ def __init__( self.convs = nn.modules.Sequential(*layers) - def forward(self, bi_feature: Tensor) -> List[Tensor]: + def forward(self, bi_feature: Tensor) -> list[Tensor]: """Forward pass of the model. Args: @@ -129,7 +127,7 @@ def __init__( raise ValueError(f"Unknown inference_mode: {inference_mode}") self.inference_mode = inference_mode - def forward(self, x: Tensor) -> Dict[str, Tensor]: + def forward(self, x: Tensor) -> dict[str, Tensor]: """Forward pass of the model. Args: @@ -151,7 +149,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: # change detection c12, c21 = self.changemixin(bi_feature) - results: Dict[str, Tensor] = {} + results: dict[str, Tensor] = {} if not self.training: results.update({"bi_seg_logit": bi_seg_logit}) if self.inference_mode == "t1t2": diff --git a/torchgeo/models/farseg.py b/torchgeo/models/farseg.py index 13f6601f901..a1faae16a5b 100644 --- a/torchgeo/models/farseg.py +++ b/torchgeo/models/farseg.py @@ -5,7 +5,7 @@ import math from collections import OrderedDict -from typing import List, cast +from typing import cast import torch.nn.functional as F import torchvision @@ -118,7 +118,7 @@ class _FSRelation(Module): def __init__( self, scene_embedding_channels: int, - in_channels_list: List[int], + in_channels_list: list[int], out_channels: int, ) -> None: """Initialize the _FSRelation module. @@ -157,7 +157,7 @@ def __init__( self.normalizer = Sigmoid() - def forward(self, scene_feature: Tensor, features: List[Tensor]) -> List[Tensor]: + def forward(self, scene_feature: Tensor, features: list[Tensor]) -> list[Tensor]: """Forward pass of the model.""" # [N, C, H, W] content_feats = [ @@ -184,7 +184,7 @@ def __init__( in_channels: int, out_channels: int, num_classes: int, - in_feature_output_strides: List[int] = [4, 8, 16, 32], + in_feature_output_strides: list[int] = [4, 8, 16, 32], out_feature_output_stride: int = 4, ) -> None: """Initialize the _LightWeightDecoder module. @@ -233,7 +233,7 @@ def __init__( UpsamplingBilinear2d(scale_factor=4), ) - def forward(self, features: List[Tensor]) -> Tensor: + def forward(self, features: list[Tensor]) -> Tensor: """Forward pass of the model.""" inner_feat_list = [] for idx, block in enumerate(self.blocks): diff --git a/torchgeo/models/fcsiam.py b/torchgeo/models/fcsiam.py index c8230b5f8f3..4d95a0eddcf 100644 --- a/torchgeo/models/fcsiam.py +++ b/torchgeo/models/fcsiam.py @@ -3,7 +3,8 @@ """Fully convolutional change detection (FCCD) implementations.""" -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import segmentation_models_pytorch as smp import torch diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 6bc106f167a..ef9bb05f10f 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -4,7 +4,8 @@ """TorchGeo batch samplers.""" import abc -from typing import Iterator, List, Optional, Tuple, Union +from collections.abc import Iterator +from typing import Optional, Union import torch from rtree.index import Index, Property @@ -15,7 +16,7 @@ from .utils import _to_tuple, get_random_bounding_box, tile_to_chips -class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC): +class BatchGeoSampler(Sampler[list[BoundingBox]], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. Unlike PyTorch's :class:`~torch.utils.data.BatchSampler`, :class:`BatchGeoSampler` @@ -46,7 +47,7 @@ def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> No self.roi = roi @abc.abstractmethod - def __iter__(self) -> Iterator[List[BoundingBox]]: + def __iter__(self) -> Iterator[list[BoundingBox]]: """Return a batch of indices of a dataset. Returns: @@ -65,7 +66,7 @@ class RandomBatchGeoSampler(BatchGeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[Tuple[float, float], float], + size: Union[tuple[float, float], float], batch_size: int, length: Optional[int] = None, roi: Optional[BoundingBox] = None, @@ -129,7 +130,7 @@ def __init__( if torch.sum(self.areas) == 0: self.areas += 1 - def __iter__(self) -> Iterator[List[BoundingBox]]: + def __iter__(self) -> Iterator[list[BoundingBox]]: """Return the indices of a dataset. Returns: diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 3c4f529c501..26daa41e2d6 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,7 +4,8 @@ """TorchGeo samplers.""" import abc -from typing import Callable, Iterable, Iterator, Optional, Tuple, Union +from collections.abc import Iterable, Iterator +from typing import Callable, Optional, Union import torch from rtree.index import Index, Property @@ -68,7 +69,7 @@ class RandomGeoSampler(GeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[Tuple[float, float], float], + size: Union[tuple[float, float], float], length: Optional[int], roi: Optional[BoundingBox] = None, units: Units = Units.PIXELS, @@ -176,8 +177,8 @@ class GridGeoSampler(GeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[Tuple[float, float], float], - stride: Union[Tuple[float, float], float], + size: Union[tuple[float, float], float], + stride: Union[tuple[float, float], float], roi: Optional[BoundingBox] = None, units: Units = Units.PIXELS, ) -> None: diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index d36f9a0550d..9c05cc867b4 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -4,7 +4,7 @@ """Common sampler utilities.""" import math -from typing import Optional, Tuple, Union, overload +from typing import Optional, Union, overload import torch @@ -12,16 +12,16 @@ @overload -def _to_tuple(value: Union[Tuple[int, int], int]) -> Tuple[int, int]: +def _to_tuple(value: Union[tuple[int, int], int]) -> tuple[int, int]: ... @overload -def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: +def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]: ... -def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: +def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]: """Convert value to a tuple if it is not already a tuple. Args: @@ -37,7 +37,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float + bounds: BoundingBox, size: Union[tuple[float, float], float], res: float ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -81,9 +81,9 @@ def get_random_bounding_box( def tile_to_chips( bounds: BoundingBox, - size: Tuple[float, float], - stride: Optional[Tuple[float, float]] = None, -) -> Tuple[int, int]: + size: tuple[float, float], + stride: Optional[tuple[float, float]] = None, +) -> tuple[int, int]: r"""Compute number of :term:`chips ` that can be sampled from a :term:`tile`. Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 03b7c76acde..eb59078413e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -4,7 +4,7 @@ """BYOL tasks.""" import os -from typing import Any, Dict, Optional, Tuple, cast +from typing import Any, Optional, cast import timm import torch @@ -47,7 +47,7 @@ class SimCLRAugmentation(nn.Module): https://arxiv.org/pdf/2002.05709.pdf for more details. """ - def __init__(self, image_size: Tuple[int, int] = (256, 256)) -> None: + def __init__(self, image_size: tuple[int, int] = (256, 256)) -> None: """Initialize a module for applying SimCLR augmentations. Args: @@ -217,7 +217,7 @@ class BYOL(nn.Module): def __init__( self, model: nn.Module, - image_size: Tuple[int, int] = (256, 256), + image_size: tuple[int, int] = (256, 256), hidden_layer: int = -2, in_channels: int = 4, projection_size: int = 256, @@ -338,7 +338,7 @@ def __init__(self, **kwargs: Any) -> None: # Creates `self.hparams` from kwargs self.save_hyperparameters() - self.hyperparams = cast(Dict[str, Any], self.hparams) + self.hyperparams = cast(dict[str, Any], self.hparams) self.config_task() @@ -353,7 +353,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: """ return self.model(*args, **kwargs) - def configure_optimizers(self) -> Dict[str, Any]: + def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 7974c4ed8b7..6d412502143 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -4,7 +4,7 @@ """Classification tasks.""" import os -from typing import Any, Dict, cast +from typing import Any, cast import matplotlib.pyplot as plt import timm @@ -98,7 +98,7 @@ def __init__(self, **kwargs: Any) -> None: # Creates `self.hparams` from kwargs self.save_hyperparameters() - self.hyperparams = cast(Dict[str, Any], self.hparams) + self.hyperparams = cast(dict[str, Any], self.hparams) self.config_task() @@ -247,7 +247,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: y_hat: Tensor = self(x).softmax(dim=-1) return y_hat - def configure_optimizers(self) -> Dict[str, Any]: + def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index e61cbfa41b5..4df35d98346 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -4,7 +4,7 @@ """Detection tasks.""" from functools import partial -from typing import Any, Dict, List, cast +from typing import Any, cast import matplotlib.pyplot as plt import torch @@ -166,7 +166,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__() # Creates `self.hparams` from kwargs self.save_hyperparameters() - self.hyperparams = cast(Dict[str, Any], self.hparams) + self.hyperparams = cast(dict[str, Any], self.hparams) self.config_task() @@ -284,7 +284,7 @@ def on_test_epoch_end(self) -> None: self.log_dict(renamed_metrics) self.test_metrics.reset() - def predict_step(self, *args: Any, **kwargs: Any) -> List[Dict[str, Tensor]]: + def predict_step(self, *args: Any, **kwargs: Any) -> list[dict[str, Tensor]]: """Compute and return the predictions. Args: @@ -295,10 +295,10 @@ def predict_step(self, *args: Any, **kwargs: Any) -> List[Dict[str, Tensor]]: """ batch = args[0] x = batch["image"] - y_hat: List[Dict[str, Tensor]] = self(x) + y_hat: list[dict[str, Tensor]] = self(x) return y_hat - def configure_optimizers(self) -> Dict[str, Any]: + def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 84a4fc20186..45daf0788b7 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -4,7 +4,7 @@ """Regression tasks.""" import os -from typing import Any, Dict, cast +from typing import Any, cast import matplotlib.pyplot as plt import timm @@ -76,7 +76,7 @@ def __init__(self, **kwargs: Any) -> None: # Creates `self.hparams` from kwargs self.save_hyperparameters() - self.hyperparams = cast(Dict[str, Any], self.hparams) + self.hyperparams = cast(dict[str, Any], self.hparams) self.config_task() self.train_metrics = MetricCollection( @@ -200,7 +200,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: y_hat: Tensor = self(x) return y_hat - def configure_optimizers(self) -> Dict[str, Any]: + def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3396ff7b5f7..759b4ffe41c 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -4,7 +4,7 @@ """Segmentation tasks.""" import warnings -from typing import Any, Dict, cast +from typing import Any, cast import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -105,7 +105,7 @@ def __init__(self, **kwargs: Any) -> None: # Creates `self.hparams` from kwargs self.save_hyperparameters() - self.hyperparams = cast(Dict[str, Any], self.hparams) + self.hyperparams = cast(dict[str, Any], self.hparams) if not isinstance(kwargs["ignore_index"], (int, type(None))): raise ValueError("ignore_index must be an int or None") @@ -262,7 +262,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: y_hat: Tensor = self(x).softmax(dim=1) return y_hat - def configure_optimizers(self) -> Dict[str, Any]: + def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 068b620efc3..530a8e33c66 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -5,7 +5,7 @@ import warnings from collections import OrderedDict -from typing import Optional, Tuple, Union, cast +from typing import Optional, Union, cast import torch import torch.nn as nn @@ -13,7 +13,7 @@ from torch.nn.modules import Conv2d, Module -def extract_backbone(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: +def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: """Extracts a backbone from a lightning checkpoint file. Args: @@ -54,7 +54,7 @@ def extract_backbone(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: return name, state_dict -def _get_input_layer_name_and_module(model: Module) -> Tuple[str, Module]: +def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: """Retrieve the input layer name and module from a timm model. Args: @@ -120,8 +120,8 @@ def reinit_initial_conv_layer( layer: Conv2d, new_in_channels: int, keep_rgb_weights: bool, - new_stride: Optional[Union[int, Tuple[int, int]]] = None, - new_padding: Optional[Union[str, Union[int, Tuple[int, int]]]] = None, + new_stride: Optional[Union[int, tuple[int, int]]] = None, + new_padding: Optional[Union[str, Union[int, tuple[int, int]]]] = None, ) -> Conv2d: """Clones a Conv2d layer while optionally retaining some of the original weights. diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index 6899959a4c6..eb72afff4a7 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -8,7 +8,7 @@ - https://github.com/awesome-spectral-indices/awesome-spectral-indices """ -from typing import Dict, Optional +from typing import Optional import torch from kornia.augmentation import IntensityAugmentationBase2D @@ -42,8 +42,8 @@ def __init__(self, index_a: int, index_b: int) -> None: def apply_transform( self, input: Tensor, - params: Dict[str, Tensor], - flags: Dict[str, int], + params: dict[str, Tensor], + flags: dict[str, int], transform: Optional[Tensor] = None, ) -> Tensor: """Apply the transform. @@ -317,8 +317,8 @@ def __init__(self, index_a: int, index_b: int, index_c: int) -> None: def apply_transform( self, input: Tensor, - params: Dict[str, Tensor], - flags: Dict[str, int], + params: dict[str, Tensor], + flags: dict[str, int], transform: Optional[Tensor] = None, ) -> Tensor: """Apply the transform. diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 5a71406ba70..decdba79da8 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -3,7 +3,7 @@ """TorchGeo transforms.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import kornia.augmentation as K import torch @@ -25,7 +25,7 @@ class AugmentationSequential(Module): def __init__( self, *args: Union[K.base._AugmentationBase, K.ImageSequential], - data_keys: List[str], + data_keys: list[str], ) -> None: """Initialize a new augmentation sequential instance. @@ -36,7 +36,7 @@ def __init__( super().__init__() self.data_keys = data_keys - keys: List[Union[str, int, DataKey]] = [] + keys: list[Union[str, int, DataKey]] = [] for key in data_keys: if key == "image": keys.append("input") @@ -47,7 +47,7 @@ def __init__( self.augs = K.AugmentationSequential(*args, data_keys=keys) - def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Perform augmentations and update data dict. Args: @@ -67,11 +67,11 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") inputs = [batch[k] for k in self.data_keys] - outputs_list: Union[Tensor, List[Tensor]] = self.augs(*inputs) + outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs) outputs_list = ( outputs_list if isinstance(outputs_list, list) else [outputs_list] ) - outputs: Dict[str, Tensor] = { + outputs: dict[str, Tensor] = { k: v for k, v in zip(self.data_keys, outputs_list) } batch.update(outputs) @@ -90,7 +90,7 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: class _RandomNCrop(K.GeometricAugmentationBase2D): """Take N random crops of a tensor.""" - def __init__(self, size: Tuple[int, int], num: int) -> None: + def __init__(self, size: tuple[int, int], num: int) -> None: """Initialize a new _RandomNCrop instance. Args: @@ -102,7 +102,7 @@ def __init__(self, size: Tuple[int, int], num: int) -> None: self.flags = {"size": size, "num": num} def compute_transformation( - self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any] + self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any] ) -> Tensor: """Compute the transformation. @@ -120,8 +120,8 @@ def compute_transformation( def apply_transform( self, input: Tensor, - params: Dict[str, Tensor], - flags: Dict[str, Any], + params: dict[str, Tensor], + flags: dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: """Apply the transform. @@ -144,7 +144,7 @@ def apply_transform( class _NCropGenerator(K.random_generator.CropGenerator): """Generate N random crops.""" - def __init__(self, size: Union[Tuple[int, int], Tensor], num: int) -> None: + def __init__(self, size: Union[tuple[int, int], Tensor], num: int) -> None: """Initialize a new _NCropGenerator instance. Args: @@ -155,8 +155,8 @@ def __init__(self, size: Union[Tuple[int, int], Tensor], num: int) -> None: self.num = num def forward( - self, batch_shape: Tuple[int, ...], same_on_batch: bool = False - ) -> Dict[str, Tensor]: + self, batch_shape: tuple[int, ...], same_on_batch: bool = False + ) -> dict[str, Tensor]: """Generate the crops. Args: diff --git a/train.py b/train.py index eb3e70d7f26..de2e6f85691 100755 --- a/train.py +++ b/train.py @@ -6,7 +6,7 @@ """torchgeo model training script.""" import os -from typing import Any, Dict, Tuple, Type, cast +from typing import Any, cast import lightning.pytorch as pl from lightning.pytorch import LightningDataModule, LightningModule, Trainer @@ -45,8 +45,8 @@ SemanticSegmentationTask, ) -TASK_TO_MODULES_MAPPING: Dict[ - str, Tuple[Type[LightningModule], Type[LightningDataModule]] +TASK_TO_MODULES_MAPPING: dict[ + str, tuple[type[LightningModule], type[LightningDataModule]] ] = { "bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule), @@ -161,9 +161,9 @@ def main(conf: DictConfig) -> None: # Choose task to run based on arguments or configuration ###################################### # Convert the DictConfig into a dictionary so that we can pass as kwargs. - task_args = cast(Dict[str, Any], OmegaConf.to_object(conf.experiment.module)) + task_args = cast(dict[str, Any], OmegaConf.to_object(conf.experiment.module)) datamodule_args = cast( - Dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule) + dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule) ) datamodule: LightningDataModule @@ -202,7 +202,7 @@ def main(conf: DictConfig) -> None: monitor=monitor_metric, min_delta=0.00, patience=18, mode=mode ) - trainer_args = cast(Dict[str, Any], OmegaConf.to_object(conf.trainer)) + trainer_args = cast(dict[str, Any], OmegaConf.to_object(conf.trainer)) trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback] trainer_args["logger"] = [tb_logger, csv_logger]