Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement map_over_datasets #2

Closed
wants to merge 10 commits into from
61 changes: 61 additions & 0 deletions DATATREE_MIGRATION_GUIDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Migration guide for users of `xarray-contrib/datatree`

_15th October 2024_

This guide is for previous users of the prototype `datatree.DataTree` class in the `xarray-contrib/datatree repository`. That repository has now been archived, and will not be maintained. This guide is intended to help smooth your transition to using the new, updated `xarray.DataTree` class.

> [!IMPORTANT]
> There are breaking changes! You should not expect that code written with `xarray-contrib/datatree` will work without any modifications. At the absolute minimum you will need to change the top-level import statement, but there are other changes too.

We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself, integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.

### Data model changes

The most important changes made are to the data model of `DataTree`. Whilst previously data in different nodes was unrelated and therefore unconstrained, now trees have "internal alignment" - meaning that dimensions and indexes in child nodes must exactly align with those in their parents.

These alignment checks happen at tree construction time, meaning there are some netCDF4 files and zarr stores that could previously be opened as `datatree.DataTree` objects using `datatree.open_datatree`, but now cannot be opened as `xr.DataTree` objects using `xr.open_datatree`. For these cases we added a new opener function `xr.open_groups`, which returns a `dict[str, Dataset]`. This is intended as a fallback for tricky cases, where the idea is that you can still open the entire contents of the file using `open_groups`, edit the `Dataset` objects, then construct a valid tree from the edited dictionary using `DataTree.from_dict`.

The alignment checks allowed us to add "Coordinate Inheritance", a much-requested feature where indexed coordinate variables are now "inherited" down to child nodes. This allows you to define common coordinates in a parent group that are then automatically available on every child node. The distinction between a locally-defined coordinate variables and an inherited coordinate that was defined on a parent node is reflected in the `DataTree.__repr__`. Generally if you prefer not to have these variables be inherited you can get more similar behaviour to the old `datatree` package by removing indexes from coordinates, as this prevents inheritance.

For further documentation see the page in the user guide on Hierarchical Data.

### Integrated backends

Previously `datatree.open_datatree` used a different codepath from `xarray.open_dataset`, and was hard-coded to only support opening netCDF files and Zarr stores.
Now xarray's backend entrypoint system has been generalized to include `open_datatree` and the new `open_groups`.
This means we can now extend other xarray backends to support `open_datatree`! If you are the maintainer of an xarray backend we encourage you to add support for `open_datatree` and `open_groups`!

Additionally:
- A `group` kwarg has been added to `open_datatree` for choosing which group in the file should become the root group of the created tree.
- Various performance improvements have been made, which should help when opening netCDF files and Zarr stores with large numbers of groups.
- We anticipate further performance improvements being possible for datatree IO.

### API changes

A number of other API changes have been made, which should only require minor modifications to your code:
- The top-level import has changed, from `from datatree import DataTree, open_datatree` to `from xarray import DataTree, open_datatree`. Alternatively you can now just use the `import xarray as xr` namespace convention for everything datatree-related.
- The `DataTree.ds` property has been changed to `DataTree.dataset`, though `DataTree.ds` remains as an alias for `DataTree.dataset`.
- Similarly the `ds` kwarg in the `DataTree.__init__` constructor has been replaced by `dataset`, i.e. use `DataTree(dataset=)` instead of `DataTree(ds=...)`.
- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherit=True/False`.
- `DataTree.copy()` also has a new `inherit` keyword argument for controlling whether or not coordinates defined on parents are copied (only relevant when copying a non-root node).
- The `DataTree.parent` property is now read-only. To assign a ancestral relationships directly you must instead use the `.children` property on the parent node, which remains settable.
- Similarly the `parent` kwarg has been removed from the `DataTree.__init__` constuctor.
- DataTree objects passed to the `children` kwarg in `DataTree.__init__` are now shallow-copied.
- `DataTree.as_array` has been replaced by `DataTree.to_dataarray`.
- A number of methods which were not well tested have been (temporarily) disabled. In general we have tried to only keep things that are known to work, with the plan to increase API surface incrementally after release.

## Thank you!

Thank you for trying out `xarray-contrib/datatree`!

We welcome contributions of any kind, including good ideas that never quite made it into the original datatree repository. Please also let us know if we have forgotten to mention a change that should have been listed in this guide.

Sincerely, the datatree team:

Tom Nicholas,
Owen Littlejohns,
Matt Savoie,
Eni Awowale,
Alfonso Ladino,
Justus Magin,
Stephan Hoyer
18 changes: 9 additions & 9 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -849,20 +849,20 @@ Aggregate data in all nodes in the subtree simultaneously.
DataTree.cumsum
DataTree.cumprod

.. ndarray methods
.. ---------------
ndarray methods
---------------

.. Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data in all nodes in the subtree.
Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data in all nodes in the subtree.

.. .. autosummary::
.. :toctree: generated/
.. autosummary::
:toctree: generated/

.. DataTree.argsort
DataTree.argsort
DataTree.conj
DataTree.conjugate
DataTree.round
.. DataTree.astype
.. DataTree.clip
.. DataTree.conj
.. DataTree.conjugate
.. DataTree.round
.. DataTree.rank

.. Reshaping and reorganising
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ New Features
`Matt Savoie <https://github.com/flamingbear>`_,
`Stephan Hoyer <https://github.com/shoyer>`_ and
`Tom Nicholas <https://github.com/TomNicholas>`_.
- A migration guide for users of the prototype `xarray-contrib/datatree repository <https://github.com/xarray-contrib/datatree>`_ has been added, and can be found in the `DATATREE_MIGRATION_GUIDE.md` file in the repository root.
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`).
By `Eni Awowale <https://github.com/eni-awowale>`_.
- Added support for vectorized interpolation using additional interpolators
Expand Down
162 changes: 162 additions & 0 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xarray.core.types import (
DaCompatible,
DsCompatible,
DtCompatible,
Self,
T_Xarray,
VarCompatible,
Expand All @@ -23,6 +24,167 @@
from xarray.core.types import T_DataArray as T_DA


class DataTreeOpsMixin:
__slots__ = ()

def _binary_op(
self, other: DtCompatible, f: Callable, reflexive: bool = False
) -> Self:
raise NotImplementedError

def __add__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.add)

