Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ellipsis typehints #7017

Merged
merged 2 commits into from
Sep 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ ignore =
E501 # line too long - let black worry about that
E731 # do not assign a lambda expression, use a def
W503 # line break before binary operator
exclude=
exclude =
.eggs
doc
builtins =
ellipsis

[isort]
profile = black
Expand Down
25 changes: 15 additions & 10 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, JoinOptions
from .types import CombineAttrsOptions, Ellipsis, JoinOptions

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -1622,7 +1622,11 @@ def cross(
return c


def dot(*arrays, dims=None, **kwargs):
def dot(
*arrays,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.

Expand Down Expand Up @@ -1711,10 +1715,7 @@ def dot(*arrays, dims=None, **kwargs):
if len(arrays) == 0:
raise TypeError("At least one array should be given.")

if isinstance(dims, str):
dims = (dims,)

common_dims = set.intersection(*[set(arr.dims) for arr in arrays])
common_dims: set[Hashable] = set.intersection(*(set(arr.dims) for arr in arrays))
all_dims = []
for arr in arrays:
all_dims += [d for d in arr.dims if d not in all_dims]
Expand All @@ -1724,21 +1725,25 @@ def dot(*arrays, dims=None, **kwargs):

if dims is ...:
dims = all_dims
elif isinstance(dims, str):
dims = (dims,)
elif dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dims = tuple(d for d, c in dim_counts.items() if c > 1)

dims = tuple(dims) # make dims a tuple
dot_dims: set[Hashable] = set(dims) # type:ignore[arg-type]

# dimensions to be parallelized
broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims)
broadcast_dims = common_dims - dot_dims
input_core_dims = [
[d for d in arr.dims if d not in broadcast_dims] for arr in arrays
]
output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)]
output_core_dims = [
[d for d in all_dims if d not in dot_dims and d not in broadcast_dims]
]

