Skip to content

Commit

Permalink
Support non-str Hashables in DataArray (#8559)
Browse files Browse the repository at this point in the history
* support hashable dims in DataArray
* add whats-new
* remove uneccessary except ImportErrors
* improve some typing
  • Loading branch information
headtr1ck authored Jan 14, 2024
1 parent 08c8f9a commit 357a444
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 49 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- Support non-string hashable dimensions in :py:class:`xarray.DataArray` (:issue:`8546`, :pull:`8559`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`).
Expand Down
59 changes: 26 additions & 33 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Generic,
Literal,
NoReturn,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -61,6 +63,7 @@
ReprObject,
_default,
either_dict_or_kwargs,
hashable,
)
from xarray.core.variable import (
IndexVariable,
Expand All @@ -73,23 +76,11 @@
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

if TYPE_CHECKING:
from typing import TypeVar, Union

from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from iris.cube import Cube as iris_Cube
from numpy.typing import ArrayLike

try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore[misc,assignment]
try:
from iris.cube import Cube as iris_Cube
except ImportError:
iris_Cube = None

from xarray.backends import ZarrStore
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
from xarray.core.groupby import DataArrayGroupBy
Expand Down Expand Up @@ -140,7 +131,9 @@ def _check_coords_dims(shape, coords, dim):


def _infer_coords_and_dims(
shape, coords, dims
shape: tuple[int, ...],
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None,
dims: str | Iterable[Hashable] | None,
) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]:
"""All the logic for creating a new DataArray"""

Expand All @@ -157,8 +150,7 @@ def _infer_coords_and_dims(

if isinstance(dims, str):
dims = (dims,)

if dims is None:
elif dims is None:
dims = [f"dim_{n}" for n in range(len(shape))]
if coords is not None and len(coords) == len(shape):
# try to infer dimensions from coords
Expand All @@ -168,16 +160,15 @@ def _infer_coords_and_dims(
for n, (dim, coord) in enumerate(zip(dims, coords)):
coord = as_variable(coord, name=dims[n]).to_index_variable()
dims[n] = coord.name
dims = tuple(dims)
elif len(dims) != len(shape):
dims_tuple = tuple(dims)
if len(dims_tuple) != len(shape):
raise ValueError(
"different number of dimensions on data "
f"and dims: {len(shape)} vs {len(dims)}"
f"and dims: {len(shape)} vs {len(dims_tuple)}"
)
else:
for d in dims:
if not isinstance(d, str):
raise TypeError(f"dimension {d} is not a string")
for d in dims_tuple:
if not hashable(d):
raise TypeError(f"Dimension {d} is not hashable")

new_coords: Mapping[Hashable, Any]

Expand All @@ -189,17 +180,21 @@ def _infer_coords_and_dims(
for k, v in coords.items():
new_coords[k] = as_variable(v, name=k)
elif coords is not None:
for dim, coord in zip(dims, coords):
for dim, coord in zip(dims_tuple, coords):
var = as_variable(coord, name=dim)
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims)
_check_coords_dims(shape, new_coords, dims_tuple)

return new_coords, dims
return new_coords, dims_tuple


def _check_data_shape(data, coords, dims):
def _check_data_shape(
data: Any,
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None,
dims: str | Iterable[Hashable] | None,
) -> Any:
if data is dtypes.NA:
data = np.nan
if coords is not None and utils.is_scalar(data, include_0d=False):
Expand Down Expand Up @@ -405,10 +400,8 @@ class DataArray(
def __init__(
self,
data: Any = dtypes.NA,
coords: Sequence[Sequence[Any] | pd.Index | DataArray]
| Mapping[Any, Any]
| None = None,
dims: Hashable | Sequence[Hashable] | None = None,
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None,
dims: str | Iterable[Hashable] | None = None,
name: Hashable | None = None,
attrs: Mapping | None = None,
# internal parameters
Expand Down
11 changes: 2 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@
from xarray.util.deprecation_helpers import _deprecate_positional_args

if TYPE_CHECKING:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from numpy.typing import ArrayLike

from xarray.backends import AbstractDataStore, ZarrStore
Expand Down Expand Up @@ -164,15 +166,6 @@
)
from xarray.core.weighted import DatasetWeighted

try:
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore[misc,assignment]
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
_DATETIMEINDEX_COMPONENTS = [
Expand Down
11 changes: 6 additions & 5 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def _importorskip(
raise ImportError("Minimum version not satisfied")
except ImportError:
has = False
func = pytest.mark.skipif(not has, reason=f"requires {modname}")

reason = f"requires {modname}"
if minversion is not None:
reason += f">={minversion}"
func = pytest.mark.skipif(not has, reason=reason)
return has, func


Expand Down Expand Up @@ -122,10 +126,7 @@ def _importorskip(
not has_pandas_version_two, reason="requires pandas 2.0.0"
)
has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0")
has_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")
requires_h5netcdf_ros3 = pytest.mark.skipif(
not has_h5netcdf_ros3[0], reason="requires h5netcdf 1.3.0"
)
has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")

has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip(
"netCDF4", "1.6.2"
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ def test_constructor_invalid(self) -> None:
with pytest.raises(ValueError, match=r"not a subset of the .* dim"):
DataArray(data, {"x": [0, 1, 2]})

with pytest.raises(TypeError, match=r"is not a string"):
DataArray(data, dims=["x", None])
with pytest.raises(TypeError, match=r"is not hashable"):
DataArray(data, dims=["x", []]) # type: ignore[list-item]

with pytest.raises(ValueError, match=r"conflicting sizes for dim"):
DataArray([1, 2, 3], coords=[("x", [0, 1])])
Expand Down
53 changes: 53 additions & 0 deletions xarray/tests/test_hashable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Union

import pytest

from xarray import DataArray, Dataset, Variable

if TYPE_CHECKING:
from xarray.core.types import TypeAlias

DimT: TypeAlias = Union[int, tuple, "DEnum", "CustomHashable"]


class DEnum(Enum):
dim = "dim"


class CustomHashable:
def __init__(self, a: int) -> None:
self.a = a

def __hash__(self) -> int:
return self.a


parametrize_dim = pytest.mark.parametrize(
"dim",
[
pytest.param(5, id="int"),
pytest.param(("a", "b"), id="tuple"),
pytest.param(DEnum.dim, id="enum"),
pytest.param(CustomHashable(3), id="HashableObject"),
],
)


@parametrize_dim
def test_hashable_dims(dim: DimT) -> None:
v = Variable([dim], [1, 2, 3])
da = DataArray([1, 2, 3], dims=[dim])
Dataset({"a": ([dim], [1, 2, 3])})

# alternative constructors
DataArray(v)
Dataset({"a": v})
Dataset({"a": da})


@parametrize_dim
def test_dataset_variable_hashable_names(dim: DimT) -> None:
Dataset({dim: ("x", [1, 2, 3])})

0 comments on commit 357a444

Please sign in to comment.