Skip to content
forked from pydata/xarray

Commit

Permalink
Fix typing errors using mypy 1.2 (pydata#7752)
Browse files Browse the repository at this point in the history
* test newest mypy

* Update ci-additional.yaml

* remove ignores

* add typing

* Use ClassVar

* Generalize data_vars typing concat.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use a normal method to retrieve a type of Variable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ignore plotfunc error

* force reinstall

* remove outdated comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Apr 15, 2023
1 parent b889208 commit 6b7515c
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 28 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
python xarray/util/print_versions.py
- name: Install mypy
run: |
python -m pip install 'mypy<0.990'
python -m pip install mypy --force-reinstall
- name: Run mypy
run: |
Expand Down Expand Up @@ -173,7 +173,7 @@ jobs:
python xarray/util/print_versions.py
- name: Install mypy
run: |
python -m pip install 'mypy<0.990'
python -m pip install mypy --force-reinstall
- name: Run mypy
run: |
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,8 @@ def _nested_combine(
return combined


# Define type for arbitrarily-nested list of lists recursively
# Currently mypy cannot handle this but other linters can (https://stackoverflow.com/a/53845083/3154101)
DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] # type: ignore[misc]
# Define type for arbitrarily-nested list of lists recursively:
DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]]


def combine_nested(
Expand Down
16 changes: 9 additions & 7 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, cast, overload
from typing import TYPE_CHECKING, Any, Union, cast, overload

import pandas as pd

Expand All @@ -27,12 +27,14 @@
JoinOptions,
)

T_DataVars = Union[ConcatOptions, Iterable[Hashable]]


@overload
def concat(
objs: Iterable[T_Dataset],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
data_vars: T_DataVars = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
Expand All @@ -47,7 +49,7 @@ def concat(
def concat(
objs: Iterable[T_DataArray],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
data_vars: T_DataVars = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
Expand All @@ -61,7 +63,7 @@ def concat(
def concat(
objs,
dim,
data_vars="all",
data_vars: T_DataVars = "all",
coords="different",
compat: CompatOptions = "equals",
positions=None,
Expand Down Expand Up @@ -291,7 +293,7 @@ def _calc_concat_dim_index(
return dim, index


def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat):
def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, compat):
"""
Determine which dataset variables need to be concatenated in the result,
"""
Expand Down Expand Up @@ -445,7 +447,7 @@ def _parse_datasets(
def _dataset_concat(
datasets: list[T_Dataset],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
data_vars: T_DataVars,
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
Expand Down Expand Up @@ -665,7 +667,7 @@ def get_indexes(name):
def _dataarray_concat(
arrays: Iterable[T_DataArray],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
data_vars: T_DataVars,
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def _construct(
window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim}

window_dims = self._mapping_to_list(
window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506
window_dim, allow_default=False, allow_allsame=False
)
strides = self._mapping_to_list(stride, default=1)

Expand Down Expand Up @@ -753,7 +753,7 @@ def construct(
window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim}

window_dims = self._mapping_to_list(
window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506
window_dim, allow_default=False, allow_allsame=False
)
strides = self._mapping_to_list(stride, default=1)

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index):
dtype = np.dtype("O")
elif hasattr(array, "categories"):
# category isn't a real numpy dtype
dtype = array.categories.dtype # type: ignore[union-attr]
dtype = array.categories.dtype
elif not is_valid_numpy_dtype(array.dtype):
dtype = np.dtype("O")
else:
Expand Down
3 changes: 1 addition & 2 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,7 @@ def test_concat_data_vars_typing(self) -> None:
actual = concat(objs, dim="x", data_vars="minimal")
assert_identical(data, actual)

def test_concat_data_vars(self):
# TODO: annotating this func fails
def test_concat_data_vars(self) -> None:
data = Dataset({"foo": ("x", np.random.randn(10))})
objs: list[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))]
for data_vars in ["minimal", "different", "all", [], ["foo"]]:
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,7 +2042,7 @@ def test_seaborn_palette_as_cmap(self) -> None:
def test_convenient_facetgrid(self) -> None:
a = easy_array((10, 15, 4))
d = DataArray(a, dims=["y", "x", "z"])
g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2)
g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015

assert_array_equal(g.axs.shape, [2, 2])
for (y, x), ax in np.ndenumerate(g.axs):
Expand All @@ -2051,7 +2051,7 @@ def test_convenient_facetgrid(self) -> None:
assert "x" == ax.get_xlabel()

# Inferring labels
g = self.plotfunc(d, col="z", col_wrap=2)
g = self.plotfunc(d, col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015
assert_array_equal(g.axs.shape, [2, 2])
for (y, x), ax in np.ndenumerate(g.axs):
assert ax.has_data()
Expand Down
20 changes: 11 additions & 9 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from copy import copy, deepcopy
from datetime import datetime, timedelta
from textwrap import dedent
Expand Down Expand Up @@ -61,8 +62,10 @@ def var():
return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5))


class VariableSubclassobjects:
cls: staticmethod[Variable]
class VariableSubclassobjects(ABC):
@abstractmethod
def cls(self, *args, **kwargs) -> Variable:
raise NotImplementedError

def test_properties(self):
data = 0.5 * np.arange(10)
Expand Down Expand Up @@ -1056,7 +1059,8 @@ def test_rolling_window_errors(self, dim, window, window_dim, center):


class TestVariable(VariableSubclassobjects):
cls = staticmethod(Variable)
def cls(self, *args, **kwargs) -> Variable:
return Variable(*args, **kwargs)

@pytest.fixture(autouse=True)
def setup(self):
Expand Down Expand Up @@ -2228,13 +2232,10 @@ def test_coarsen_keep_attrs(self, operation="mean"):
assert new.attrs == _attrs


def _init_dask_variable(*args, **kwargs):
return Variable(*args, **kwargs).chunk()


@requires_dask
class TestVariableWithDask(VariableSubclassobjects):
cls = staticmethod(_init_dask_variable)
def cls(self, *args, **kwargs) -> Variable:
return Variable(*args, **kwargs).chunk()

def test_chunk(self):
unblocked = Variable(["dim_0", "dim_1"], np.ones((3, 4)))
Expand Down Expand Up @@ -2346,7 +2347,8 @@ def test_as_sparse(self):


class TestIndexVariable(VariableSubclassobjects):
cls = staticmethod(IndexVariable)
def cls(self, *args, **kwargs) -> IndexVariable:
return IndexVariable(*args, **kwargs)

def test_init(self):
with pytest.raises(ValueError, match=r"must be 1-dimensional"):
Expand Down

0 comments on commit 6b7515c

Please sign in to comment.