From 162e34aecae7797314d82b11f10a7c188f8cf023 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 12 Jun 2024 15:43:40 -0600 Subject: [PATCH] Fix typing --- xarray/core/dataset.py | 21 ++++++++++----------- xarray/core/types.py | 3 ++- xarray/core/variable.py | 5 +++-- xarray/namedarray/core.py | 6 +++--- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 282ca671129..99206a9cc29 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -159,7 +159,6 @@ QueryParserOptions, ReindexMethodOptions, SideOptions, - T_NormalizedChunks, T_Xarray, ) from xarray.core.weighted import DatasetWeighted @@ -281,17 +280,17 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): def _maybe_chunk( - name, - var, - chunks, + name: Hashable, + var: Variable, + chunks: Mapping[Any, T_ChunkDim] | None, token=None, lock=None, - name_prefix="xarray-", - overwrite_encoded_chunks=False, - inline_array=False, + name_prefix: str = "xarray-", + overwrite_encoded_chunks: bool = False, + inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, -): +) -> Variable: from xarray.namedarray.daskmanager import DaskManager if chunks is not None: @@ -2730,7 +2729,7 @@ def chunk( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" ) - def _resolve_frequency(name: Hashable, freq: str) -> tuple[int]: + def _resolve_frequency(name: Hashable, freq: str) -> tuple[int, ...]: variable = self._variables.get(name, None) if variable is None: raise ValueError( @@ -2742,7 +2741,7 @@ def _resolve_frequency(name: Hashable, freq: str) -> tuple[int]: f"Received variable {name!r} with dtype {variable.dtype!r} instead." ) - chunks = tuple( + chunks: tuple[int, ...] = tuple( DataArray( np.ones(variable.shape, dtype=int), dims=(name,), @@ -2756,7 +2755,7 @@ def _resolve_frequency(name: Hashable, freq: str) -> tuple[int]: ) return chunks - chunks_mapping_ints: T_NormalizedChunks = { + chunks_mapping_ints: Mapping[Any, T_ChunkDim] = { name: ( _resolve_frequency(name, chunks) if isinstance(chunks, str) and chunks != "auto" diff --git a/xarray/core/types.py b/xarray/core/types.py index 33024cb4797..259f961e5f5 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -183,7 +183,8 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. T_FreqStr: TypeAlias = str -T_ChunkDim: TypeAlias = Union[T_FreqStr, int, Literal["auto"], None, tuple[int, ...]] +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +T_ChunkDimFreq: TypeAlias = Union[T_FreqStr, T_ChunkDim] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] T_NormalizedChunks = tuple[tuple[int, ...], ...] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f0685882595..7e1ed3956fe 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,7 +8,7 @@ from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast +from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast import numpy as np import pandas as pd @@ -63,6 +63,7 @@ PadReflectOptions, QuantileMethods, Self, + T_Chunks, T_DuckArray, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -2522,7 +2523,7 @@ def _to_dense(self) -> Variable: def chunk( # type: ignore[override] self, - chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + chunks: T_Chunks = {}, name: str | None = None, lock: bool | None = None, inline_array: bool | None = None, diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 960ab9d4d1d..b24e48194ad 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray - from xarray.core.types import Dims + from xarray.core.types import Dims, T_Chunks from xarray.namedarray._typing import ( Default, _AttrsLike, @@ -748,7 +748,7 @@ def sizes(self) -> dict[_Dim, _IntOrUnknown]: def chunk( self, - chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + chunks: T_Chunks = {}, chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None, from_array_kwargs: Any = None, **chunks_kwargs: Any, @@ -834,7 +834,7 @@ def chunk( ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] if is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) # type: ignore[assignment] + chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) data_chunked = chunkmanager.from_array(ndata, chunks, **from_array_kwargs) # type: ignore[arg-type]