Skip to content

Commit

Permalink
Update test_namedarray.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Oct 13, 2023
1 parent 598a871 commit ab26e87
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,49 @@

import xarray as xr
from xarray.namedarray.core import NamedArray, from_array
from xarray.namedarray.utils import T_DuckArray, _arrayfunction_or_api
from xarray.namedarray._typing import _arrayfunction_or_api

if TYPE_CHECKING:
from types import ModuleType

from numpy.typing import NDArray

from xarray.namedarray.utils import (
from xarray.namedarray._typing import (
_DimsLike,
_Shape,
_DType_co,
_ScalarType,
_ScalarType_co,
DuckArray,
_ShapeType_co,
duckarray,
)


class CustomArrayBase(xr.core.indexing.NDArrayMixin, Generic[T_DuckArray]):
def __init__(self, array: T_DuckArray) -> None:
self.array: T_DuckArray = array
class CustomArrayBase(xr.core.indexing.NDArrayMixin, Generic[_ShapeType_co, _DType_co]):
def __init__(self, array: duckarray[Any, _DType_co]) -> None:
self.array: duckarray[Any, _DType_co] = array

@property
def dtype(self) -> np.dtype[np.generic]:
def dtype(self) -> _DType_co:
return self.array.dtype

@property
def shape(self) -> _Shape:
return self.array.shape

@property
def real(self) -> Any:
return self.array.real

@property
def imag(self) -> Any:
return self.array.imag

def astype(self, dtype: np.typing.DTypeLike) -> Any:
return self.array.astype(dtype)


class CustomArray(CustomArrayBase[T_DuckArray], Generic[T_DuckArray]):
class CustomArray(
CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
):
def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]:
return np.array(self.array)


class CustomArrayIndexable(
CustomArrayBase[T_DuckArray],
CustomArrayBase[_ShapeType_co, _DType_co],
xr.core.indexing.ExplicitlyIndexed,
Generic[T_DuckArray],
Generic[_ShapeType_co, _DType_co],
):
def __array_namespace__(self) -> ModuleType:
return np
Expand All @@ -79,7 +76,7 @@ def test_from_array(
expected: np.ndarray[Any, Any],
raise_error: bool,
) -> None:
actual: NamedArray[np.ndarray[Any, Any]]
actual: NamedArray[Any, Any]
if raise_error:
with pytest.raises(TypeError, match="already a Named array"):
actual = from_array(dims, data) # type: ignore
Expand Down Expand Up @@ -108,19 +105,19 @@ def test_from_array_with_explicitly_indexed(
random_inputs: np.ndarray[Any, Any]
) -> None:
array = CustomArray(random_inputs)
output: NamedArray[CustomArray[np.ndarray[Any, Any]]]
output: NamedArray[Any, Any]
output = from_array(("x", "y", "z"), array)
assert isinstance(output.data, np.ndarray)

array2 = CustomArrayIndexable(random_inputs)
output2: NamedArray[CustomArrayIndexable[np.ndarray[Any, Any]]]
output2: NamedArray[Any, Any]
output2 = from_array(("x", "y", "z"), array2)
assert isinstance(output2.data, CustomArrayIndexable)


def test_properties() -> None:
data = 0.5 * np.arange(10).reshape(2, 5)
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = NamedArray(["x", "y"], data, {"key": "value"})
assert named_array.dims == ("x", "y")
assert np.array_equal(named_array.data, data)
Expand All @@ -133,7 +130,7 @@ def test_properties() -> None:


def test_attrs() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5))
assert named_array.attrs == {}
named_array.attrs["key"] = "value"
Expand All @@ -143,7 +140,7 @@ def test_attrs() -> None:


def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = NamedArray(["x", "y", "z"], random_inputs)
assert np.array_equal(named_array.data, random_inputs)
with pytest.raises(ValueError):
Expand All @@ -159,7 +156,7 @@ def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
],
)
def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = from_array([], data)
assert named_array.data == data
assert named_array.dims == ()
Expand All @@ -171,7 +168,7 @@ def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:


def test_0d_object() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = from_array([], (10, 12, 12))
expected_data = np.empty((), dtype=object)
expected_data[()] = (10, 12, 12)
Expand All @@ -186,7 +183,7 @@ def test_0d_object() -> None:


def test_0d_datetime() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = from_array([], np.datetime64("2000-01-01"))
assert named_array.dtype == np.dtype("datetime64[D]")

Expand All @@ -208,7 +205,7 @@ def test_0d_datetime() -> None:
def test_0d_timedelta(
timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64]
) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = from_array([], timedelta)
assert named_array.dtype == expected_dtype
assert named_array.data == timedelta
Expand All @@ -225,7 +222,7 @@ def test_0d_timedelta(
],
)
def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[Any, Any]
named_array = NamedArray(dims, np.asarray(np.random.random(data_shape)))
assert named_array.dims == tuple(dims)
if raises:
Expand All @@ -237,20 +234,22 @@ def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) ->


def test_duck_array_class() -> None:
def test_duck_array_typevar(a: T_DuckArray) -> T_DuckArray:
def test_duck_array_typevar(a: DuckArray[_ScalarType]) -> DuckArray[_ScalarType]:
# Mypy checks a is valid:
b: T_DuckArray = a
b: DuckArray[_ScalarType] = a

# Runtime check if valid:
if isinstance(b, _arrayfunction_or_api):
# TODO: cast is a mypy workaround for https://github.com/python/mypy/issues/10817
# pyright doesn't need it.
return cast(T_DuckArray, b)
return cast(DuckArray[_ScalarType], b)
else:
raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi")

numpy_a: NDArray[np.int64] = np.array([2.1, 4], dtype=np.dtype(np.int64))
custom_a: CustomArrayIndexable[NDArray[np.int64]] = CustomArrayIndexable(numpy_a)
custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] = CustomArrayIndexable(
numpy_a
)

test_duck_array_typevar(numpy_a)
test_duck_array_typevar(custom_a)

0 comments on commit ab26e87

Please sign in to comment.