def __sub__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.sub)

def __mul__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.mul)

def __pow__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.pow)

def __truediv__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.truediv)

def __floordiv__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.floordiv)

def __mod__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.mod)

def __and__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.and_)

def __xor__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.xor)

def __or__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.or_)

def __lshift__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.lshift)

def __rshift__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.rshift)

def __lt__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.lt)

def __le__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.le)

def __gt__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.gt)

def __ge__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.ge)

def __eq__(self, other: DtCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)

def __ne__(self, other: DtCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)

def __rsub__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)

def __rmul__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)

def __rpow__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)

def __rtruediv__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)

def __rfloordiv__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)

def __rmod__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)

def __rand__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)

def __rxor__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)

def __ror__(self, other: DtCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)

def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError

def __neg__(self) -> Self:
return self._unary_op(operator.neg)

def __pos__(self) -> Self:
return self._unary_op(operator.pos)

def __abs__(self) -> Self:
return self._unary_op(operator.abs)

def __invert__(self) -> Self:
return self._unary_op(operator.invert)

def round(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.round_, *args, **kwargs)

def argsort(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.argsort, *args, **kwargs)

def conj(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conj, *args, **kwargs)

def conjugate(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conjugate, *args, **kwargs)

__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
__pow__.__doc__ = operator.pow.__doc__
__truediv__.__doc__ = operator.truediv.__doc__
__floordiv__.__doc__ = operator.floordiv.__doc__
__mod__.__doc__ = operator.mod.__doc__
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
__ge__.__doc__ = operator.ge.__doc__
__eq__.__doc__ = nputils.array_eq.__doc__
__ne__.__doc__ = nputils.array_ne.__doc__
__radd__.__doc__ = operator.add.__doc__
__rsub__.__doc__ = operator.sub.__doc__
__rmul__.__doc__ = operator.mul.__doc__
__rpow__.__doc__ = operator.pow.__doc__
__rtruediv__.__doc__ = operator.truediv.__doc__
__rfloordiv__.__doc__ = operator.floordiv.__doc__
__rmod__.__doc__ = operator.mod.__doc__
__rand__.__doc__ = operator.and_.__doc__
__rxor__.__doc__ = operator.xor.__doc__
__ror__.__doc__ = operator.or_.__doc__
__neg__.__doc__ = operator.neg.__doc__
__pos__.__doc__ = operator.pos.__doc__
__abs__.__doc__ = operator.abs.__doc__
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__


class DatasetOpsMixin:
__slots__ = ()

Expand Down
3 changes: 2 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4765,9 +4765,10 @@ def _unary_op(self, f: Callable, *args, **kwargs) -> Self:
def _binary_op(
self, other: DaCompatible, f: Callable, reflexive: bool = False
) -> Self:
from xarray.core.datatree import DataTree
from xarray.core.groupby import GroupBy

if isinstance(other, Dataset | GroupBy):
if isinstance(other, DataTree | Dataset | GroupBy):
return NotImplemented
if isinstance(other, DataArray):
align_type = OPTIONS["arithmetic_join"]
Expand Down
9 changes: 4 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7784,9 +7784,10 @@ def _unary_op(self, f, *args, **kwargs) -> Self:

def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
from xarray.core.dataarray import DataArray
from xarray.core.datatree import DataTree
from xarray.core.groupby import GroupBy

if isinstance(other, GroupBy):
if isinstance(other, DataTree | GroupBy):
return NotImplemented
align_type = OPTIONS["arithmetic_join"] if join is None else join
if isinstance(other, DataArray | Dataset):
Expand Down Expand Up @@ -10375,10 +10376,8 @@ def groupby(
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
squeeze : bool, default: True
If "group" is a dimension of any arrays in this dataset, `squeeze`
controls whether the subarrays have a dimension of length 1 along
that dimension or if the dimension is squeezed out.
squeeze : False
This argument is deprecated.
restore_coord_dims : bool, default: False
If True, also restore the dimension order of multi-dimensional
coordinates.
Expand Down
Loading