Skip to content

Commit

Permalink
Use shape and dtype as typevars in NamedArray (#8294)
Browse files Browse the repository at this point in the history
* Add from_array function

* Update core.py

* some fixes

* Update test_namedarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* fixes

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* more

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_namedarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* fixes

* fkxes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_namedarray.py

* Update test_namedarray.py

* Update test_namedarray.py

* Update test_namedarray.py

* Update test_namedarray.py

* Update test_namedarray.py

* Update test_namedarray.py

* move to NDArray instead

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename and align more with numpy typing

* Add duck type testing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docstring

* Update test_namedarray.py

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* fixes

* final

* Follow numpy's example more with typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update utils.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Create _array_api.py

* Create _typing.py

* Update core.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _typing.py

* Update core.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Will this make pre-commit happy?

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _array_api.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more

* Update core.py

* fixes

* Update test_namedarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* Use Self becuase Variable subclasses

* fixes

* Update test_namedarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* Update core.py

* Update core.py

* Update variable.py

* Update variable.py

* fix array api, add docstrings

* Fix typing so that a different array gets correct typing

* add _new with correct typing in variable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* shape usually stays the same when copying

* Update variable.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_namedarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_namedarray.py

* same shape when astyping

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Delete test_namedarray_sketching.py

* typos

* remove any typing for now

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update indexing.py

* add namespace to some explicitindexing stuff

* Update variable.py

* Update duck_array_ops.py

* Update duck_array_ops.py

* fixes

* Update variable.py

* Fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_variable.py

* Revert "Update test_variable.py"

This reverts commit 6572abe.

* Update _array_api.py

* Update _array_api.py

* Update _array_api.py

* as_compatible_data lose the typing

* Update indexing.py

* Update core.py

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update variable.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update variable.py

* Update variable.py

* Update indexing.py

* Update xarray/core/variable.py

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* Update xarray/core/variable.py

Co-authored-by: Michael Niklas  <[email protected]>

* Apply suggestions from code review

Co-authored-by: Michael Niklas  <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michael Niklas <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2023
1 parent c25c825 commit 087fe45
Show file tree
Hide file tree
Showing 7 changed files with 1,114 additions and 337 deletions.
36 changes: 30 additions & 6 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,25 @@ def __init__(
if encoding is not None:
self.encoding = encoding

def _new(
self,
dims=_default,
data=_default,
attrs=_default,
):
dims_ = copy.copy(self._dims) if dims is _default else dims

if attrs is _default:
attrs_ = None if self._attrs is None else self._attrs.copy()
else:
attrs_ = attrs

if data is _default:
return type(self)(dims_, copy.copy(self._data), attrs_)
else:
cls_ = type(self)
return cls_(dims_, data, attrs_)

@property
def _in_memory(self):
return isinstance(
Expand Down Expand Up @@ -905,16 +924,17 @@ def _copy(
ndata = data_old
else:
# don't share caching between copies
ndata = indexing.MemoryCachedArray(data_old.array)
# TODO: MemoryCachedArray doesn't match the array api:
ndata = indexing.MemoryCachedArray(data_old.array) # type: ignore[assignment]

if deep:
ndata = copy.deepcopy(ndata, memo)

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape:
if self.shape != ndata.shape: # type: ignore[attr-defined]
raise ValueError(
f"Data shape {ndata.shape} must match shape of object {self.shape}"
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
)

attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
Expand Down Expand Up @@ -1054,7 +1074,8 @@ def chunk(
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
# different indexing types in an explicit way:
# https://github.com/dask/dask/issues/2883
ndata = indexing.ImplicitToExplicitIndexingAdapter(
# TODO: ImplicitToExplicitIndexingAdapter doesn't match the array api:
ndata = indexing.ImplicitToExplicitIndexingAdapter( # type: ignore[assignment]
data_old, indexing.OuterIndexer
)

Expand Down Expand Up @@ -2608,6 +2629,9 @@ class IndexVariable(Variable):

__slots__ = ()

# TODO: PandasIndexingAdapter doesn't match the array api:
_data: PandasIndexingAdapter # type: ignore[assignment]

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
super().__init__(dims, data, attrs, encoding, fastpath)
if self.ndim != 1:
Expand Down Expand Up @@ -2756,9 +2780,9 @@ def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None):

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape:
if self.shape != ndata.shape: # type: ignore[attr-defined]
raise ValueError(
f"Data shape {ndata.shape} must match shape of object {self.shape}"
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
)

attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
Expand Down
122 changes: 122 additions & 0 deletions xarray/namedarray/_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from types import ModuleType
from typing import Any

import numpy as np

from xarray.namedarray._typing import (
_arrayapi,
_DType,
_ScalarType,
_ShapeType,
_SupportsImag,
_SupportsReal,
)
from xarray.namedarray.core import NamedArray


def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
if isinstance(x._data, _arrayapi):
return x._data.__array_namespace__()
else:
return np


def astype(
x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True
) -> NamedArray[_ShapeType, _DType]:
"""
Copies an array to a specified data type irrespective of Type Promotion Rules rules.
Parameters
----------
x : NamedArray
Array to cast.
dtype : _DType
Desired data type.
copy : bool, optional
Specifies whether to copy an array when the specified dtype matches the data
type of the input array x.
If True, a newly allocated array must always be returned.
If False and the specified dtype matches the data type of the input array,
the input array must be returned; otherwise, a newly allocated array must be
returned. Default: True.
Returns
-------
out : NamedArray
An array having the specified data type. The returned array must have the
same shape as x.
Examples
--------
>>> narr = NamedArray(("x",), np.array([1.5, 2.5]))
>>> astype(narr, np.dtype(int)).data
array([1, 2])
"""
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
return x._new(data=xp.astype(x, dtype, copy=copy))

# np.astype doesn't exist yet:
return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]


def imag(
x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var]
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the imaginary component of a complex number for each element x_i of the
input array x.
Parameters
----------
x : NamedArray
Input array. Should have a complex floating-point data type.
Returns
-------
out : NamedArray
An array containing the element-wise results. The returned array must have a
floating-point data type with the same floating-point precision as x
(e.g., if x is complex64, the returned array must have the floating-point
data type float32).
Examples
--------
>>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
>>> imag(narr).data
array([2., 4.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.imag(x._data))
return out


def real(
x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var]
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the real component of a complex number for each element x_i of the
input array x.
Parameters
----------
x : NamedArray
Input array. Should have a complex floating-point data type.
Returns
-------
out : NamedArray
An array containing the element-wise results. The returned array must have a
floating-point data type with the same floating-point precision as x
(e.g., if x is complex64, the returned array must have the floating-point
data type float32).
Examples
--------
>>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
>>> real(narr).data
array([1., 2.])
"""
xp = _get_data_namespace(x)
return x._new(data=xp.real(x._data))
Loading

0 comments on commit 087fe45

Please sign in to comment.