Skip to content

Commit

Permalink
Add T_DuckArray type hint to Variable.data (#8203)
Browse files Browse the repository at this point in the history
* Add T_DuckArray

* Add type to variable.data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* Update variable.py

* Update variable.py

* Update variable.py

* Update variable.py

* Update variable.py

* chunk renaming

* Update parallelcompat.py

* fix attrs?

* Update alignment.py

* Update test_parallelcompat.py

* Update test_variable.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_variable.py

* Update test_variable.py

* Update test_variable.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
3 people authored Sep 19, 2023
1 parent 3d59258 commit 2b444af
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 46 deletions.
6 changes: 3 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray


def reindex_variables(
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(

def _normalize_indexes(
self,
indexes: Mapping[Any, Any],
indexes: Mapping[Any, Any | T_DuckArray],
) -> tuple[NormalizedIndexes, NormalizedIndexVars]:
"""Normalize the indexes/indexers used for re-indexing or alignment.
Expand All @@ -194,7 +194,7 @@ def _normalize_indexes(
f"Indexer has dimensions {idx.dims} that are different "
f"from that to be indexed along '{k}'"
)
data = as_compatible_data(idx)
data: T_DuckArray = as_compatible_data(idx)
pd_idx = safe_cast_to_index(data)
pd_idx.name = k
if isinstance(pd_idx, pd.MultiIndex):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7481,7 +7481,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset:
else:
variables[k] = f(v, *args, **kwargs)
if keep_attrs:
variables[k].attrs = v._attrs
variables[k]._attrs = v._attrs
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(variables, attrs=attrs)

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
T_ChunkedArray = TypeVar("T_ChunkedArray")

if TYPE_CHECKING:
from xarray.core.types import T_Chunks, T_NormalizedChunks
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks


@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -257,7 +257,7 @@ def normalize_chunks(

@abstractmethod
def from_array(
self, data: np.ndarray, chunks: T_Chunks, **kwargs
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
) -> T_ChunkedArray:
"""
Create a chunked array from a non-chunked numpy-like array.
Expand Down
4 changes: 4 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def copy(
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
T_Alignable = TypeVar("T_Alignable", bound="Alignable")

# Temporary placeholder for indicating an array api compliant type.
# hopefully in the future we can narrow this down more:
T_DuckArray = TypeVar("T_DuckArray", bound=Any)

ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
import pandas as pd

if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -253,7 +253,7 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]:
return isinstance(value, (list, tuple))


def is_duck_array(value: Any) -> bool:
def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]:
if isinstance(value, np.ndarray):
return True
return (
Expand Down
76 changes: 45 additions & 31 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -66,6 +66,7 @@
PadModeOptions,
PadReflectOptions,
QuantileMethods,
T_DuckArray,
T_Variable,
)

Expand All @@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError):
# TODO: move this to an xarray.exceptions module?


def as_variable(obj, name=None) -> Variable | IndexVariable:
def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable:
"""Convert an object into a Variable.
Parameters
Expand Down Expand Up @@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
elif isinstance(obj, (set, dict)):
raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}")
elif name is not None:
data = as_compatible_data(obj)
data: T_DuckArray = as_compatible_data(obj)
if data.ndim != 1:
raise MissingDimensionsError(
f"cannot set variable {name!r} with {data.ndim!r}-dimensional data "
Expand Down Expand Up @@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data):
return data


def as_compatible_data(data, fastpath: bool = False):
def as_compatible_data(
data: T_DuckArray | ArrayLike, fastpath: bool = False
) -> T_DuckArray:
"""Prepare and wrap data to put in a Variable.
- If data does not have the necessary attributes, convert it to ndarray.
Expand All @@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False):
"""
if fastpath and getattr(data, "ndim", 0) > 0:
# can't use fastpath (yet) for scalars
return _maybe_wrap_data(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

from xarray.core.dataarray import DataArray

Expand All @@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False):

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
data = _possibly_convert_datetime_or_timedelta_index(data)
return _maybe_wrap_data(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

if isinstance(data, tuple):
data = utils.to_0d_object_array(data)
Expand All @@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False):
if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return data
return cast("T_DuckArray", data)

# validate whether the data is valid data types.
data = np.asarray(data)
Expand Down Expand Up @@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):

__slots__ = ("_dims", "_data", "_attrs", "_encoding")

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
def __init__(
self,
dims,
data: T_DuckArray | ArrayLike,
attrs=None,
encoding=None,
fastpath=False,
):
"""
Parameters
----------
Expand All @@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
Well-behaved code to serialize a Variable should ignore
unrecognized encoding items.
"""
self._data = as_compatible_data(data, fastpath=fastpath)
self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath)
self._dims = self._parse_dimensions(dims)
self._attrs = None
self._attrs: dict[Any, Any] | None = None
self._encoding = None
if attrs is not None:
self.attrs = attrs
Expand Down Expand Up @@ -410,7 +420,7 @@ def _in_memory(self):
)

@property
def data(self) -> Any:
def data(self: T_Variable):
"""
The Variable's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.
Expand All @@ -429,12 +439,12 @@ def data(self) -> Any:
return self.values

