Skip to content

Commit

Permalink
first try at generic dimension type
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck committed Oct 5, 2023
1 parent e09609c commit 66ba4e8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 24 deletions.
73 changes: 57 additions & 16 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import math
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload

import numpy as np

Expand Down Expand Up @@ -39,8 +39,6 @@
PostPersistCallable: Any # type: ignore[no-redef]

# T_NamedArray = TypeVar("T_NamedArray", bound="NamedArray[T_DuckArray]")
DimsInput = Union[str, Iterable[Hashable]]
Dims = tuple[Hashable, ...]
AttrsInput = Union[Mapping[Any, Any], None]


Expand Down Expand Up @@ -75,7 +73,10 @@ def as_compatible_data(
return cast(T_DuckArray, np.asarray(data))


class NamedArray(Generic[T_DuckArray]):
T_Dim = TypeVar("T_Dim", bound=Hashable)


class NamedArray(Generic[T_Dim, T_DuckArray]):

"""A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array.
Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names,
Expand All @@ -84,20 +85,60 @@ class NamedArray(Generic[T_DuckArray]):
__slots__ = ("_data", "_dims", "_attrs")

_data: T_DuckArray
_dims: Dims
_dims: tuple[T_Dim, ...]
_attrs: dict[Any, Any] | None

@overload
def __init__(
self: NamedArray[str, T_DuckArray],
dims: str,
data: T_DuckArray,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

@overload
def __init__(
self: NamedArray[str, np.ndarray[Any, np.dtype[np.generic]]],
dims: str,
data: np.typing.ArrayLike,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

@overload
def __init__(
self: NamedArray[T_Dim, T_DuckArray],
dims: Iterable[T_Dim],
data: T_DuckArray,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

@overload
def __init__(
self: NamedArray[T_Dim, np.ndarray[Any, np.dtype[np.generic]]],
dims: Iterable[T_Dim],
data: np.typing.ArrayLike,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

def __init__(
self,
dims: DimsInput,
dims: str | Iterable[T_Dim],
data: T_DuckArray | np.typing.ArrayLike,
attrs: AttrsInput = None,
fastpath: bool = False,
):
) -> None:
"""
Parameters
----------
dims : str or iterable of str
dims : str or iterable of hashable
Name(s) of the dimension(s).
data : T_DuckArray or np.typing.ArrayLike
The actual data that populates the array. Should match the shape specified by `dims`.
Expand Down Expand Up @@ -194,22 +235,22 @@ def nbytes(self) -> int:
return self.size * self.dtype.itemsize

@property
def dims(self) -> Dims:
def dims(self) -> tuple[T_Dim, ...]:
"""Tuple of dimension names with which this NamedArray is associated."""
return self._dims

@dims.setter
def dims(self, value: DimsInput) -> None:
def dims(self, value: str | Iterable[T_Dim]) -> None:
self._dims = self._parse_dimensions(value)

def _parse_dimensions(self, dims: DimsInput) -> Dims:
dims = (dims,) if isinstance(dims, str) else tuple(dims)
if len(dims) != self.ndim:
def _parse_dimensions(self, dims: str | Iterable[T_Dim]) -> tuple[T_Dim, ...]:
pdims = (dims,) if isinstance(dims, str) else tuple(dims)
if len(pdims) != self.ndim:
raise ValueError(
f"dimensions {dims} must have the same length as the "
f"dimensions {pdims} must have the same length as the "
f"number of data dimensions, ndim={self.ndim}"
)
return dims
return pdims # type: ignore[return-value]

@property
def attrs(self) -> dict[Any, Any]:
Expand Down Expand Up @@ -397,7 +438,7 @@ def sizes(self) -> dict[Hashable, int]:

def _replace(
self,
dims: DimsInput | Default = _default,
dims: str | Iterable[T_Dim] | Default = _default,
data: T_DuckArray | np.typing.ArrayLike | Default = _default,
attrs: AttrsInput | Default = _default,
) -> Self:
Expand Down
16 changes: 8 additions & 8 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class CustomArrayIndexable(CustomArrayBase, xr.core.indexing.ExplicitlyIndexed):

def test_properties() -> None:
data = 0.5 * np.arange(10).reshape(2, 5)
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(["x", "y"], data, {"key": "value"})
assert named_array.dims == ("x", "y")
assert np.array_equal(named_array.data, data)
Expand All @@ -104,7 +104,7 @@ def test_properties() -> None:


def test_attrs() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5))
assert named_array.attrs == {}
named_array.attrs["key"] = "value"
Expand All @@ -114,7 +114,7 @@ def test_attrs() -> None:


def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(["x", "y", "z"], random_inputs)
assert np.array_equal(named_array.data, random_inputs)
with pytest.raises(ValueError):
Expand All @@ -130,7 +130,7 @@ def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
],
)
def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray([], data)
assert named_array.data == data
assert named_array.dims == ()
Expand All @@ -142,7 +142,7 @@ def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:


def test_0d_object() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray([], (10, 12, 12))
expected_data = np.empty((), dtype=object)
expected_data[()] = (10, 12, 12)
Expand All @@ -157,7 +157,7 @@ def test_0d_object() -> None:


def test_0d_datetime() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray([], np.datetime64("2000-01-01"))
assert named_array.dtype == np.dtype("datetime64[D]")

Expand All @@ -179,7 +179,7 @@ def test_0d_datetime() -> None:
def test_0d_timedelta(
timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64]
) -> None:
named_array: NamedArray[np.ndarray[Any, np.dtype[np.timedelta64]]]
named_array: NamedArray[str, np.ndarray[Any, np.dtype[np.timedelta64]]]
named_array = NamedArray([], timedelta)
assert named_array.dtype == expected_dtype
assert named_array.data == timedelta
Expand All @@ -196,7 +196,7 @@ def test_0d_timedelta(
],
)
def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(dims, np.random.random(data_shape))
assert named_array.dims == tuple(dims)
if raises:
Expand Down

0 comments on commit 66ba4e8

Please sign in to comment.