# construct einsum subscripts, such as '...abc,...ab->...c'
# Note: input_core_dims are always moved to the last position
Expand Down
9 changes: 5 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .types import (
CoarsenBoundaryOptions,
DatetimeUnitOptions,
Ellipsis,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -3769,7 +3770,7 @@ def imag(self: T_DataArray) -> T_DataArray:
def dot(
self: T_DataArray,
other: T_DataArray,
dims: Hashable | Sequence[Hashable] | None = None,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent! Great that we can roll the str | Iterable[Hashable] out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as soon as you try to type some internals you will run into problems with Hashable | Sequence[Hashable].
This really seems to be the only way.

We should probably add a parser for this in utils such that all methods handle it the same.
Maybe I can come up with a PR for that.

) -> T_DataArray:
"""Perform dot product of two DataArrays along their shared dims.

Expand All @@ -3779,7 +3780,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims : ..., Hashable or sequence of Hashable, optional
dims : ..., str or Iterable of Hashable, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.

Expand Down Expand Up @@ -4773,7 +4774,7 @@ def idxmax(
# https://github.com/python/mypy/issues/12846 is resolved
def argmin(
self,
dim: Hashable | Sequence[Hashable] | None = None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down Expand Up @@ -4878,7 +4879,7 @@ def argmin(
# https://github.com/python/mypy/issues/12846 is resolved
def argmax(
self,
dim: Hashable | Sequence[Hashable] = None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
CombineAttrsOptions,
CompatOptions,
DatetimeUnitOptions,
Ellipsis,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -4255,7 +4256,7 @@ def _get_stack_index(

def _stack_once(
self: T_Dataset,
dims: Sequence[Hashable],
dims: Sequence[Hashable | Ellipsis],
new_dim: Hashable,
index_cls: type[Index],
create_index: bool | None = True,
Expand Down Expand Up @@ -4314,10 +4315,10 @@ def _stack_once(

def stack(
self: T_Dataset,
dimensions: Mapping[Any, Sequence[Hashable]] | None = None,
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable],
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
) -> T_Dataset:
"""
Stack any number of existing dimensions into a single new dimension.
Expand Down
11 changes: 9 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np

if TYPE_CHECKING:
from .common import DataWithCoords

from .common import AbstractArray, DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, GroupBy
Expand All @@ -29,13 +30,19 @@
# from typing_extensions import Self
# except ImportError:
# Self: Any = None
Self: Any = None
Self = TypeVar("Self")

Ellipsis = ellipsis

else:
Self: Any = None
Ellipsis: Any = None


T_Dataset = TypeVar("T_Dataset", bound="Dataset")
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_Variable = TypeVar("T_Variable", bound="Variable")
T_Array = TypeVar("T_Array", bound="AbstractArray")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops that was not supposed to be checked in, but I guess it doesn't hurt.

T_Index = TypeVar("T_Index", bound="Index")

T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

if TYPE_CHECKING:
from .types import (
Ellipsis,
ErrorOptionsWithWarn,
PadModeOptions,
PadReflectOptions,
Expand Down Expand Up @@ -1478,7 +1479,7 @@ def roll(self, shifts=None, **shifts_kwargs):

def transpose(
self,
*dims: Hashable,
*dims: Hashable | Ellipsis,
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Variable:
"""Return a new Variable object with transposed dimensions.
Expand Down Expand Up @@ -2555,7 +2556,7 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float):
def _unravel_argminmax(
self,
argminmax: str,
dim: Hashable | Sequence[Hashable] | None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None,
axis: int | None,
keep_attrs: bool | None,
skipna: bool | None,
Expand Down Expand Up @@ -2624,7 +2625,7 @@ def _unravel_argminmax(

def argmin(
self,
dim: Hashable | Sequence[Hashable] = None,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
axis: int = None,
keep_attrs: bool = None,
skipna: bool = None,
Expand Down
16 changes: 8 additions & 8 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .computation import apply_ufunc, dot
from .npcompat import ArrayLike
from .pycompat import is_duck_dask_array
from .types import T_Xarray
from .types import Ellipsis, T_Xarray

# Weighted quantile methods are a subset of the numpy supported quantile methods.
QUANTILE_METHODS = Literal[
Expand Down Expand Up @@ -206,7 +206,7 @@ def _check_dim(self, dim: Hashable | Iterable[Hashable] | None):
def _reduce(
da: DataArray,
weights: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | Ellipsis | None = None,
skipna: bool | None = None,
) -> DataArray:
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
Expand All @@ -227,7 +227,7 @@ def _reduce(
return dot(da, weights, dims=dim)

def _sum_of_weights(
self, da: DataArray, dim: Hashable | Iterable[Hashable] | None = None
self, da: DataArray, dim: str | Iterable[Hashable] | None = None
) -> DataArray:
"""Calculate the sum of weights, accounting for missing values"""

Expand All @@ -251,7 +251,7 @@ def _sum_of_weights(
def _sum_of_squares(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
Expand All @@ -263,7 +263,7 @@ def _sum_of_squares(
def _weighted_sum(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
Expand All @@ -273,7 +273,7 @@ def _weighted_sum(
def _weighted_mean(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
Expand All @@ -287,7 +287,7 @@ def _weighted_mean(
def _weighted_var(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
Expand All @@ -301,7 +301,7 @@ def _weighted_var(
def _weighted_std(
self,
da: DataArray,
dim: Hashable | Iterable[Hashable] | None = None,
dim: str | Iterable[Hashable] | None = None,
skipna: bool | None = None,
) -> DataArray:
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,7 @@ def apply_truncate_x_x_valid(obj):


@pytest.mark.parametrize("use_dask", [True, False])
def test_dot(use_dask) -> None:
def test_dot(use_dask: bool) -> None:
if use_dask:
if not has_dask:
pytest.skip("test for dask.")
Expand Down Expand Up @@ -1862,7 +1862,7 @@ def test_dot(use_dask) -> None:


@pytest.mark.parametrize("use_dask", [True, False])
def test_dot_align_coords(use_dask) -> None:
def test_dot_align_coords(use_dask: bool) -> None:
# GH 3694

if use_dask:
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6420,7 +6420,7 @@ def test_deepcopy_obj_array() -> None:
assert x0.values[0] is not x1.values[0]


def test_clip(da) -> None:
def test_clip(da: DataArray) -> None:
with raise_if_dask_computes():
result = da.clip(min=0.5)
assert result.min(...) >= 0.5
Expand Down