@data.setter
def data(self, data):
def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None:
data = as_compatible_data(data)
if data.shape != self.shape:
if data.shape != self.shape: # type: ignore[attr-defined]
raise ValueError(
f"replacement data must match the Variable's shape. "
f"replacement data has shape {data.shape}; Variable has shape {self.shape}"
f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined]
)
self._data = data

Expand Down Expand Up @@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable:
return self._replace(encoding={})

def copy(
self: T_Variable, deep: bool = True, data: ArrayLike | None = None
self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None
) -> T_Variable:
"""Returns a copy of this object.
Expand Down Expand Up @@ -1058,24 +1068,26 @@ def copy(
def _copy(
self: T_Variable,
deep: bool = True,
data: ArrayLike | None = None,
data: T_DuckArray | ArrayLike | None = None,
memo: dict[int, Any] | None = None,
) -> T_Variable:
if data is None:
ndata = self._data
data_old = self._data

if isinstance(ndata, indexing.MemoryCachedArray):
if isinstance(data_old, indexing.MemoryCachedArray):
# don't share caching between copies
ndata = indexing.MemoryCachedArray(ndata.array)
ndata = indexing.MemoryCachedArray(data_old.array)
else:
ndata = data_old

if deep:
ndata = copy.deepcopy(ndata, memo)

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape:
if self.shape != ndata.shape: # type: ignore[attr-defined]
raise ValueError(
f"Data shape {ndata.shape} must match shape of object {self.shape}"
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
)

attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
Expand Down Expand Up @@ -1248,11 +1260,11 @@ def chunk(
inline_array=inline_array,
)

data = self._data
if chunkmanager.is_chunked_array(data):
data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type]
data_old = self._data
if chunkmanager.is_chunked_array(data_old):
data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
else:
if isinstance(data, indexing.ExplicitlyIndexed):
if isinstance(data_old, indexing.ExplicitlyIndexed):
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
# that can't handle general array indexing. For example, in netCDF4 you
# can do "outer" indexing along two dimensions independent, which works
Expand All @@ -1261,20 +1273,22 @@ def chunk(
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
# different indexing types in an explicit way:
# https://github.com/dask/dask/issues/2883
data = indexing.ImplicitToExplicitIndexingAdapter(
data, indexing.OuterIndexer
ndata = indexing.ImplicitToExplicitIndexingAdapter(
data_old, indexing.OuterIndexer
)
else:
ndata = data_old

if utils.is_dict_like(chunks):
chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape))
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))

data = chunkmanager.from_array(
data,
data_chunked = chunkmanager.from_array(
ndata,
chunks, # type: ignore[arg-type]
**_from_array_kwargs,
)

return self._replace(data=data)
return self._replace(data=data_chunked)

def to_numpy(self) -> np.ndarray:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
guess_chunkmanager,
list_chunkmanagers,
)
from xarray.core.types import T_Chunks, T_NormalizedChunks
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
from xarray.tests import has_dask, requires_dask


Expand Down Expand Up @@ -76,7 +76,7 @@ def normalize_chunks(
return normalize_chunks(chunks, shape, limit, dtype, previous_chunks)

def from_array(
self, data: np.ndarray, chunks: T_Chunks, **kwargs
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
) -> DummyChunkedArray:
from dask import array as da

Expand Down
12 changes: 7 additions & 5 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from copy import copy, deepcopy
from datetime import datetime, timedelta
from textwrap import dedent
from typing import Generic

import numpy as np
import pandas as pd
Expand All @@ -26,6 +27,7 @@
VectorizedIndexer,
)
from xarray.core.pycompat import array_type
from xarray.core.types import T_DuckArray
from xarray.core.utils import NDArrayMixin
from xarray.core.variable import as_compatible_data, as_variable
from xarray.tests import (
Expand Down Expand Up @@ -2529,7 +2531,7 @@ def test_to_index_variable_copy(self) -> None:
assert a.dims == ("x",)


class TestAsCompatibleData:
class TestAsCompatibleData(Generic[T_DuckArray]):
def test_unchanged_types(self):
types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray)
for t in types:
Expand Down Expand Up @@ -2610,17 +2612,17 @@ def test_tz_datetime(self) -> None:
times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
actual = as_compatible_data(times_s)
actual: T_DuckArray = as_compatible_data(times_s)
assert actual.array == times_s
assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz)

series = pd.Series(times_s)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
actual = as_compatible_data(series)
actual2: T_DuckArray = as_compatible_data(series)

np.testing.assert_array_equal(actual, series.values)
assert actual.dtype == np.dtype("datetime64[ns]")
np.testing.assert_array_equal(actual2, series.values)
assert actual2.dtype == np.dtype("datetime64[ns]")

def test_full_like(self) -> None:
# For more thorough tests, see test_variable.py
Expand Down

0 comments on commit 2b444af

Please sign in to comment.