diff --git a/pyproject.toml b/pyproject.toml index ccc94cd766c..fb76fc6eeb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,27 +141,42 @@ exclude_lines = [ "@overload", ] +# https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] -python_version = "3.10" +# Import discovery ignore_missing_imports = true -show_error_codes = true -exclude = "(build|data|dist|docs/src|images|logo|logs|output)/" +exclude = "(build|data|dist|docs/.*|images|logo|.*logs|output|requirements)/" -# Strict -warn_unused_configs = true +# Disallow dynamic typing (TODO: work in progress) +disallow_any_unimported = false +disallow_any_expr = false +disallow_any_decorated = false +disallow_any_explicit = false disallow_any_generics = true disallow_subclassing_any = true + +# Untyped definitions and calls disallow_untyped_calls = true disallow_untyped_defs = true disallow_incomplete_defs = true -check_untyped_defs = true disallow_untyped_decorators = true -no_implicit_optional = true + +# Configuring warnings warn_redundant_casts = true warn_unused_ignores = true +warn_no_return = true warn_return_any = true -no_implicit_reexport = true +warn_unreachable = true + +# Miscellaneous strictness flags strict_equality = true +strict = true + +# Configuring error messages +pretty = true + +# Miscellaneous +warn_unused_configs = true [tool.pytest.ini_options] # Skip slow tests by default diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 73861011a7b..ae3550608d9 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -14,7 +14,7 @@ import shutil import subprocess import sys -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, TypeAlias, cast, overload @@ -367,7 +367,9 @@ def working_dir(dirname: Path, create: bool = False) -> Iterator[None]: os.chdir(cwd) -def _list_dict_to_dict_list(samples: Iterable[dict[Any, Any]]) -> dict[Any, list[Any]]: +def _list_dict_to_dict_list( + samples: Iterable[Mapping[Any, Any]], +) -> dict[Any, list[Any]]: """Convert a list of dictionaries to a dictionary of lists. Args: @@ -385,7 +387,9 @@ def _list_dict_to_dict_list(samples: Iterable[dict[Any, Any]]) -> dict[Any, list return collated -def _dict_list_to_list_dict(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: +def _dict_list_to_list_dict( + sample: Mapping[Any, Sequence[Any]], +) -> list[dict[Any, Any]]: """Convert a dictionary of lists to a list of dictionaries. Args: @@ -405,7 +409,7 @@ def _dict_list_to_list_dict(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, return uncollated -def stack_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: +def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: """Stack a list of samples along a new axis. Useful for forming a mini-batch of samples to pass to @@ -426,7 +430,7 @@ def stack_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: return collated -def concat_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: +def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: """Concatenate a list of samples along an existing axis. Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`. @@ -448,7 +452,7 @@ def concat_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: return collated -def merge_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: +def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: """Merge a list of samples. Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`. @@ -473,7 +477,7 @@ def merge_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: return collated -def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: +def unbind_samples(sample: MutableMapping[Any, Any]) -> list[dict[Any, Any]]: """Reverse of :func:`stack_samples`. Useful for turning a mini-batch of samples into a list of samples. These individual diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index b2210eb2518..15a1d7c5960 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -55,7 +55,7 @@ def __init__( self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Any]) -> dict[str, Any]: """Perform augmentations and update data dict. Args: @@ -99,7 +99,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Convert boxes to default [N, 4] if 'boxes' in batch: - batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # type:ignore[assignment] + batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # Torchmetrics does not support masks with a channel dimension if 'mask' in batch and batch['mask'].shape[1] == 1: