Skip to content

Commit

Permalink
Merge pull request #144 from csiro-coasts/task/py3.10-type-hints
Browse files Browse the repository at this point in the history
Upgrade to Python 3.10 using pyupgrade
  • Loading branch information
mx-moth authored Jul 8, 2024
2 parents 628f441 + ed2c4ff commit 9965a8b
Show file tree
Hide file tree
Showing 27 changed files with 156 additions and 174 deletions.
3 changes: 1 addition & 2 deletions docs/developing/grass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import enum
from collections.abc import Hashable, Sequence
from functools import cached_property
from typing import Optional

import numpy
import xarray
Expand Down Expand Up @@ -33,7 +32,7 @@ class Grass(DimensionConvention[GrassGridKind, GrassIndex]):
default_grid_kind = GrassGridKind.field

@classmethod
def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]:
def check_dataset(cls, dataset: xarray.Dataset) -> int | None:
# A Grass dataset is recognised by the 'Conventions' global attribute
if dataset.attrs['Conventions'] == 'Grass 1.0':
return Specificity.HIGH
Expand Down
6 changes: 3 additions & 3 deletions docs/roles.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from collections.abc import Iterable
from typing import Callable, cast
from collections.abc import Callable, Iterable
from typing import cast

import yaml
from docutils import nodes, utils
Expand Down Expand Up @@ -81,7 +81,7 @@ class Citation(Directive):

def load_citation_file(self) -> dict:
citation_file = self.options['citation_file']
with open(citation_file, 'r') as f:
with open(citation_file) as f:
return cast(dict, yaml.load(f, yaml.Loader))

def run(self) -> list[nodes.Node]:
Expand Down
2 changes: 1 addition & 1 deletion scripts/min_deps_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_requirements(
Yield (package name, major version, minor version, patch version)
"""
for line_number, line in enumerate(open(fname, 'r'), start=1):
for line_number, line in enumerate(open(fname), start=1):
if '#' in line:
line = line[:line.index('#')]
line = line.strip()
Expand Down
5 changes: 2 additions & 3 deletions scripts/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import shlex
import subprocess
import sys
from typing import Optional

PROJECT = pathlib.Path(__file__).parent.parent

Expand All @@ -29,7 +28,7 @@


def main(
args: Optional[list[str]] = None,
args: list[str] | None = None,
) -> None:
parser = argparse.ArgumentParser()
add_options(parser)
Expand Down Expand Up @@ -222,7 +221,7 @@ def output(*args: str) -> bytes:

def yn(
prompt: str,
default: Optional[bool] = None,
default: bool | None = None,
) -> bool:
examples = {True: '[Yn]', False: '[yN]', None: '[yn]'}[default]
prompt = f'{prompt.strip()} {examples} '
Expand Down
5 changes: 2 additions & 3 deletions src/emsarray/cli/command.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import argparse
from typing import Optional

from emsarray.cli import utils

Expand All @@ -24,11 +23,11 @@ def name(self) -> str:

#: A short description of what this subcommand does,
#: shown as part of the usage message for the base command.
help: Optional[str] = None
help: str | None = None

#: A longer description of what this subcommand does,
#: shown as part of the usage message for this subcommand.
description: Optional[str] = None
description: str | None = None

def add_parser(self, subparsers: argparse._SubParsersAction) -> None:
parser = subparsers.add_parser(
Expand Down
2 changes: 1 addition & 1 deletion src/emsarray/cli/commands/export_geometry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Callable

import xarray

Expand Down
7 changes: 4 additions & 3 deletions src/emsarray/cli/commands/plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse
import functools
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar
from typing import Any, TypeVar

import emsarray
from emsarray.cli import BaseCommand, CommandException
Expand All @@ -28,7 +29,7 @@ def __init__(
dest: str,
*,
value_type: Callable = str,
default: Optional[dict[str, Any]] = None,
default: dict[str, Any] | None = None,
**kwargs: Any,
):
if default is None:
Expand All @@ -42,7 +43,7 @@ def __call__(
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Any,
option_string: Optional[str] = None,
option_string: str | None = None,
) -> None:
super().__call__
holder = getattr(namespace, self.dest, {})
Expand Down
10 changes: 5 additions & 5 deletions src/emsarray/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import re
import sys
import textwrap
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from functools import wraps
from pathlib import Path
from typing import Callable, Optional, Protocol
from typing import Protocol

from shapely.geometry import box, shape
from shapely.geometry.base import BaseGeometry
Expand All @@ -30,7 +30,7 @@
class MainCallable(Protocol):
def __call__(
self,
argv: Optional[list[str]] = None,
argv: list[str] | None = None,
handle_errors: bool = True,
) -> None:
...
Expand Down Expand Up @@ -112,7 +112,7 @@ def decorator(
) -> MainCallable:
@wraps(fn)
def wrapper(
argv: Optional[list[str]] = None,
argv: list[str] | None = None,
handle_errors: bool = True,
) -> None:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -172,7 +172,7 @@ def nice_console_errors() -> Iterator:

class DoubleNewlineDescriptionFormatter(argparse.HelpFormatter):
def _fill_text(self, text: str, width: int, indent: str) -> str:
fill_text = super(DoubleNewlineDescriptionFormatter, self)._fill_text
fill_text = super()._fill_text

return '\n\n'.join(
fill_text(paragraph, width, indent)
Expand Down
4 changes: 2 additions & 2 deletions src/emsarray/compat/shapely.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from collections.abc import Iterable
from typing import Generic, TypeVar, Union, cast
from typing import Generic, TypeVar, cast

import numpy
import shapely
Expand Down Expand Up @@ -31,7 +31,7 @@ class SpatialIndex(Generic[T]):

def __init__(
self,
items: Union[numpy.ndarray, Iterable[tuple[BaseGeometry, T]]],
items: numpy.ndarray | Iterable[tuple[BaseGeometry, T]],
):
self.items = numpy.array(items, dtype=self.dtype)

Expand Down
62 changes: 30 additions & 32 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import enum
import logging
import warnings
from collections.abc import Hashable, Iterable, Sequence
from collections.abc import Callable, Hashable, Iterable, Sequence
from functools import cached_property
from typing import (
TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, cast
)
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

import numpy
import xarray
Expand Down Expand Up @@ -166,7 +164,7 @@ def check_validity(cls, dataset: xarray.Dataset) -> None:

@classmethod
@abc.abstractmethod
def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]:
def check_dataset(cls, dataset: xarray.Dataset) -> int | None:
"""
Check if a dataset uses this convention.
Expand Down Expand Up @@ -582,7 +580,7 @@ def wind_index(
self,
linear_index: int,
*,
grid_kind: Optional[GridKind] = None,
grid_kind: GridKind | None = None,
) -> Index:
"""Convert a linear index to a conventnion native index.
Expand Down Expand Up @@ -635,7 +633,7 @@ def wind_index(
def unravel_index(
self,
linear_index: int,
grid_kind: Optional[GridKind] = None,
grid_kind: GridKind | None = None,
) -> Index:
"""An alias for :meth:`Convention.wind_index()`.
Expand Down Expand Up @@ -772,7 +770,7 @@ def ravel(
self,
data_array: xarray.DataArray,
*,
linear_dimension: Optional[Hashable] = None,
linear_dimension: Hashable | None = None,
) -> xarray.DataArray:
"""
Flatten the surface dimensions of a :class:`~xarray.DataArray`,
Expand Down Expand Up @@ -815,9 +813,9 @@ def wind(
self,
data_array: xarray.DataArray,
*,
grid_kind: Optional[GridKind] = None,
axis: Optional[int] = None,
linear_dimension: Optional[Hashable] = None,
grid_kind: GridKind | None = None,
axis: int | None = None,
linear_dimension: Hashable | None = None,
) -> xarray.DataArray:
"""
Wind a flattened :class:`~xarray.DataArray`
Expand Down Expand Up @@ -935,9 +933,9 @@ def data_crs(self) -> 'CRS':
def plot_on_figure(
self,
figure: 'Figure',
scalar: Optional[DataArrayOrName] = None,
vector: Optional[tuple[DataArrayOrName, DataArrayOrName]] = None,
title: Optional[str] = None,
scalar: DataArrayOrName | None = None,
vector: tuple[DataArrayOrName, DataArrayOrName] | None = None,
title: str | None = None,
**kwargs: Any,
) -> None:
"""Plot values for a :class:`~xarray.DataArray`
Expand Down Expand Up @@ -1015,10 +1013,10 @@ def plot(self, *args: Any, **kwargs: Any) -> None:
def animate_on_figure(
self,
figure: 'Figure',
scalar: Optional[DataArrayOrName] = None,
vector: Optional[tuple[DataArrayOrName, DataArrayOrName]] = None,
coordinate: Optional[DataArrayOrName] = None,
title: Optional[Union[str, Callable[[Any], str]]] = None,
scalar: DataArrayOrName | None = None,
vector: tuple[DataArrayOrName, DataArrayOrName] | None = None,
coordinate: DataArrayOrName | None = None,
title: str | Callable[[Any], str] | None = None,
**kwargs: Any,
) -> 'FuncAnimation':
"""
Expand Down Expand Up @@ -1115,7 +1113,7 @@ def animate_on_figure(
@utils.timed_func
def make_poly_collection(
self,
data_array: Optional[DataArrayOrName] = None,
data_array: DataArrayOrName | None = None,
**kwargs: Any,
) -> 'PolyCollection':
"""
Expand Down Expand Up @@ -1192,7 +1190,7 @@ def make_poly_collection(

def make_patch_collection(
self,
data_array: Optional[DataArrayOrName] = None,
data_array: DataArrayOrName | None = None,
**kwargs: Any,
) -> 'PolyCollection':
warnings.warn(
Expand All @@ -1206,8 +1204,8 @@ def make_patch_collection(
def make_quiver(
self,
axes: 'Axes',
u: Optional[DataArrayOrName] = None,
v: Optional[DataArrayOrName] = None,
u: DataArrayOrName | None = None,
v: DataArrayOrName | None = None,
**kwargs: Any,
) -> 'Quiver':
"""
Expand Down Expand Up @@ -1238,7 +1236,7 @@ def make_quiver(
# sometimes preferring to fill them in later,
# so `u` and `v` are optional.
# If they are not provided, we set default quiver values of `numpy.nan`.
values: Union[tuple[numpy.ndarray, numpy.ndarray], tuple[float, float]]
values: tuple[numpy.ndarray, numpy.ndarray] | tuple[float, float]
values = numpy.nan, numpy.nan

if u is not None and v is not None:
Expand Down Expand Up @@ -1331,7 +1329,7 @@ def mask(self) -> numpy.ndarray:
return cast(numpy.ndarray, mask)

@cached_property
def geometry(self) -> Union[Polygon, MultiPolygon]:
def geometry(self) -> Polygon | MultiPolygon:
"""
A :class:`shapely.Polygon` or :class:`shapely.MultiPolygon` that represents
the geometry of the entire dataset.
Expand Down Expand Up @@ -1438,7 +1436,7 @@ def spatial_index(self) -> SpatialIndex[SpatialIndexItem[Index]]:
def get_index_for_point(
self,
point: Point,
) -> Optional[SpatialIndexItem[Index]]:
) -> SpatialIndexItem[Index] | None:
"""
Find the index for a :class:`~shapely.Point` in the dataset.
Expand Down Expand Up @@ -1761,8 +1759,8 @@ def ocean_floor(self) -> xarray.Dataset:
def normalize_depth_variables(
self,
*,
positive_down: Optional[bool] = None,
deep_to_shallow: Optional[bool] = None,
positive_down: bool | None = None,
deep_to_shallow: bool | None = None,
) -> xarray.Dataset:
"""An alias for :func:`emsarray.operations.depth.normalize_depth_variables`"""
return depth.normalize_depth_variables(
Expand Down Expand Up @@ -1895,7 +1893,7 @@ def wind_index(
self,
linear_index: int,
*,
grid_kind: Optional[GridKind] = None,
grid_kind: GridKind | None = None,
) -> Index:
if grid_kind is None:
grid_kind = self.default_grid_kind
Expand All @@ -1907,7 +1905,7 @@ def ravel(
self,
data_array: xarray.DataArray,
*,
linear_dimension: Optional[Hashable] = None,
linear_dimension: Hashable | None = None,
) -> xarray.DataArray:
kind = self.get_grid_kind(data_array)
dimensions = self.grid_dimensions[kind]
Expand All @@ -1919,9 +1917,9 @@ def wind(
self,
data_array: xarray.DataArray,
*,
grid_kind: Optional[GridKind] = None,
axis: Optional[int] = None,
linear_dimension: Optional[Hashable] = None,
grid_kind: GridKind | None = None,
axis: int | None = None,
linear_dimension: Hashable | None = None,
) -> xarray.DataArray:
if axis is not None:
linear_dimension = data_array.dims[axis]
Expand Down
Loading

0 comments on commit 9965a8b

Please sign in to comment.