Skip to content

Commit

Permalink
Add Ellipsis typehints (#7017)
Browse files Browse the repository at this point in the history
* use ellipsis in dot

* add ellipsis to more funcs
  • Loading branch information
headtr1ck authored Sep 11, 2022
1 parent c0011e1 commit b018442
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 34 deletions.
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ ignore =
E501 # line too long - let black worry about that
E731 # do not assign a lambda expression, use a def
W503 # line break before binary operator
exclude=
exclude =
.eggs
doc
builtins =
ellipsis

[isort]
profile = black
Expand Down
25 changes: 15 additions & 10 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, JoinOptions
from .types import CombineAttrsOptions, Ellipsis, JoinOptions

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -1622,7 +1622,11 @@ def cross(
return c


def dot(*arrays, dims=None, **kwargs):
def dot(
*arrays,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
Expand Down Expand Up @@ -1711,10 +1715,7 @@ def dot(*arrays, dims=None, **kwargs):
if len(arrays) == 0:
raise TypeError("At least one array should be given.")

if isinstance(dims, str):
dims = (dims,)

common_dims = set.intersection(*[set(arr.dims) for arr in arrays])
common_dims: set[Hashable] = set.intersection(*(set(arr.dims) for arr in arrays))
all_dims = []
for arr in arrays:
all_dims += [d for d in arr.dims if d not in all_dims]
Expand All @@ -1724,21 +1725,25 @@ def dot(*arrays, dims=None, **kwargs):

if dims is ...:
dims = all_dims
elif isinstance(dims, str):
dims = (dims,)
elif dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dims = tuple(d for d, c in dim_counts.items() if c > 1)

dims = tuple(dims) # make dims a tuple
dot_dims: set[Hashable] = set(dims) # type:ignore[arg-type]

# dimensions to be parallelized
broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims)
broadcast_dims = common_dims - dot_dims
input_core_dims = [
[d for d in arr.dims if d not in broadcast_dims] for arr in arrays
]
output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)]
output_core_dims = [
[d for d in all_dims if d not in dot_dims and d not in broadcast_dims]
]

# construct einsum subscripts, such as '...abc,...ab->...c'
# Note: input_core_dims are always moved to the last position
Expand Down
9 changes: 5 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .types import (
CoarsenBoundaryOptions,
DatetimeUnitOptions,
Ellipsis,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -3769,7 +3770,7 @@ def imag(self: T_DataArray) -> T_DataArray:
def dot(
self: T_DataArray,
other: T_DataArray,
dims: Hashable | Sequence[Hashable] | None = None,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
) -> T_DataArray:
"""Perform dot product of two DataArrays along their shared dims.
Expand All @@ -3779,7 +3780,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims : ..., Hashable or sequence of Hashable, optional
dims : ..., str or Iterable of Hashable, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.
Expand Down Expand Up @@ -4773,7 +4774,7 @@ def idxmax(
# https://github.com/python/mypy/issues/12846 is resolved
def argmin(
self,
dim: Hashable | Sequence[Hashable] | None = None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down Expand Up @@ -4878,7 +4879,7 @@ def argmin(
# https://github.com/python/mypy/issues/12846 is resolved
def argmax(
self,
dim: Hashable | Sequence[Hashable] = None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
CombineAttrsOptions,
CompatOptions,
DatetimeUnitOptions,
Ellipsis,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -4255,7 +4256,7 @@ def _get_stack_index(

def _stack_once(
self: T_Dataset,
dims: Sequence[Hashable],
dims: Sequence[Hashable | Ellipsis],
new_dim: Hashable,
index_cls: type[Index],
create_index: bool | None = True,
Expand Down Expand Up @@ -4314,10 +4315,10 @@ def _stack_once(

def stack(
self: T_Dataset,
dimensions: Mapping[Any, Sequence[Hashable]] | None = None,
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable],
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
) -> T_Dataset:
"""
Stack any number of existing dimensions into a single new dimension.
Expand Down
11 changes: 9 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np

if TYPE_CHECKING:
from .common import DataWithCoords

from .common import AbstractArray, DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, GroupBy
Expand All @@ -29,13 +30,19 @@
# from typing_extensions import Self
# except ImportError:
# Self: Any = None
Self: Any = None
Self = TypeVar("Self")

Ellipsis = ellipsis

else:
Self: Any = None
Ellipsis: Any = None


T_Dataset = TypeVar("T_Dataset", bound="Dataset")
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_Variable = TypeVar("T_Variable", bound="Variable")
T_Array = TypeVar("T_Array", bound="AbstractArray")
T_Index = TypeVar("T_Index", bound="Index")

T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

if TYPE_CHECKING:
from .types import (
Ellipsis,
ErrorOptionsWithWarn,
PadModeOptions,
PadReflectOptions,
Expand Down Expand Up @@ -1478,7 +1479,7 @@ def roll(self, shifts=None, **shifts_kwargs):

def transpose(
self,
*dims: Hashable,
*dims: Hashable | Ellipsis,
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Variable:
"""Return a new Variable object with transposed dimensions.
Expand Down Expand Up @@ -2555,7 +2556,7 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float):
def _unravel_argminmax(
self,
argminmax: str,
dim: Hashable | Sequence[Hashable] | None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None,
axis: int | None,
keep_attrs: bool | None,
skipna: bool | None,
Expand Down Expand Up @@ -2624,7 +2625,7 @@ def _unravel_argminmax(

def argmin(
self,
dim: Hashable | Sequence[Hashable] = None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
axis: int = None,
keep_attrs: bool = None,
skipna: bool = None,
Expand Down
16 changes: 8 additions & 8 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .computation import apply_ufunc, dot
from .npcompat import ArrayLike
from .pycompat import is_duck_dask_array
from .types import T_Xarray
from .types import Ellipsis, T_Xarray

# Weighted quantile methods are a subset of the numpy supported quantile methods.
QUANTILE_METHODS = Literal[
Expand Down Expand Up @@ -206,7 +206,7 @@ def _check_dim(self, dim: Hashable | Iterable[Hashable] | None):
def _reduce(
da: DataArray,
weights: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | Ellipsis | None = None,
skipna: bool | None = None,
) -> DataArray:
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
Expand All @@ -227,7 +227,7 @@ def _reduce(
return dot(da, weights, dims=dim)

def _sum_of_weights(
self, da: DataArray, dim: Hashable | Iterable[Hashable] | None = None
self, da: DataArray, dim: str | Iterable[Hashable] | None = None
) -> DataArray:
"""Calculate the sum of weights, accounting for missing values"""

Expand All @@ -251,7 +251,7 @@ def _sum_of_weights(
def _sum_of_squares(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
Expand All @@ -263,7 +263,7 @@ def _sum_of_squares(
def _weighted_sum(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
Expand All @@ -273,7 +273,7 @@ def _weighted_sum(
def _weighted_mean(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
Expand All @@ -287,7 +287,7 @@ def _weighted_mean(
def _weighted_var(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
Expand All @@ -301,7 +301,7 @@ def _weighted_var(
def _weighted_std(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,7 @@ def apply_truncate_x_x_valid(obj):


@pytest.mark.parametrize("use_dask", [True, False])
def test_dot(use_dask) -> None:
def test_dot(use_dask: bool) -> None:
if use_dask:
if not has_dask:
pytest.skip("test for dask.")
Expand Down Expand Up @@ -1862,7 +1862,7 @@ def test_dot(use_dask) -> None:


@pytest.mark.parametrize("use_dask", [True, False])
def test_dot_align_coords(use_dask) -> None:
def test_dot_align_coords(use_dask: bool) -> None:
# GH 3694

if use_dask:
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6420,7 +6420,7 @@ def test_deepcopy_obj_array() -> None:
assert x0.values[0] is not x1.values[0]


def test_clip(da) -> None:
def test_clip(da: DataArray) -> None:
with raise_if_dask_computes():
result = da.clip(min=0.5)
assert result.min(...) >= 0.5
Expand Down

0 comments on commit b018442

Please sign in to comment.