From 6b7515c20efb9a280370eb755d2410958e893a96 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 15 Apr 2023 20:31:57 +0200 Subject: [PATCH] Fix typing errors using mypy 1.2 (#7752) * test newest mypy * Update ci-additional.yaml * remove ignores * add typing * Use ClassVar * Generalize data_vars typing concat. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use a normal method to retrieve a type of Variable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ignore plotfunc error * force reinstall * remove outdated comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 4 ++-- xarray/core/combine.py | 5 ++--- xarray/core/concat.py | 16 +++++++++------- xarray/core/rolling.py | 4 ++-- xarray/core/utils.py | 2 +- xarray/tests/test_concat.py | 3 +-- xarray/tests/test_plot.py | 4 ++-- xarray/tests/test_variable.py | 20 +++++++++++--------- 8 files changed, 30 insertions(+), 28 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 6f069af5da6..d9956570991 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -119,7 +119,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install 'mypy<0.990' + python -m pip install mypy --force-reinstall - name: Run mypy run: | @@ -173,7 +173,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install 'mypy<0.990' + python -m pip install mypy --force-reinstall - name: Run mypy run: | diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 946f71e5d28..8106c295f5a 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -369,9 +369,8 @@ def _nested_combine( return combined -# Define type for arbitrarily-nested list of lists recursively -# Currently mypy cannot handle this but other linters can (https://stackoverflow.com/a/53845083/3154101) -DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] # type: ignore[misc] +# Define type for arbitrarily-nested list of lists recursively: +DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] def combine_nested( diff --git a/xarray/core/concat.py b/xarray/core/concat.py index f092911948f..dcf2a23d311 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, Union, cast, overload import pandas as pd @@ -27,12 +27,14 @@ JoinOptions, ) + T_DataVars = Union[ConcatOptions, Iterable[Hashable]] + @overload def concat( objs: Iterable[T_Dataset], dim: Hashable | T_DataArray | pd.Index, - data_vars: ConcatOptions | list[Hashable] = "all", + data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, @@ -47,7 +49,7 @@ def concat( def concat( objs: Iterable[T_DataArray], dim: Hashable | T_DataArray | pd.Index, - data_vars: ConcatOptions | list[Hashable] = "all", + data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, @@ -61,7 +63,7 @@ def concat( def concat( objs, dim, - data_vars="all", + data_vars: T_DataVars = "all", coords="different", compat: CompatOptions = "equals", positions=None, @@ -291,7 +293,7 @@ def _calc_concat_dim_index( return dim, index -def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat): +def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, compat): """ Determine which dataset variables need to be concatenated in the result, """ @@ -445,7 +447,7 @@ def _parse_datasets( def _dataset_concat( datasets: list[T_Dataset], dim: str | T_DataArray | pd.Index, - data_vars: str | list[str], + data_vars: T_DataVars, coords: str | list[str], compat: CompatOptions, positions: Iterable[Iterable[int]] | None, @@ -665,7 +667,7 @@ def get_indexes(name): def _dataarray_concat( arrays: Iterable[T_DataArray], dim: str | T_DataArray | pd.Index, - data_vars: str | list[str], + data_vars: T_DataVars, coords: str | list[str], compat: CompatOptions, positions: Iterable[Iterable[int]] | None, diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 8b9f31bfdfd..7eb4e9c7687 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -376,7 +376,7 @@ def _construct( window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} window_dims = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506 + window_dim, allow_default=False, allow_allsame=False ) strides = self._mapping_to_list(stride, default=1) @@ -753,7 +753,7 @@ def construct( window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} window_dims = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506 + window_dim, allow_default=False, allow_allsame=False ) strides = self._mapping_to_list(stride, default=1) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 08625fe7d95..1c90a2410f2 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -113,7 +113,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index): dtype = np.dtype("O") elif hasattr(array, "categories"): # category isn't a real numpy dtype - dtype = array.categories.dtype # type: ignore[union-attr] + dtype = array.categories.dtype elif not is_valid_numpy_dtype(array.dtype): dtype = np.dtype("O") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 9021ce2522b..030f653e031 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -538,8 +538,7 @@ def test_concat_data_vars_typing(self) -> None: actual = concat(objs, dim="x", data_vars="minimal") assert_identical(data, actual) - def test_concat_data_vars(self): - # TODO: annotating this func fails + def test_concat_data_vars(self) -> None: data = Dataset({"foo": ("x", np.random.randn(10))}) objs: list[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] for data_vars in ["minimal", "different", "all", [], ["foo"]]: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b587b890ef0..18ca49670ba 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2042,7 +2042,7 @@ def test_seaborn_palette_as_cmap(self) -> None: def test_convenient_facetgrid(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) - g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015 assert_array_equal(g.axs.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axs): @@ -2051,7 +2051,7 @@ def test_convenient_facetgrid(self) -> None: assert "x" == ax.get_xlabel() # Inferring labels - g = self.plotfunc(d, col="z", col_wrap=2) + g = self.plotfunc(d, col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015 assert_array_equal(g.axs.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axs): assert ax.has_data() diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index f9fa79dd8c0..b92db16e34b 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from abc import ABC, abstractmethod from copy import copy, deepcopy from datetime import datetime, timedelta from textwrap import dedent @@ -61,8 +62,10 @@ def var(): return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5)) -class VariableSubclassobjects: - cls: staticmethod[Variable] +class VariableSubclassobjects(ABC): + @abstractmethod + def cls(self, *args, **kwargs) -> Variable: + raise NotImplementedError def test_properties(self): data = 0.5 * np.arange(10) @@ -1056,7 +1059,8 @@ def test_rolling_window_errors(self, dim, window, window_dim, center): class TestVariable(VariableSubclassobjects): - cls = staticmethod(Variable) + def cls(self, *args, **kwargs) -> Variable: + return Variable(*args, **kwargs) @pytest.fixture(autouse=True) def setup(self): @@ -2228,13 +2232,10 @@ def test_coarsen_keep_attrs(self, operation="mean"): assert new.attrs == _attrs -def _init_dask_variable(*args, **kwargs): - return Variable(*args, **kwargs).chunk() - - @requires_dask class TestVariableWithDask(VariableSubclassobjects): - cls = staticmethod(_init_dask_variable) + def cls(self, *args, **kwargs) -> Variable: + return Variable(*args, **kwargs).chunk() def test_chunk(self): unblocked = Variable(["dim_0", "dim_1"], np.ones((3, 4))) @@ -2346,7 +2347,8 @@ def test_as_sparse(self): class TestIndexVariable(VariableSubclassobjects): - cls = staticmethod(IndexVariable) + def cls(self, *args, **kwargs) -> IndexVariable: + return IndexVariable(*args, **kwargs) def test_init(self): with pytest.raises(ValueError, match=r"must be 1-dimensional"):