Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expand_dims #8407

Merged
merged 19 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,7 +2596,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable:
"""
Use sparse-array as backend.
"""
from xarray.namedarray.utils import _default as _default_named
from xarray.namedarray._typing import _default as _default_named

if sparse_format is _default:
sparse_format = _default_named
Expand Down
52 changes: 52 additions & 0 deletions xarray/namedarray/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import numpy as np

from xarray.namedarray._typing import (
Default,
_arrayapi,
_Axis,
_default,
_Dim,
_DType,
_ScalarType,
_ShapeType,
Expand Down Expand Up @@ -144,3 +148,51 @@ def real(
xp = _get_data_namespace(x)
out = x._new(data=xp.real(x._data))
return out


# %% Manipulation functions
def expand_dims(
x: NamedArray[Any, _DType],
/,
*,
dim: _Dim | Default = _default,
axis: _Axis = 0,
) -> NamedArray[Any, _DType]:
"""
Expands the shape of an array by inserting a new dimension of size one at the
position specified by dims.

Parameters
----------
x :
Array to expand.
dim :
Dimension name. New dimension will be stored in the axis position.
axis :
(Not recommended) Axis position (zero-based). Default is 0.

Returns
-------
out :
An expanded output array having the same data type as x.

Examples
--------
>>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]]))
>>> expand_dims(x)
<xarray.NamedArray (dim_2: 1, x: 2, y: 2)>
Array([[[1., 2.],
[3., 4.]]], dtype=float64)
>>> expand_dims(x, dim="z")
<xarray.NamedArray (z: 1, x: 2, y: 2)>
Array([[[1., 2.],
[3., 4.]]], dtype=float64)
"""
xp = _get_data_namespace(x)
dims = x.dims
if dim is _default:
dim = f"dim_{len(dims)}"
d = list(dims)
d.insert(axis, dim)
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
return out
14 changes: 14 additions & 0 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Mapping, Sequence
from enum import Enum
from types import ModuleType
from typing import (
Any,
Callable,
Final,
Protocol,
SupportsIndex,
TypeVar,
Expand All @@ -15,6 +17,14 @@

import numpy as np


# Singleton type, as per https://github.com/python/typing/pull/240
class Default(Enum):
token: Final = 0


_default = Default.token

# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
Expand Down Expand Up @@ -49,6 +59,10 @@ def dtype(self) -> _DType_co:
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True)

_Axis = int
_Axes = tuple[_Axis, ...]
_AxisLike = Union[_Axis, _Axes]

_Chunks = tuple[_Shape, ...]

_Dim = Hashable
Expand Down
5 changes: 3 additions & 2 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_arrayapi,
_arrayfunction_or_api,
_chunkedarray,
_default,
_dtype,
_DType_co,
_ScalarType_co,
Expand All @@ -33,13 +34,14 @@
_SupportsImag,
_SupportsReal,
)
from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array
from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array

if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray

from xarray.core.types import Dims
from xarray.namedarray._typing import (
Default,
_AttrsLike,
_Chunks,
_Dim,
Expand All @@ -52,7 +54,6 @@
_ShapeType,
duckarray,
)
from xarray.namedarray.utils import Default

try:
from dask.typing import (
Expand Down
15 changes: 1 addition & 14 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

import sys
from collections.abc import Hashable
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Final,
)
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -31,14 +26,6 @@
DaskCollection: Any = NDArray # type: ignore


# Singleton type, as per https://github.com/python/typing/pull/240
class Default(Enum):
token: Final = 0


_default = Default.token


def module_available(module: str) -> bool:
"""Checks whether a module is installed without importing it.

Expand Down
10 changes: 7 additions & 3 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
import pytest

from xarray.core.indexing import ExplicitlyIndexed
from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co
from xarray.namedarray._typing import (
_arrayfunction_or_api,
_default,
_DType_co,
_ShapeType_co,
)
from xarray.namedarray.core import NamedArray, from_array
from xarray.namedarray.utils import _default

if TYPE_CHECKING:
from types import ModuleType

from numpy.typing import ArrayLike, DTypeLike, NDArray

from xarray.namedarray._typing import (
Default,
_AttrsLike,
_DimsLike,
_DType,
_Shape,
duckarray,
)
from xarray.namedarray.utils import Default


class CustomArrayBase(Generic[_ShapeType_co, _DType_co]):
Expand Down
Loading